diff --git a/stubgen/config.py b/stubgen/config.py new file mode 100644 index 0000000..02618d1 --- /dev/null +++ b/stubgen/config.py @@ -0,0 +1,149 @@ +"""Configuration loader for stubgen.""" + +import tomllib +from pathlib import Path +from typing import Any, Dict, List, Optional + +import sys + + +class ConfigError(Exception): + """Configuration error.""" + pass + + +class ConfigLoader: + """Loads configuration from pyproject.toml and stubgen.toml.""" + + def __init__(self): + self.config: Dict[str, Any] = {} + self.config_path: Optional[Path] = None + + def load(self, path: Optional[Path] = None) -> Dict[str, Any]: + """Load configuration from file.""" + if path is None: + path = self._find_config() + + if path is None or not path.exists(): + self.config = self._get_default_config() + return self.config + + self.config_path = path + + try: + with open(path, 'rb') as f: + raw_config = tomllib.load(f) + except tomllib.TOMLDecodeError as e: + raise ConfigError(f"Failed to parse config file: {e}") + + self.config = self._parse_config(raw_config) + + return self.config + + def _find_config(self) -> Optional[Path]: + """Search for config file in current directory and parents.""" + current = Path.cwd() + + search_paths = [ + current / "pyproject.toml", + current / "stubgen.toml", + current / ".stubgen.toml", + ] + + for config_path in search_paths: + if config_path.exists(): + return config_path + + return None + + def _parse_config(self, raw: Dict[str, Any]) -> Dict[str, Any]: + """Parse raw TOML config into structured format.""" + config = self._get_default_config() + + if "tool" in raw and "stubgen" in raw["tool"]: + stubgen_config = raw["tool"]["stubgen"] + config = self._merge_config(config, stubgen_config) + + if "stubgen" in raw: + config = self._merge_config(config, raw["stubgen"]) + + return config + + def _merge_config(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Merge configuration dictionaries.""" + result = base.copy() + + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = self._merge_config(result[key], value) + else: + result[key] = value + + return result + + def _get_default_config(self) -> Dict[str, Any]: + """Return default configuration.""" + return { + "exclude_patterns": [ + "tests/*", + "*/__pycache__/*", + "*/.venv/*", + "*/venv/*", + "*/node_modules/*", + "*.egg-info/*", + ], + "infer_depth": 3, + "strict_mode": False, + "interactive": False, + "output_dir": None, + "recursive": True, + "verbose": False, + "dry_run": False, + } + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value.""" + return self.config.get(key, default) + + def get_list(self, key: str) -> List[str]: + """Get a configuration value as a list.""" + value = self.config.get(key, []) + if isinstance(value, str): + return [value] + return value + + def get_bool(self, key: str) -> bool: + """Get a configuration value as a boolean.""" + value = self.config.get(key, False) + return bool(value) + + def get_int(self, key: str, default: int = 0) -> int: + """Get a configuration value as an integer.""" + value = self.config.get(key, default) + try: + return int(value) + except (TypeError, ValueError): + return default + + +def load_config(config_path: Optional[Path] = None) -> Dict[str, Any]: + """Load configuration from file.""" + loader = ConfigLoader() + return loader.load(config_path) + + +def should_exclude(path: Path, exclude_patterns: List[str]) -> bool: + """Check if a path should be excluded based on patterns.""" + import fnmatch + + path_str = str(path) + + for pattern in exclude_patterns: + if fnmatch.fnmatch(path_str, pattern): + return True + if fnmatch.fnmatch(path.name, pattern): + return True + if path.parent.name in pattern: + return True + + return False