diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..bfd8bb7 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,323 @@ +"""Tests for tool execution engine and built-in tools.""" + +import pytest +import asyncio +from pathlib import Path +from unittest.mock import patch, MagicMock + +from mcp_server_cli.tools.base import ToolBase, ToolResult, ToolRegistry +from mcp_server_cli.tools.file_tools import ( + FileTools, + ReadFileTool, + WriteFileTool, + ListDirectoryTool, + GlobFilesTool, +) +from mcp_server_cli.tools.shell_tools import ShellTools, ExecuteCommandTool +from mcp_server_cli.tools.git_tools import GitTools +from mcp_server_cli.models import ToolSchema, ToolParameter + + +class TestToolBase: + """Tests for ToolBase abstract class.""" + + def test_tool_validation(self): + """Test tool argument validation.""" + class TestTool(ToolBase): + def __init__(self): + super().__init__( + name="test_tool", + description="A test tool", + ) + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema( + properties={ + "name": ToolParameter( + name="name", + type="string", + required=True, + ), + "count": ToolParameter( + name="count", + type="integer", + enum=["1", "2", "3"], + ), + }, + required=["name"], + ) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool = TestTool() + + result = tool.validate_arguments({"name": "test"}) + assert result["name"] == "test" + + def test_missing_required_param(self): + """Test that missing required parameters raise error.""" + class TestTool(ToolBase): + def __init__(self): + super().__init__(name="test_tool", description="A test tool") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema( + properties={ + "required_param": ToolParameter( + name="required_param", + type="string", + required=True, + ), + }, + required=["required_param"], + ) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool = TestTool() + + with pytest.raises(ValueError, match="Missing required parameter"): + tool.validate_arguments({}) + + def test_invalid_enum_value(self): + """Test that invalid enum values raise error.""" + class TestTool(ToolBase): + def __init__(self): + super().__init__(name="test_tool", description="A test tool") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema( + properties={ + "color": ToolParameter( + name="color", + type="string", + enum=["red", "green", "blue"], + ), + }, + ) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool = TestTool() + + with pytest.raises(ValueError, match="Invalid value"): + tool.validate_arguments({"color": "yellow"}) + + +class TestToolRegistry: + """Tests for ToolRegistry.""" + + def test_register_and_get(self, tool_registry: ToolRegistry): + """Test registering and retrieving a tool.""" + class TestTool(ToolBase): + def __init__(self): + super().__init__(name="test_tool", description="A test tool") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema(properties={}, required=[]) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool = TestTool() + tool_registry.register(tool) + + retrieved = tool_registry.get("test_tool") + assert retrieved is tool + assert retrieved.name == "test_tool" + + def test_unregister(self, tool_registry: ToolRegistry): + """Test unregistering a tool.""" + class TestTool(ToolBase): + def __init__(self): + super().__init__(name="test_tool", description="A test tool") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema(properties={}, required=[]) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool = TestTool() + tool_registry.register(tool) + + assert tool_registry.unregister("test_tool") is True + assert tool_registry.get("test_tool") is None + + def test_list_tools(self, tool_registry: ToolRegistry): + """Test listing all tools.""" + class TestTool1(ToolBase): + def __init__(self): + super().__init__(name="tool1", description="Tool 1") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema(properties={}, required=[]) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + class TestTool2(ToolBase): + def __init__(self): + super().__init__(name="tool2", description="Tool 2") + + def _create_input_schema(self) -> ToolSchema: + return ToolSchema(properties={}, required=[]) + + async def execute(self, arguments) -> ToolResult: + return ToolResult(success=True, output="OK") + + tool_registry.register(TestTool1()) + tool_registry.register(TestTool2()) + + tools = tool_registry.list() + assert len(tools) == 2 + assert tool_registry.list_names() == ["tool1", "tool2"] + + +class TestFileTools: + """Tests for file operation tools.""" + + @pytest.mark.asyncio + async def test_read_file(self, temp_file: Path): + """Test reading a file.""" + tool = ReadFileTool() + result = await tool.execute({"path": str(temp_file)}) + + assert result.success is True + assert "Hello, World!" in result.output + + @pytest.mark.asyncio + async def test_read_nonexistent_file(self, temp_dir: Path): + """Test reading a nonexistent file.""" + tool = ReadFileTool() + result = await tool.execute({"path": str(temp_dir / "nonexistent.txt")}) + + assert result.success is False + assert "not found" in result.error.lower() + + @pytest.mark.asyncio + async def test_write_file(self, temp_dir: Path): + """Test writing a file.""" + tool = WriteFileTool() + result = await tool.execute({ + "path": str(temp_dir / "new_file.txt"), + "content": "New content", + }) + + assert result.success is True + assert (temp_dir / "new_file.txt").read_text() == "New content" + + @pytest.mark.asyncio + async def test_list_directory(self, temp_dir: Path): + """Test listing a directory.""" + tool = ListDirectoryTool() + result = await tool.execute({"path": str(temp_dir)}) + + assert result.success is True + + @pytest.mark.asyncio + async def test_glob_files(self, temp_dir: Path): + """Test glob file search.""" + (temp_dir / "test1.txt").touch() + (temp_dir / "test2.txt").touch() + (temp_dir / "subdir").mkdir() + (temp_dir / "subdir" / "test3.txt").touch() + + tool = GlobFilesTool() + result = await tool.execute({ + "path": str(temp_dir), + "pattern": "**/*.txt", + }) + + assert result.success is True + assert "test1.txt" in result.output + assert "test3.txt" in result.output + + +class TestShellTools: + """Tests for shell execution tools.""" + + @pytest.mark.asyncio + async def test_execute_ls(self): + """Test executing ls command.""" + tool = ExecuteCommandTool() + result = await tool.execute({ + "cmd": ["ls", "-1"], + "timeout": 10, + }) + + assert result.success is True + + @pytest.mark.asyncio + async def test_execute_with_cwd(self, temp_dir: Path): + """Test executing command with working directory.""" + tool = ExecuteCommandTool() + result = await tool.execute({ + "cmd": ["pwd"], + "cwd": str(temp_dir), + }) + + assert result.success is True + + @pytest.mark.asyncio + async def test_execute_nonexistent_command(self): + """Test executing nonexistent command.""" + tool = ExecuteCommandTool() + result = await tool.execute({ + "cmd": ["nonexistent_command_12345"], + }) + + assert result.success is False + assert "no such file" in result.error.lower() + + @pytest.mark.asyncio + async def test_command_timeout(self): + """Test command timeout.""" + tool = ExecuteCommandTool() + result = await tool.execute({ + "cmd": ["sleep", "10"], + "timeout": 1, + }) + + assert result.success is False + assert "timed out" in result.error.lower() + + +class TestGitTools: + """Tests for git tools.""" + + @pytest.mark.asyncio + async def test_git_status_not_in_repo(self, temp_dir: Path): + """Test git status in non-git directory.""" + tool = GitTools() + result = await tool.execute({ + "operation": "status", + "path": str(temp_dir), + }) + + assert result.success is False + assert "not in a git repository" in result.error.lower() + + +class TestToolResult: + """Tests for ToolResult model.""" + + def test_success_result(self): + """Test creating a success result.""" + result = ToolResult(success=True, output="Test output") + assert result.success is True + assert result.output == "Test output" + assert result.error is None + + def test_error_result(self): + """Test creating an error result.""" + result = ToolResult( + success=False, + output="", + error="Something went wrong", + ) + assert result.success is False + assert result.error == "Something went wrong"