diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..2f6a7da --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,202 @@ +"""Tests for MCP protocol message models.""" + +import pytest +from pydantic import ValidationError + +from mcp_server_cli.models import ( + MCPRequest, + MCPResponse, + MCPNotification, + MCPMethod, + ToolDefinition, + ToolSchema, + ToolParameter, + ToolCallParams, + ToolCallResult, + InitializeParams, + InitializeResult, + ServerConfig, + LocalLLMConfig, + SecurityConfig, + AppConfig, +) + + +class TestMCPRequest: + """Tests for MCPRequest model.""" + + def test_create_request(self): + """Test creating a basic request.""" + request = MCPRequest( + id=1, + method=MCPMethod.INITIALIZE, + params={"protocol_version": "2024-11-05"}, + ) + assert request.jsonrpc == "2.0" + assert request.id == 1 + assert request.method == MCPMethod.INITIALIZE + assert request.params["protocol_version"] == "2024-11-05" + + def test_request_without_id(self): + """Test request without id (notification).""" + request = MCPRequest( + method=MCPMethod.TOOLS_LIST, + ) + assert request.id is None + assert request.method == MCPMethod.TOOLS_LIST + + +class TestMCPResponse: + """Tests for MCPResponse model.""" + + def test_create_response_with_result(self): + """Test creating a response with result.""" + response = MCPResponse( + id=1, + result={"tools": [{"name": "test_tool"}]}, + ) + assert response.jsonrpc == "2.0" + assert response.id == 1 + assert response.result["tools"][0]["name"] == "test_tool" + + def test_create_response_with_error(self): + """Test creating a response with error.""" + response = MCPResponse( + id=1, + error={"code": -32600, "message": "Invalid Request"}, + ) + assert response.error["code"] == -32600 + assert response.error["message"] == "Invalid Request" + + +class TestToolModels: + """Tests for tool-related models.""" + + def test_tool_definition(self): + """Test creating a tool definition.""" + tool = ToolDefinition( + name="read_file", + description="Read a file from disk", + input_schema=ToolSchema( + properties={ + "path": ToolParameter( + name="path", + type="string", + description="File path", + required=True, + ) + }, + required=["path"], + ), + ) + assert tool.name == "read_file" + assert "path" in tool.input_schema.required + + def test_tool_name_validation(self): + """Test that tool names must be valid identifiers.""" + with pytest.raises(ValidationError): + ToolDefinition( + name="invalid-name", + description="Tool with invalid name", + ) + + def test_tool_parameter_with_enum(self): + """Test tool parameter with enum values.""" + param = ToolParameter( + name="operation", + type="string", + enum=["read", "write", "delete"], + ) + assert param.enum == ["read", "write", "delete"] + + +class TestConfigModels: + """Tests for configuration models.""" + + def test_server_config_defaults(self): + """Test server config default values.""" + config = ServerConfig() + assert config.host == "127.0.0.1" + assert config.port == 3000 + assert config.log_level == "INFO" + + def test_local_llm_config_defaults(self): + """Test local LLM config defaults.""" + config = LocalLLMConfig() + assert config.enabled is False + assert config.base_url == "http://localhost:11434" + assert config.model == "llama2" + + def test_security_config_defaults(self): + """Test security config defaults.""" + config = SecurityConfig() + assert "ls" in config.allowed_commands + assert "/etc" in config.blocked_paths + assert config.max_shell_timeout == 30 + + def test_app_config_composition(self): + """Test app config with all components.""" + config = AppConfig( + server=ServerConfig(port=8080), + llm=LocalLLMConfig(enabled=True), + security=SecurityConfig(allowed_commands=["echo"]), + ) + assert config.server.port == 8080 + assert config.llm.enabled is True + assert "echo" in config.security.allowed_commands + + +class TestToolCallModels: + """Tests for tool call models.""" + + def test_tool_call_params(self): + """Test tool call parameters.""" + params = ToolCallParams( + name="read_file", + arguments={"path": "/tmp/test.txt"}, + ) + assert params.name == "read_file" + assert params.arguments["path"] == "/tmp/test.txt" + + def test_tool_call_result(self): + """Test tool call result.""" + result = ToolCallResult( + content=[{"type": "text", "text": "Hello"}], + is_error=False, + ) + assert result.content[0]["text"] == "Hello" + assert result.is_error is False + + def test_tool_call_result_error(self): + """Test tool call result with error.""" + result = ToolCallResult( + content=[], + is_error=True, + error_message="File not found", + ) + assert result.is_error is True + assert result.error_message == "File not found" + + +class TestInitializeModels: + """Tests for initialize-related models.""" + + def test_initialize_params(self): + """Test initialize parameters.""" + params = InitializeParams( + protocol_version="2024-11-05", + capabilities={"tools": {}}, + client_info={"name": "test-client"}, + ) + assert params.protocol_version == "2024-11-05" + assert params.client_info["name"] == "test-client" + + def test_initialize_result(self): + """Test initialize result.""" + result = InitializeResult( + protocol_version="2024-11-05", + server_info={"name": "mcp-server", "version": "0.1.0"}, + capabilities={"tools": {"listChanged": True}}, + ) + assert result.server_info.name == "mcp-server" + assert result.capabilities.tools["listChanged"] is True