"""Configuration loader for stubgen.""" import tomllib from pathlib import Path from typing import Any, Dict, List, Optional import sys class ConfigError(Exception): """Configuration error.""" pass class ConfigLoader: """Loads configuration from pyproject.toml and stubgen.toml.""" def __init__(self): self.config: Dict[str, Any] = {} self.config_path: Optional[Path] = None def load(self, path: Optional[Path] = None) -> Dict[str, Any]: """Load configuration from file.""" if path is None: path = self._find_config() if path is None or not path.exists(): self.config = self._get_default_config() return self.config self.config_path = path try: with open(path, 'rb') as f: raw_config = tomllib.load(f) except tomllib.TOMLDecodeError as e: raise ConfigError(f"Failed to parse config file: {e}") self.config = self._parse_config(raw_config) return self.config def _find_config(self) -> Optional[Path]: """Search for config file in current directory and parents.""" current = Path.cwd() search_paths = [ current / "pyproject.toml", current / "stubgen.toml", current / ".stubgen.toml", ] for config_path in search_paths: if config_path.exists(): return config_path return None def _parse_config(self, raw: Dict[str, Any]) -> Dict[str, Any]: """Parse raw TOML config into structured format.""" config = self._get_default_config() if "tool" in raw and "stubgen" in raw["tool"]: stubgen_config = raw["tool"]["stubgen"] config = self._merge_config(config, stubgen_config) if "stubgen" in raw: config = self._merge_config(config, raw["stubgen"]) return config def _merge_config(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: """Merge configuration dictionaries.""" result = base.copy() for key, value in override.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = self._merge_config(result[key], value) else: result[key] = value return result def _get_default_config(self) -> Dict[str, Any]: """Return default configuration.""" return { "exclude_patterns": [ "tests/*", "*/__pycache__/*", "*/.venv/*", "*/venv/*", "*/node_modules/*", "*.egg-info/*", ], "infer_depth": 3, "strict_mode": False, "interactive": False, "output_dir": None, "recursive": True, "verbose": False, "dry_run": False, } def get(self, key: str, default: Any = None) -> Any: """Get a configuration value.""" return self.config.get(key, default) def get_list(self, key: str) -> List[str]: """Get a configuration value as a list.""" value = self.config.get(key, []) if isinstance(value, str): return [value] return value def get_bool(self, key: str) -> bool: """Get a configuration value as a boolean.""" value = self.config.get(key, False) return bool(value) def get_int(self, key: str, default: int = 0) -> int: """Get a configuration value as an integer.""" value = self.config.get(key, default) try: return int(value) except (TypeError, ValueError): return default def load_config(config_path: Optional[Path] = None) -> Dict[str, Any]: """Load configuration from file.""" loader = ConfigLoader() return loader.load(config_path) def should_exclude(path: Path, exclude_patterns: List[str]) -> bool: """Check if a path should be excluded based on patterns.""" import fnmatch path_str = str(path) for pattern in exclude_patterns: if fnmatch.fnmatch(path_str, pattern): return True if fnmatch.fnmatch(path.name, pattern): return True if path.parent.name in pattern: return True return False