162 lines
4.0 KiB
Python
162 lines
4.0 KiB
Python
"""Base tool infrastructure for MCP Server CLI."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from mcp_server_cli.models import ToolParameter, ToolSchema
|
|
|
|
|
|
class ToolResult(BaseModel):
|
|
"""Result of tool execution."""
|
|
|
|
success: bool = True
|
|
output: str
|
|
error: Optional[str] = None
|
|
|
|
|
|
class ToolBase(ABC):
|
|
"""Abstract base class for all MCP tools."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
annotations: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""Initialize a tool.
|
|
|
|
Args:
|
|
name: Tool name.
|
|
description: Tool description.
|
|
annotations: Optional tool annotations.
|
|
"""
|
|
self.name = name
|
|
self.description = description
|
|
self.annotations = annotations
|
|
self._input_schema: Optional[ToolSchema] = None
|
|
|
|
@property
|
|
def input_schema(self) -> ToolSchema:
|
|
"""Get the tool's input schema."""
|
|
if self._input_schema is None:
|
|
self._input_schema = self._create_input_schema()
|
|
return self._input_schema
|
|
|
|
@abstractmethod
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
"""Create the tool's input schema.
|
|
|
|
Returns:
|
|
ToolSchema defining expected parameters.
|
|
"""
|
|
pass
|
|
|
|
def get_parameters(self) -> Dict[str, ToolParameter]:
|
|
"""Get the tool's parameters.
|
|
|
|
Returns:
|
|
Dictionary of parameter name to ToolParameter.
|
|
"""
|
|
return self.input_schema.properties
|
|
|
|
def validate_arguments(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate and sanitize tool arguments.
|
|
|
|
Args:
|
|
arguments: Input arguments to validate.
|
|
|
|
Returns:
|
|
Validated arguments.
|
|
"""
|
|
required = self.input_schema.required
|
|
properties = self.input_schema.properties
|
|
|
|
for param_name in required:
|
|
if param_name not in arguments:
|
|
raise ValueError(f"Missing required parameter: {param_name}")
|
|
|
|
validated = {}
|
|
for name, param in properties.items():
|
|
if name in arguments:
|
|
value = arguments[name]
|
|
if param.enum and value not in param.enum:
|
|
raise ValueError(f"Invalid value for {name}: {value}")
|
|
validated[name] = value
|
|
|
|
return validated
|
|
|
|
@abstractmethod
|
|
async def execute(self, arguments: Dict[str, Any]) -> ToolResult:
|
|
"""Execute the tool with given arguments.
|
|
|
|
Args:
|
|
arguments: Tool arguments.
|
|
|
|
Returns:
|
|
ToolResult with execution output.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ToolRegistry:
|
|
"""Registry for managing tools."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the tool registry."""
|
|
self._tools: Dict[str, ToolBase] = {}
|
|
|
|
def register(self, tool: ToolBase):
|
|
"""Register a tool.
|
|
|
|
Args:
|
|
tool: Tool to register.
|
|
"""
|
|
self._tools[tool.name] = tool
|
|
|
|
def get(self, name: str) -> Optional[ToolBase]:
|
|
"""Get a tool by name.
|
|
|
|
Args:
|
|
name: Tool name.
|
|
|
|
Returns:
|
|
Tool or None if not found.
|
|
"""
|
|
return self._tools.get(name)
|
|
|
|
def unregister(self, name: str) -> bool:
|
|
"""Unregister a tool.
|
|
|
|
Args:
|
|
name: Tool name.
|
|
|
|
Returns:
|
|
True if tool was unregistered.
|
|
"""
|
|
if name in self._tools:
|
|
del self._tools[name]
|
|
return True
|
|
return False
|
|
|
|
def list(self) -> List[ToolBase]:
|
|
"""List all registered tools.
|
|
|
|
Returns:
|
|
List of registered tools.
|
|
"""
|
|
return list(self._tools.values())
|
|
|
|
def list_names(self) -> List[str]:
|
|
"""List all tool names.
|
|
|
|
Returns:
|
|
List of tool names.
|
|
"""
|
|
return list(self._tools.keys())
|
|
|
|
def clear(self):
|
|
"""Clear all registered tools."""
|
|
self._tools.clear()
|