324 lines
10 KiB
Python
324 lines
10 KiB
Python
"""Tests for tool execution engine and built-in tools."""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from mcp_server_cli.tools.base import ToolBase, ToolResult, ToolRegistry
|
|
from mcp_server_cli.tools.file_tools import (
|
|
FileTools,
|
|
ReadFileTool,
|
|
WriteFileTool,
|
|
ListDirectoryTool,
|
|
GlobFilesTool,
|
|
)
|
|
from mcp_server_cli.tools.shell_tools import ShellTools, ExecuteCommandTool
|
|
from mcp_server_cli.tools.git_tools import GitTools
|
|
from mcp_server_cli.models import ToolSchema, ToolParameter
|
|
|
|
|
|
class TestToolBase:
|
|
"""Tests for ToolBase abstract class."""
|
|
|
|
def test_tool_validation(self):
|
|
"""Test tool argument validation."""
|
|
class TestTool(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
)
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(
|
|
properties={
|
|
"name": ToolParameter(
|
|
name="name",
|
|
type="string",
|
|
required=True,
|
|
),
|
|
"count": ToolParameter(
|
|
name="count",
|
|
type="integer",
|
|
enum=["1", "2", "3"],
|
|
),
|
|
},
|
|
required=["name"],
|
|
)
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool = TestTool()
|
|
|
|
result = tool.validate_arguments({"name": "test"})
|
|
assert result["name"] == "test"
|
|
|
|
def test_missing_required_param(self):
|
|
"""Test that missing required parameters raise error."""
|
|
class TestTool(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="test_tool", description="A test tool")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(
|
|
properties={
|
|
"required_param": ToolParameter(
|
|
name="required_param",
|
|
type="string",
|
|
required=True,
|
|
),
|
|
},
|
|
required=["required_param"],
|
|
)
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool = TestTool()
|
|
|
|
with pytest.raises(ValueError, match="Missing required parameter"):
|
|
tool.validate_arguments({})
|
|
|
|
def test_invalid_enum_value(self):
|
|
"""Test that invalid enum values raise error."""
|
|
class TestTool(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="test_tool", description="A test tool")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(
|
|
properties={
|
|
"color": ToolParameter(
|
|
name="color",
|
|
type="string",
|
|
enum=["red", "green", "blue"],
|
|
),
|
|
},
|
|
)
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool = TestTool()
|
|
|
|
with pytest.raises(ValueError, match="Invalid value"):
|
|
tool.validate_arguments({"color": "yellow"})
|
|
|
|
|
|
class TestToolRegistry:
|
|
"""Tests for ToolRegistry."""
|
|
|
|
def test_register_and_get(self, tool_registry: ToolRegistry):
|
|
"""Test registering and retrieving a tool."""
|
|
class TestTool(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="test_tool", description="A test tool")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(properties={}, required=[])
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool = TestTool()
|
|
tool_registry.register(tool)
|
|
|
|
retrieved = tool_registry.get("test_tool")
|
|
assert retrieved is tool
|
|
assert retrieved.name == "test_tool"
|
|
|
|
def test_unregister(self, tool_registry: ToolRegistry):
|
|
"""Test unregistering a tool."""
|
|
class TestTool(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="test_tool", description="A test tool")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(properties={}, required=[])
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool = TestTool()
|
|
tool_registry.register(tool)
|
|
|
|
assert tool_registry.unregister("test_tool") is True
|
|
assert tool_registry.get("test_tool") is None
|
|
|
|
def test_list_tools(self, tool_registry: ToolRegistry):
|
|
"""Test listing all tools."""
|
|
class TestTool1(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="tool1", description="Tool 1")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(properties={}, required=[])
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
class TestTool2(ToolBase):
|
|
def __init__(self):
|
|
super().__init__(name="tool2", description="Tool 2")
|
|
|
|
def _create_input_schema(self) -> ToolSchema:
|
|
return ToolSchema(properties={}, required=[])
|
|
|
|
async def execute(self, arguments) -> ToolResult:
|
|
return ToolResult(success=True, output="OK")
|
|
|
|
tool_registry.register(TestTool1())
|
|
tool_registry.register(TestTool2())
|
|
|
|
tools = tool_registry.list()
|
|
assert len(tools) == 2
|
|
assert tool_registry.list_names() == ["tool1", "tool2"]
|
|
|
|
|
|
class TestFileTools:
|
|
"""Tests for file operation tools."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_file(self, temp_file: Path):
|
|
"""Test reading a file."""
|
|
tool = ReadFileTool()
|
|
result = await tool.execute({"path": str(temp_file)})
|
|
|
|
assert result.success is True
|
|
assert "Hello, World!" in result.output
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_nonexistent_file(self, temp_dir: Path):
|
|
"""Test reading a nonexistent file."""
|
|
tool = ReadFileTool()
|
|
result = await tool.execute({"path": str(temp_dir / "nonexistent.txt")})
|
|
|
|
assert result.success is False
|
|
assert "not found" in result.error.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_file(self, temp_dir: Path):
|
|
"""Test writing a file."""
|
|
tool = WriteFileTool()
|
|
result = await tool.execute({
|
|
"path": str(temp_dir / "new_file.txt"),
|
|
"content": "New content",
|
|
})
|
|
|
|
assert result.success is True
|
|
assert (temp_dir / "new_file.txt").read_text() == "New content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_directory(self, temp_dir: Path):
|
|
"""Test listing a directory."""
|
|
tool = ListDirectoryTool()
|
|
result = await tool.execute({"path": str(temp_dir)})
|
|
|
|
assert result.success is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_glob_files(self, temp_dir: Path):
|
|
"""Test glob file search."""
|
|
(temp_dir / "test1.txt").touch()
|
|
(temp_dir / "test2.txt").touch()
|
|
(temp_dir / "subdir").mkdir()
|
|
(temp_dir / "subdir" / "test3.txt").touch()
|
|
|
|
tool = GlobFilesTool()
|
|
result = await tool.execute({
|
|
"path": str(temp_dir),
|
|
"pattern": "**/*.txt",
|
|
})
|
|
|
|
assert result.success is True
|
|
assert "test1.txt" in result.output
|
|
assert "test3.txt" in result.output
|
|
|
|
|
|
class TestShellTools:
|
|
"""Tests for shell execution tools."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_ls(self):
|
|
"""Test executing ls command."""
|
|
tool = ExecuteCommandTool()
|
|
result = await tool.execute({
|
|
"cmd": ["ls", "-1"],
|
|
"timeout": 10,
|
|
})
|
|
|
|
assert result.success is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_cwd(self, temp_dir: Path):
|
|
"""Test executing command with working directory."""
|
|
tool = ExecuteCommandTool()
|
|
result = await tool.execute({
|
|
"cmd": ["pwd"],
|
|
"cwd": str(temp_dir),
|
|
})
|
|
|
|
assert result.success is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_nonexistent_command(self):
|
|
"""Test executing nonexistent command."""
|
|
tool = ExecuteCommandTool()
|
|
result = await tool.execute({
|
|
"cmd": ["nonexistent_command_12345"],
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "no such file" in result.error.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_command_timeout(self):
|
|
"""Test command timeout."""
|
|
tool = ExecuteCommandTool()
|
|
result = await tool.execute({
|
|
"cmd": ["sleep", "10"],
|
|
"timeout": 1,
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "timed out" in result.error.lower()
|
|
|
|
|
|
class TestGitTools:
|
|
"""Tests for git tools."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_git_status_not_in_repo(self, temp_dir: Path):
|
|
"""Test git status in non-git directory."""
|
|
tool = GitTools()
|
|
result = await tool.execute({
|
|
"operation": "status",
|
|
"path": str(temp_dir),
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "not in a git repository" in result.error.lower()
|
|
|
|
|
|
class TestToolResult:
|
|
"""Tests for ToolResult model."""
|
|
|
|
def test_success_result(self):
|
|
"""Test creating a success result."""
|
|
result = ToolResult(success=True, output="Test output")
|
|
assert result.success is True
|
|
assert result.output == "Test output"
|
|
assert result.error is None
|
|
|
|
def test_error_result(self):
|
|
"""Test creating an error result."""
|
|
result = ToolResult(
|
|
success=False,
|
|
output="",
|
|
error="Something went wrong",
|
|
)
|
|
assert result.success is False
|
|
assert result.error == "Something went wrong"
|