This commit is contained in:
160
src/mcp_server_cli/tools/base.py
Normal file
160
src/mcp_server_cli/tools/base.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Base tool infrastructure for MCP Server CLI."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp_server_cli.models import ToolSchema, ToolParameter
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Result of tool execution."""
|
||||
|
||||
success: bool = True
|
||||
output: str
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ToolBase(ABC):
|
||||
"""Abstract base class for all MCP tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
annotations: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
description: Tool description.
|
||||
annotations: Optional tool annotations.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.annotations = annotations
|
||||
self._input_schema: Optional[ToolSchema] = None
|
||||
|
||||
@property
|
||||
def input_schema(self) -> ToolSchema:
|
||||
"""Get the tool's input schema."""
|
||||
if self._input_schema is None:
|
||||
self._input_schema = self._create_input_schema()
|
||||
return self._input_schema
|
||||
|
||||
@abstractmethod
|
||||
def _create_input_schema(self) -> ToolSchema:
|
||||
"""Create the tool's input schema.
|
||||
|
||||
Returns:
|
||||
ToolSchema defining expected parameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_parameters(self) -> Dict[str, ToolParameter]:
|
||||
"""Get the tool's parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameter name to ToolParameter.
|
||||
"""
|
||||
return self.input_schema.properties
|
||||
|
||||
def validate_arguments(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate and sanitize tool arguments.
|
||||
|
||||
Args:
|
||||
arguments: Input arguments to validate.
|
||||
|
||||
Returns:
|
||||
Validated arguments.
|
||||
"""
|
||||
required = self.input_schema.required
|
||||
properties = self.input_schema.properties
|
||||
|
||||
for param_name in required:
|
||||
if param_name not in arguments:
|
||||
raise ValueError(f"Missing required parameter: {param_name}")
|
||||
|
||||
validated = {}
|
||||
for name, param in properties.items():
|
||||
if name in arguments:
|
||||
value = arguments[name]
|
||||
if param.enum and value not in param.enum:
|
||||
raise ValueError(f"Invalid value for {name}: {value}")
|
||||
validated[name] = value
|
||||
|
||||
return validated
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, arguments: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute the tool with given arguments.
|
||||
|
||||
Args:
|
||||
arguments: Tool arguments.
|
||||
|
||||
Returns:
|
||||
ToolResult with execution output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the tool registry."""
|
||||
self._tools: Dict[str, ToolBase] = {}
|
||||
|
||||
def register(self, tool: ToolBase):
|
||||
"""Register a tool.
|
||||
|
||||
Args:
|
||||
tool: Tool to register.
|
||||
"""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[ToolBase]:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
|
||||
Returns:
|
||||
Tool or None if not found.
|
||||
"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
|
||||
Returns:
|
||||
True if tool was unregistered.
|
||||
"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list(self) -> List[ToolBase]:
|
||||
"""List all registered tools.
|
||||
|
||||
Returns:
|
||||
List of registered tools.
|
||||
"""
|
||||
return list(self._tools.values())
|
||||
|
||||
def list_names(self) -> List[str]:
|
||||
"""List all tool names.
|
||||
|
||||
Returns:
|
||||
List of tool names.
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def clear(self):
|
||||
"""Clear all registered tools."""
|
||||
self._tools.clear()
|
||||
Reference in New Issue
Block a user