diff --git a/depnav/src/depnav/config.py b/depnav/src/depnav/config.py new file mode 100644 index 0000000..b14c887 --- /dev/null +++ b/depnav/src/depnav/config.py @@ -0,0 +1,162 @@ +"""Configuration management for depnav.""" + +import os +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import yaml + + +class Config: + """Configuration manager for depnav.""" + + DEFAULT_CONFIG_NAMES = [".depnav.yaml", "depnav.yaml", "pyproject.toml"] + DEFAULT_DEPTH = 3 + DEFAULT_MAX_NODES = 100 + + def __init__(self, config_path: Optional[Path] = None): + self._config: dict[str, Any] = {} + self._config_path = config_path + + def load(self) -> None: + """Load configuration from file and environment.""" + self._config = {} + + if self._config_path and self._config_path.exists(): + self._load_from_file(self._config_path) + else: + self._find_and_load_config() + + self._apply_env_overrides() + + def _find_and_load_config(self) -> None: + """Find and load configuration from standard locations.""" + for config_name in self.DEFAULT_CONFIG_NAMES: + config_path = Path.cwd() / config_name + if config_path.exists(): + self._load_from_file(config_path) + break + + def _load_from_file(self, path: Path) -> None: + """Load configuration from a YAML or TOML file.""" + try: + content = path.read_text() + ext = path.suffix.lower() + + if ext == ".toml": + try: + import tomli + data = tomli.loads(content) or {} + except ImportError: + try: + import tomllib + data = tomllib.loads(content) or {} + except ImportError: + data = {} + + if path.name == "pyproject.toml": + data = data.get("tool", {}).get("depnav", {}) + else: + data = yaml.safe_load(content) or {} + + self._config.update(data) + except (yaml.YAMLError, OSError): + pass + + def _apply_env_overrides(self) -> None: + """Apply environment variable overrides.""" + env_map: dict[str, Tuple[str, Optional[int]]] = { + "DEPNAV_CONFIG": ("config_file", None), + "DEPNAV_THEME": ("theme", None), + "DEPNAV_PAGER": ("pager", None), + "DEPNAV_DEPTH": ("depth", self.DEFAULT_DEPTH), + "DEPNAV_MAX_NODES": ("max_nodes", self.DEFAULT_MAX_NODES), + } + + for env_var, (key, default) in env_map.items(): + env_value = os.environ.get(env_var) + if env_value is not None: + value: Union[int, str] = env_value + if default is not None and isinstance(default, int): + value = int(env_value) + self._config[key] = value + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value.""" + return self._config.get(key, default) + + def get_theme(self) -> dict[str, Any]: + """Get the current theme configuration.""" + themes = self._config.get("themes", {}) + theme_name = self._config.get("theme", "default") + return themes.get(theme_name, themes.get("default", self._default_theme())) + + def _default_theme(self) -> dict[str, Any]: + """Return the default theme configuration.""" + return { + "node_style": "cyan", + "edge_style": "dim", + "highlight_style": "yellow", + "cycle_style": "red", + } + + def get_exclude_patterns(self) -> list[str]: + """Get patterns for excluding files/directories.""" + return self._config.get("exclude", ["__pycache__", "node_modules", ".git"]) + + def get_include_extensions(self) -> list[str]: + """Get file extensions to include.""" + return self._config.get( + "extensions", [".py", ".js", ".jsx", ".ts", ".tsx", ".go"] + ) + + def get_output_format(self) -> str: + """Get the default output format.""" + return self._config.get("output", "ascii") + + def get_depth(self) -> int: + """Get the default traversal depth.""" + return int(self._config.get("depth", self.DEFAULT_DEPTH)) + + def get_max_nodes(self) -> int: + """Get the maximum number of nodes to display.""" + return int(self._config.get("max_nodes", self.DEFAULT_MAX_NODES)) + + def save(self, path: Path) -> None: + """Save current configuration to a file.""" + with open(path, "w") as f: + yaml.dump(self._config, f, default_flow_style=False) + + def set(self, key: str, value: Any) -> None: + """Set a configuration value.""" + self._config[key] = value + + def merge(self, other: dict[str, Any]) -> None: + """Merge another configuration dictionary.""" + self._config.update(other) + + def __repr__(self) -> str: + return f"Config({self._config})" + + +def load_config(config_path: Optional[Path] = None) -> Config: + """Load and return a configuration object.""" + config = Config(config_path) + config.load() + return config + + +def get_config_value( + key: str, default: Any = None, config: Optional[Config] = None +) -> Any: + """Get a configuration value from a config or environment.""" + if config is not None: + return config.get(key, default) + + env_key = f"DEPNAV_{key.upper()}" + env_value = os.environ.get(env_key) + if env_value is not None: + return env_value + + cfg = load_config() + return cfg.get(key, default)