"""Base tool infrastructure for MCP Server CLI.""" from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from pydantic import BaseModel from mcp_server_cli.models import ToolParameter, ToolSchema class ToolResult(BaseModel): """Result of tool execution.""" success: bool = True output: str error: Optional[str] = None class ToolBase(ABC): """Abstract base class for all MCP tools.""" def __init__( self, name: str, description: str, annotations: Optional[Dict[str, Any]] = None, ): """Initialize a tool. Args: name: Tool name. description: Tool description. annotations: Optional tool annotations. """ self.name = name self.description = description self.annotations = annotations self._input_schema: Optional[ToolSchema] = None @property def input_schema(self) -> ToolSchema: """Get the tool's input schema.""" if self._input_schema is None: self._input_schema = self._create_input_schema() return self._input_schema @abstractmethod def _create_input_schema(self) -> ToolSchema: """Create the tool's input schema. Returns: ToolSchema defining expected parameters. """ pass def get_parameters(self) -> Dict[str, ToolParameter]: """Get the tool's parameters. Returns: Dictionary of parameter name to ToolParameter. """ return self.input_schema.properties def validate_arguments(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Validate and sanitize tool arguments. Args: arguments: Input arguments to validate. Returns: Validated arguments. """ required = self.input_schema.required properties = self.input_schema.properties for param_name in required: if param_name not in arguments: raise ValueError(f"Missing required parameter: {param_name}") validated = {} for name, param in properties.items(): if name in arguments: value = arguments[name] if param.enum and value not in param.enum: raise ValueError(f"Invalid value for {name}: {value}") validated[name] = value return validated @abstractmethod async def execute(self, arguments: Dict[str, Any]) -> ToolResult: """Execute the tool with given arguments. Args: arguments: Tool arguments. Returns: ToolResult with execution output. """ pass class ToolRegistry: """Registry for managing tools.""" def __init__(self): """Initialize the tool registry.""" self._tools: Dict[str, ToolBase] = {} def register(self, tool: ToolBase): """Register a tool. Args: tool: Tool to register. """ self._tools[tool.name] = tool def get(self, name: str) -> Optional[ToolBase]: """Get a tool by name. Args: name: Tool name. Returns: Tool or None if not found. """ return self._tools.get(name) def unregister(self, name: str) -> bool: """Unregister a tool. Args: name: Tool name. Returns: True if tool was unregistered. """ if name in self._tools: del self._tools[name] return True return False def list(self) -> List[ToolBase]: """List all registered tools. Returns: List of registered tools. """ return list(self._tools.values()) def list_names(self) -> List[str]: """List all tool names. Returns: List of tool names. """ return list(self._tools.keys()) def clear(self): """Clear all registered tools.""" self._tools.clear()