"""MCP Protocol Server implementation using FastAPI.""" import asyncio import json import logging from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Callable, Awaitable from enum import Enum from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel import sse_starlette.sse as sse from mcp_server_cli.models import ( MCPRequest, MCPResponse, MCPNotification, MCPMethod, ToolDefinition, ToolCallParams, ToolCallResult, InitializeParams, InitializeResult, ServerInfo, ServerCapabilities, ToolsListResult, ) from mcp_server_cli.config import AppConfig, ConfigManager from mcp_server_cli.tools import ToolBase, ToolResult logger = logging.getLogger(__name__) class MCPConnectionState(str, Enum): """State of MCP connection.""" DISCONNECTED = "disconnected" INITIALIZING = "initializing" READY = "ready" class MCPServer: """MCP Protocol Server implementation.""" def __init__(self, config: Optional[AppConfig] = None): """Initialize the MCP server. Args: config: Optional server configuration. """ self.config = config or AppConfig() self.tool_registry: Dict[str, ToolBase] = {} self.connection_state = MCPConnectionState.DISCONNECTED self._initialized = False def register_tool(self, tool: ToolBase): """Register a tool with the server. Args: tool: Tool to register. """ self.tool_registry[tool.name] = tool def get_tool(self, name: str) -> Optional[ToolBase]: """Get a tool by name. Args: name: Tool name. Returns: Tool or None if not found. """ return self.tool_registry.get(name) def list_tools(self) -> List[ToolDefinition]: """List all registered tools. Returns: List of tool definitions. """ return [ ToolDefinition( name=tool.name, description=tool.description, input_schema=tool.input_schema, annotations=tool.annotations, ) for tool in self.tool_registry.values() ] async def handle_request(self, request: MCPRequest) -> MCPResponse: """Handle an MCP request. Args: request: MCP request message. Returns: MCP response message. """ method = request.method params = request.params or {} try: if method == MCPMethod.INITIALIZE: result = await self._handle_initialize(InitializeParams(**params)) elif method == MCPMethod.TOOLS_LIST: result = await self._handle_tools_list() elif method == MCPMethod.TOOLS_CALL: result = await self._handle_tool_call(ToolCallParams(**params)) else: return MCPResponse( id=request.id, error={"code": -32601, "message": f"Method not found: {method}"}, ) return MCPResponse(id=request.id, result=result.model_dump()) except Exception as e: logger.error(f"Error handling request: {e}", exc_info=True) return MCPResponse( id=request.id, error={"code": -32603, "message": str(e)}, ) async def _handle_initialize(self, params: InitializeParams) -> InitializeResult: """Handle MCP initialize request. Args: params: Initialize parameters. Returns: Initialize result. """ self.connection_state = MCPConnectionState.INITIALIZING self._initialized = True self.connection_state = MCPConnectionState.READY return InitializeResult( protocol_version=params.protocol_version, server_info=ServerInfo( name="mcp-server-cli", version="0.1.0", ), capabilities=ServerCapabilities( tools={"listChanged": True}, resources={}, prompts={}, ), ) async def _handle_tools_list(self) -> ToolsListResult: """Handle tools/list request. Returns: List of available tools. """ return ToolsListResult(tools=self.list_tools()) async def _handle_tool_call(self, params: ToolCallParams) -> ToolCallResult: """Handle tools/call request. Args: params: Tool call parameters. Returns: Tool execution result. """ tool = self.get_tool(params.name) if not tool: return ToolCallResult( content=[], is_error=True, error_message=f"Tool not found: {params.name}", ) try: result = await tool.execute(params.arguments or {}) return ToolCallResult(content=[{"type": "text", "text": result.output}]) except Exception as e: return ToolCallResult( content=[], is_error=True, error_message=str(e), ) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan context manager.""" logger.info("MCP Server starting up...") yield logger.info("MCP Server shutting down...") def create_app(config: Optional[AppConfig] = None) -> FastAPI: """Create and configure the FastAPI application. Args: config: Optional server configuration. Returns: Configured FastAPI application. """ mcp_server = MCPServer(config) app = FastAPI( title="MCP Server CLI", description="Model Context Protocol Server", version="0.1.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "healthy", "state": mcp_server.connection_state} @app.get("/api/tools") async def list_tools(): """List all available tools.""" return {"tools": [t.model_dump() for t in mcp_server.list_tools()]} @app.post("/api/tools/call") async def call_tool(request: Request): """Call a tool by name.""" body = await request.json() tool_name = body.get("name") arguments = body.get("arguments", {}) tool = mcp_server.get_tool(tool_name) if not tool: raise HTTPException(status_code=404, detail=f"Tool not found: {tool_name}") try: result = await tool.execute(arguments) return {"success": True, "output": result.output} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/mcp") async def handle_mcp(request: MCPRequest): """Handle MCP protocol messages.""" response = await mcp_server.handle_request(request) return response.model_dump() @app.post("/mcp/{path:path}") async def handle_mcp_fallback(path: str, request: Request): """Handle MCP protocol messages at various paths.""" body = await request.json() mcp_request = MCPRequest(**body) response = await mcp_server.handle_request(mcp_request) return response.model_dump() return app def run_server( host: str = "127.0.0.1", port: int = 3000, config_path: Optional[str] = None, log_level: str = "INFO", ): """Run the MCP server using uvicorn. Args: host: Host to bind to. port: Port to listen on. config_path: Path to configuration file. log_level: Logging level. """ import uvicorn logging.basicConfig(level=getattr(logging, log_level.upper())) config = None if config_path: try: config_manager = ConfigManager() config = config_manager.load(Path(config_path)) host = config.server.host port = config.server.port except Exception as e: logger.warning(f"Failed to load config: {e}") app = create_app(config) uvicorn.run(app, host=host, port=port)