297 lines
8.4 KiB
Python
297 lines
8.4 KiB
Python
"""MCP Protocol Server implementation using FastAPI."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, Dict, List, Optional, Callable, Awaitable
|
|
from enum import Enum
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
import sse_starlette.sse as sse
|
|
|
|
from mcp_server_cli.models import (
|
|
MCPRequest,
|
|
MCPResponse,
|
|
MCPNotification,
|
|
MCPMethod,
|
|
ToolDefinition,
|
|
ToolCallParams,
|
|
ToolCallResult,
|
|
InitializeParams,
|
|
InitializeResult,
|
|
ServerInfo,
|
|
ServerCapabilities,
|
|
ToolsListResult,
|
|
)
|
|
from mcp_server_cli.config import AppConfig, ConfigManager
|
|
from mcp_server_cli.tools import ToolBase, ToolResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MCPConnectionState(str, Enum):
|
|
"""State of MCP connection."""
|
|
|
|
DISCONNECTED = "disconnected"
|
|
INITIALIZING = "initializing"
|
|
READY = "ready"
|
|
|
|
|
|
class MCPServer:
|
|
"""MCP Protocol Server implementation."""
|
|
|
|
def __init__(self, config: Optional[AppConfig] = None):
|
|
"""Initialize the MCP server.
|
|
|
|
Args:
|
|
config: Optional server configuration.
|
|
"""
|
|
self.config = config or AppConfig()
|
|
self.tool_registry: Dict[str, ToolBase] = {}
|
|
self.connection_state = MCPConnectionState.DISCONNECTED
|
|
self._initialized = False
|
|
|
|
def register_tool(self, tool: ToolBase):
|
|
"""Register a tool with the server.
|
|
|
|
Args:
|
|
tool: Tool to register.
|
|
"""
|
|
self.tool_registry[tool.name] = tool
|
|
|
|
def get_tool(self, name: str) -> Optional[ToolBase]:
|
|
"""Get a tool by name.
|
|
|
|
Args:
|
|
name: Tool name.
|
|
|
|
Returns:
|
|
Tool or None if not found.
|
|
"""
|
|
return self.tool_registry.get(name)
|
|
|
|
def list_tools(self) -> List[ToolDefinition]:
|
|
"""List all registered tools.
|
|
|
|
Returns:
|
|
List of tool definitions.
|
|
"""
|
|
return [
|
|
ToolDefinition(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
input_schema=tool.input_schema,
|
|
annotations=tool.annotations,
|
|
)
|
|
for tool in self.tool_registry.values()
|
|
]
|
|
|
|
async def handle_request(self, request: MCPRequest) -> MCPResponse:
|
|
"""Handle an MCP request.
|
|
|
|
Args:
|
|
request: MCP request message.
|
|
|
|
Returns:
|
|
MCP response message.
|
|
"""
|
|
method = request.method
|
|
params = request.params or {}
|
|
|
|
try:
|
|
if method == MCPMethod.INITIALIZE:
|
|
result = await self._handle_initialize(InitializeParams(**params))
|
|
elif method == MCPMethod.TOOLS_LIST:
|
|
result = await self._handle_tools_list()
|
|
elif method == MCPMethod.TOOLS_CALL:
|
|
result = await self._handle_tool_call(ToolCallParams(**params))
|
|
else:
|
|
return MCPResponse(
|
|
id=request.id,
|
|
error={"code": -32601, "message": f"Method not found: {method}"},
|
|
)
|
|
|
|
return MCPResponse(id=request.id, result=result.model_dump())
|
|
except Exception as e:
|
|
logger.error(f"Error handling request: {e}", exc_info=True)
|
|
return MCPResponse(
|
|
id=request.id,
|
|
error={"code": -32603, "message": str(e)},
|
|
)
|
|
|
|
async def _handle_initialize(self, params: InitializeParams) -> InitializeResult:
|
|
"""Handle MCP initialize request.
|
|
|
|
Args:
|
|
params: Initialize parameters.
|
|
|
|
Returns:
|
|
Initialize result.
|
|
"""
|
|
self.connection_state = MCPConnectionState.INITIALIZING
|
|
self._initialized = True
|
|
self.connection_state = MCPConnectionState.READY
|
|
|
|
return InitializeResult(
|
|
protocol_version=params.protocol_version,
|
|
server_info=ServerInfo(
|
|
name="mcp-server-cli",
|
|
version="0.1.0",
|
|
),
|
|
capabilities=ServerCapabilities(
|
|
tools={"listChanged": True},
|
|
resources={},
|
|
prompts={},
|
|
),
|
|
)
|
|
|
|
async def _handle_tools_list(self) -> ToolsListResult:
|
|
"""Handle tools/list request.
|
|
|
|
Returns:
|
|
List of available tools.
|
|
"""
|
|
return ToolsListResult(tools=self.list_tools())
|
|
|
|
async def _handle_tool_call(self, params: ToolCallParams) -> ToolCallResult:
|
|
"""Handle tools/call request.
|
|
|
|
Args:
|
|
params: Tool call parameters.
|
|
|
|
Returns:
|
|
Tool execution result.
|
|
"""
|
|
tool = self.get_tool(params.name)
|
|
if not tool:
|
|
return ToolCallResult(
|
|
content=[],
|
|
is_error=True,
|
|
error_message=f"Tool not found: {params.name}",
|
|
)
|
|
|
|
try:
|
|
result = await tool.execute(params.arguments or {})
|
|
return ToolCallResult(content=[{"type": "text", "text": result.output}])
|
|
except Exception as e:
|
|
return ToolCallResult(
|
|
content=[],
|
|
is_error=True,
|
|
error_message=str(e),
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan context manager."""
|
|
logger.info("MCP Server starting up...")
|
|
yield
|
|
logger.info("MCP Server shutting down...")
|
|
|
|
|
|
def create_app(config: Optional[AppConfig] = None) -> FastAPI:
|
|
"""Create and configure the FastAPI application.
|
|
|
|
Args:
|
|
config: Optional server configuration.
|
|
|
|
Returns:
|
|
Configured FastAPI application.
|
|
"""
|
|
mcp_server = MCPServer(config)
|
|
|
|
app = FastAPI(
|
|
title="MCP Server CLI",
|
|
description="Model Context Protocol Server",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "state": mcp_server.connection_state}
|
|
|
|
@app.get("/api/tools")
|
|
async def list_tools():
|
|
"""List all available tools."""
|
|
return {"tools": [t.model_dump() for t in mcp_server.list_tools()]}
|
|
|
|
@app.post("/api/tools/call")
|
|
async def call_tool(request: Request):
|
|
"""Call a tool by name."""
|
|
body = await request.json()
|
|
tool_name = body.get("name")
|
|
arguments = body.get("arguments", {})
|
|
|
|
tool = mcp_server.get_tool(tool_name)
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail=f"Tool not found: {tool_name}")
|
|
|
|
try:
|
|
result = await tool.execute(arguments)
|
|
return {"success": True, "output": result.output}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/mcp")
|
|
async def handle_mcp(request: MCPRequest):
|
|
"""Handle MCP protocol messages."""
|
|
response = await mcp_server.handle_request(request)
|
|
return response.model_dump()
|
|
|
|
@app.post("/mcp/{path:path}")
|
|
async def handle_mcp_fallback(path: str, request: Request):
|
|
"""Handle MCP protocol messages at various paths."""
|
|
body = await request.json()
|
|
mcp_request = MCPRequest(**body)
|
|
response = await mcp_server.handle_request(mcp_request)
|
|
return response.model_dump()
|
|
|
|
return app
|
|
|
|
|
|
def run_server(
|
|
host: str = "127.0.0.1",
|
|
port: int = 3000,
|
|
config_path: Optional[str] = None,
|
|
log_level: str = "INFO",
|
|
):
|
|
"""Run the MCP server using uvicorn.
|
|
|
|
Args:
|
|
host: Host to bind to.
|
|
port: Port to listen on.
|
|
config_path: Path to configuration file.
|
|
log_level: Logging level.
|
|
"""
|
|
import uvicorn
|
|
|
|
logging.basicConfig(level=getattr(logging, log_level.upper()))
|
|
|
|
config = None
|
|
if config_path:
|
|
try:
|
|
config_manager = ConfigManager()
|
|
config = config_manager.load(Path(config_path))
|
|
host = config.server.host
|
|
port = config.server.port
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load config: {e}")
|
|
|
|
app = create_app(config)
|
|
|
|
uvicorn.run(app, host=host, port=port)
|