diff --git a/src/mcp_server_cli/config.py b/src/mcp_server_cli/config.py new file mode 100644 index 0000000..9edc38e --- /dev/null +++ b/src/mcp_server_cli/config.py @@ -0,0 +1,256 @@ +"""Configuration management for MCP Server CLI.""" + +import os +from pathlib import Path +from typing import Dict, Optional, Any +import yaml +from pydantic import ValidationError + +from mcp_server_cli.models import ( + AppConfig, + ServerConfig, + LocalLLMConfig, + SecurityConfig, + ToolConfig, +) + + +class ConfigManager: + """Manages application configuration with file and environment support.""" + + DEFAULT_CONFIG_FILENAME = "config.yaml" + ENV_VAR_PREFIX = "MCP" + + def __init__(self, config_path: Optional[Path] = None): + """Initialize the configuration manager. + + Args: + config_path: Optional path to configuration file. + """ + self.config_path = config_path + self._config: Optional[AppConfig] = None + + @classmethod + def get_env_var_name(cls, key: str) -> str: + """Convert a config key to an environment variable name. + + Args: + key: Configuration key (e.g., 'server.port') + + Returns: + Environment variable name (e.g., 'MCP_SERVER_PORT') + """ + return f"{cls.ENV_VAR_PREFIX}_{key.upper().replace('.', '_')}" + + def get_from_env(self, key: str, default: Any = None) -> Any: + """Get a configuration value from environment variables. + + Args: + key: Configuration key (e.g., 'server.port') + default: Default value if not found + + Returns: + The environment variable value or default + """ + env_key = self.get_env_var_name(key) + return os.environ.get(env_key, default) + + def load(self, path: Optional[Path] = None) -> AppConfig: + """Load configuration from file and environment. + + Args: + path: Optional path to configuration file. + + Returns: + Loaded and validated AppConfig object. + """ + config_path = path or self.config_path + + if config_path and config_path.exists(): + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) or {} + else: + config_data = {} + + config = self._merge_with_defaults(config_data) + config = self._apply_env_overrides(config) + + try: + self._config = AppConfig(**config) + except ValidationError as e: + raise ValueError(f"Configuration validation error: {e}") + + return self._config + + def _merge_with_defaults(self, config_data: Dict[str, Any]) -> Dict[str, Any]: + """Merge configuration data with default values. + + Args: + config_data: Configuration dictionary. + + Returns: + Merged configuration dictionary. + """ + defaults = { + "server": { + "host": "127.0.0.1", + "port": 3000, + "log_level": "INFO", + }, + "llm": { + "enabled": False, + "base_url": "http://localhost:11434", + "model": "llama2", + "temperature": 0.7, + "max_tokens": 2048, + "timeout": 60, + }, + "security": { + "allowed_commands": ["ls", "cat", "echo", "pwd", "git"], + "blocked_paths": ["/etc", "/root"], + "max_shell_timeout": 30, + "require_confirmation": False, + }, + "tools": [], + } + + if "server" not in config_data: + config_data["server"] = {} + config_data["server"] = {**defaults["server"], **config_data["server"]} + + if "llm" not in config_data: + config_data["llm"] = {} + config_data["llm"] = {**defaults["llm"], **config_data["llm"]} + + if "security" not in config_data: + config_data["security"] = {} + config_data["security"] = {**defaults["security"], **config_data["security"]} + + if "tools" not in config_data: + config_data["tools"] = defaults["tools"] + + return config_data + + def _apply_env_overrides(self, config: Dict[str, Any]) -> Dict[str, Any]: + """Apply environment variable overrides to configuration. + + Args: + config: Configuration dictionary. + + Returns: + Configuration with environment overrides applied. + """ + env_mappings = { + "MCP_PORT": ("server", "port", int), + "MCP_HOST": ("server", "host", str), + "MCP_CONFIG_PATH": ("_config_path", None, str), + "MCP_LOG_LEVEL": ("server", "log_level", str), + "MCP_LLM_URL": ("llm", "base_url", str), + "MCP_LLM_MODEL": ("llm", "model", str), + "MCP_LLM_ENABLED": ("llm", "enabled", lambda x: x.lower() == "true"), + } + + for env_var, mapping in env_mappings.items(): + value = os.environ.get(env_var) + if value is not None: + if mapping[1] is None: + config[mapping[0]] = mapping[2](value) + else: + section, key, converter = mapping + if section not in config: + config[section] = {} + config[section][key] = converter(value) + + return config + + def save(self, config: AppConfig, path: Optional[Path] = None) -> Path: + """Save configuration to a YAML file. + + Args: + config: Configuration to save. + path: Optional path to save to. + + Returns: + Path to the saved configuration file. + """ + save_path = path or self.config_path or Path(self.DEFAULT_CONFIG_FILENAME) + + config_dict = { + "server": config.server.model_dump(), + "llm": config.llm.model_dump(), + "security": config.security.model_dump(), + "tools": [tc.model_dump() for tc in config.tools], + } + + with open(save_path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, indent=2) + + return save_path + + def get_config(self) -> Optional[AppConfig]: + """Get the loaded configuration. + + Returns: + The loaded AppConfig or None if not loaded. + """ + return self._config + + @staticmethod + def generate_default_config() -> AppConfig: + """Generate a default configuration. + + Returns: + AppConfig with default values. + """ + return AppConfig() + + +def load_config_from_path(config_path: str) -> AppConfig: + """Load configuration from a specific path. + + Args: + config_path: Path to configuration file. + + Returns: + Loaded AppConfig. + + Raises: + FileNotFoundError: If config file doesn't exist. + ValidationError: If configuration is invalid. + """ + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + manager = ConfigManager(path) + return manager.load() + + +def create_config_template() -> Dict[str, Any]: + """Create a configuration template. + + Returns: + Dictionary with configuration template. + """ + return { + "server": { + "host": "127.0.0.1", + "port": 3000, + "log_level": "INFO", + }, + "llm": { + "enabled": False, + "base_url": "http://localhost:11434", + "model": "llama2", + "temperature": 0.7, + "max_tokens": 2048, + "timeout": 60, + }, + "security": { + "allowed_commands": ["ls", "cat", "echo", "pwd", "git", "grep", "find"], + "blocked_paths": ["/etc", "/root", "/home/*/.ssh"], + "max_shell_timeout": 30, + "require_confirmation": False, + }, + "tools": [], + }