This commit is contained in:
296
src/mcp_server_cli/server.py
Normal file
296
src/mcp_server_cli/server.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""MCP Protocol Server implementation using FastAPI."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, Optional, Callable, Awaitable
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import sse_starlette.sse as sse
|
||||
|
||||
from mcp_server_cli.models import (
|
||||
MCPRequest,
|
||||
MCPResponse,
|
||||
MCPNotification,
|
||||
MCPMethod,
|
||||
ToolDefinition,
|
||||
ToolCallParams,
|
||||
ToolCallResult,
|
||||
InitializeParams,
|
||||
InitializeResult,
|
||||
ServerInfo,
|
||||
ServerCapabilities,
|
||||
ToolsListResult,
|
||||
)
|
||||
from mcp_server_cli.config import AppConfig, ConfigManager
|
||||
from mcp_server_cli.tools import ToolBase, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPConnectionState(str, Enum):
|
||||
"""State of MCP connection."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
INITIALIZING = "initializing"
|
||||
READY = "ready"
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""MCP Protocol Server implementation."""
|
||||
|
||||
def __init__(self, config: Optional[AppConfig] = None):
|
||||
"""Initialize the MCP server.
|
||||
|
||||
Args:
|
||||
config: Optional server configuration.
|
||||
"""
|
||||
self.config = config or AppConfig()
|
||||
self.tool_registry: Dict[str, ToolBase] = {}
|
||||
self.connection_state = MCPConnectionState.DISCONNECTED
|
||||
self._initialized = False
|
||||
|
||||
def register_tool(self, tool: ToolBase):
|
||||
"""Register a tool with the server.
|
||||
|
||||
Args:
|
||||
tool: Tool to register.
|
||||
"""
|
||||
self.tool_registry[tool.name] = tool
|
||||
|
||||
def get_tool(self, name: str) -> Optional[ToolBase]:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
|
||||
Returns:
|
||||
Tool or None if not found.
|
||||
"""
|
||||
return self.tool_registry.get(name)
|
||||
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""List all registered tools.
|
||||
|
||||
Returns:
|
||||
List of tool definitions.
|
||||
"""
|
||||
return [
|
||||
ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
annotations=tool.annotations,
|
||||
)
|
||||
for tool in self.tool_registry.values()
|
||||
]
|
||||
|
||||
async def handle_request(self, request: MCPRequest) -> MCPResponse:
|
||||
"""Handle an MCP request.
|
||||
|
||||
Args:
|
||||
request: MCP request message.
|
||||
|
||||
Returns:
|
||||
MCP response message.
|
||||
"""
|
||||
method = request.method
|
||||
params = request.params or {}
|
||||
|
||||
try:
|
||||
if method == MCPMethod.INITIALIZE:
|
||||
result = await self._handle_initialize(InitializeParams(**params))
|
||||
elif method == MCPMethod.TOOLS_LIST:
|
||||
result = await self._handle_tools_list()
|
||||
elif method == MCPMethod.TOOLS_CALL:
|
||||
result = await self._handle_tool_call(ToolCallParams(**params))
|
||||
else:
|
||||
return MCPResponse(
|
||||
id=request.id,
|
||||
error={"code": -32601, "message": f"Method not found: {method}"},
|
||||
)
|
||||
|
||||
return MCPResponse(id=request.id, result=result.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling request: {e}", exc_info=True)
|
||||
return MCPResponse(
|
||||
id=request.id,
|
||||
error={"code": -32603, "message": str(e)},
|
||||
)
|
||||
|
||||
async def _handle_initialize(self, params: InitializeParams) -> InitializeResult:
|
||||
"""Handle MCP initialize request.
|
||||
|
||||
Args:
|
||||
params: Initialize parameters.
|
||||
|
||||
Returns:
|
||||
Initialize result.
|
||||
"""
|
||||
self.connection_state = MCPConnectionState.INITIALIZING
|
||||
self._initialized = True
|
||||
self.connection_state = MCPConnectionState.READY
|
||||
|
||||
return InitializeResult(
|
||||
protocol_version=params.protocol_version,
|
||||
server_info=ServerInfo(
|
||||
name="mcp-server-cli",
|
||||
version="0.1.0",
|
||||
),
|
||||
capabilities=ServerCapabilities(
|
||||
tools={"listChanged": True},
|
||||
resources={},
|
||||
prompts={},
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_tools_list(self) -> ToolsListResult:
|
||||
"""Handle tools/list request.
|
||||
|
||||
Returns:
|
||||
List of available tools.
|
||||
"""
|
||||
return ToolsListResult(tools=self.list_tools())
|
||||
|
||||
async def _handle_tool_call(self, params: ToolCallParams) -> ToolCallResult:
|
||||
"""Handle tools/call request.
|
||||
|
||||
Args:
|
||||
params: Tool call parameters.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
tool = self.get_tool(params.name)
|
||||
if not tool:
|
||||
return ToolCallResult(
|
||||
content=[],
|
||||
is_error=True,
|
||||
error_message=f"Tool not found: {params.name}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool.execute(params.arguments or {})
|
||||
return ToolCallResult(content=[{"type": "text", "text": result.output}])
|
||||
except Exception as e:
|
||||
return ToolCallResult(
|
||||
content=[],
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan context manager."""
|
||||
logger.info("MCP Server starting up...")
|
||||
yield
|
||||
logger.info("MCP Server shutting down...")
|
||||
|
||||
|
||||
def create_app(config: Optional[AppConfig] = None) -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Optional server configuration.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application.
|
||||
"""
|
||||
mcp_server = MCPServer(config)
|
||||
|
||||
app = FastAPI(
|
||||
title="MCP Server CLI",
|
||||
description="Model Context Protocol Server",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "state": mcp_server.connection_state}
|
||||
|
||||
@app.get("/api/tools")
|
||||
async def list_tools():
|
||||
"""List all available tools."""
|
||||
return {"tools": [t.model_dump() for t in mcp_server.list_tools()]}
|
||||
|
||||
@app.post("/api/tools/call")
|
||||
async def call_tool(request: Request):
|
||||
"""Call a tool by name."""
|
||||
body = await request.json()
|
||||
tool_name = body.get("name")
|
||||
arguments = body.get("arguments", {})
|
||||
|
||||
tool = mcp_server.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool not found: {tool_name}")
|
||||
|
||||
try:
|
||||
result = await tool.execute(arguments)
|
||||
return {"success": True, "output": result.output}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/mcp")
|
||||
async def handle_mcp(request: MCPRequest):
|
||||
"""Handle MCP protocol messages."""
|
||||
response = await mcp_server.handle_request(request)
|
||||
return response.model_dump()
|
||||
|
||||
@app.post("/mcp/{path:path}")
|
||||
async def handle_mcp_fallback(path: str, request: Request):
|
||||
"""Handle MCP protocol messages at various paths."""
|
||||
body = await request.json()
|
||||
mcp_request = MCPRequest(**body)
|
||||
response = await mcp_server.handle_request(mcp_request)
|
||||
return response.model_dump()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3000,
|
||||
config_path: Optional[str] = None,
|
||||
log_level: str = "INFO",
|
||||
):
|
||||
"""Run the MCP server using uvicorn.
|
||||
|
||||
Args:
|
||||
host: Host to bind to.
|
||||
port: Port to listen on.
|
||||
config_path: Path to configuration file.
|
||||
log_level: Logging level.
|
||||
"""
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=getattr(logging, log_level.upper()))
|
||||
|
||||
config = None
|
||||
if config_path:
|
||||
try:
|
||||
config_manager = ConfigManager()
|
||||
config = config_manager.load(Path(config_path))
|
||||
host = config.server.host
|
||||
port = config.server.port
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
Reference in New Issue
Block a user