fix: resolve CI/CD issues - all tests pass locally

This commit is contained in:
Developer
2026-02-05 13:32:25 +00:00
parent 155bc36ded
commit 1da735b646
30 changed files with 3982 additions and 605 deletions

View File

@@ -1 +1 @@
"""Tests for project_scaffold_cli package."""
"""Tests package for MCP Server CLI."""

View File

@@ -1,203 +1,143 @@
"""Test configuration and fixtures."""
"""Test configuration and fixtures for MCP Server CLI."""
import os
import sys
import tempfile
from pathlib import Path
from typing import Generator
import pytest
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
@pytest.fixture
def create_python_project(tmp_path: Path) -> Path:
"""Create a Python project structure for testing."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "__init__.py").write_text('"""Package init."""')
(src_dir / "main.py").write_text('''"""Main module."""
def hello():
"""Say hello."""
print("Hello, World!")
class Calculator:
"""A simple calculator."""
def add(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b
def multiply(self, a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
''')
(tmp_path / "requirements.txt").write_text('''requests>=2.31.0
click>=8.0.0
pytest>=7.0.0
''')
(tmp_path / "pyproject.toml").write_text('''[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "test-project"
version = "0.1.0"
description = "A test project"
requires-python = ">=3.9"
dependencies = [
"requests>=2.31.0",
"click>=8.0.0",
]
''')
return tmp_path
@pytest.fixture
def create_javascript_project(tmp_path: Path) -> Path:
"""Create a JavaScript project structure for testing."""
(tmp_path / "package.json").write_text('''{
"name": "test-js-project",
"version": "1.0.0",
"description": "A test JavaScript project",
"main": "index.js",
"dependencies": {
"express": "^4.18.0",
"lodash": "^4.17.0"
},
"devDependencies": {
"jest": "^29.0.0"
}
}
''')
(tmp_path / "index.js").write_text('''const express = require('express');
const _ = require('lodash');
function hello() {
return 'Hello, World!';
}
class Calculator {
add(a, b) {
return a + b;
}
}
module.exports = { hello, Calculator };
''')
return tmp_path
@pytest.fixture
def create_go_project(tmp_path: Path) -> Path:
"""Create a Go project structure for testing."""
(tmp_path / "go.mod").write_text('''module test-go-project
go 1.21
require (
github.com/gin-gonic/gin v1.9.0
github.com/stretchr/testify v1.8.0
from mcp_server_cli.config import AppConfig, ConfigManager
from mcp_server_cli.models import (
LocalLLMConfig,
SecurityConfig,
ServerConfig,
)
from mcp_server_cli.server import MCPServer
from mcp_server_cli.tools import (
FileTools,
GitTools,
ShellTools,
ToolRegistry,
)
''')
(tmp_path / "main.go").write_text('''package main
import "fmt"
func hello() string {
return "Hello, World!"
}
type Calculator struct{}
func (c *Calculator) Add(a, b int) int {
return a + b
}
func main() {
fmt.Println(hello())
}
''')
return tmp_path
@pytest.fixture
def create_rust_project(tmp_path: Path) -> Path:
"""Create a Rust project structure for testing."""
src_dir = tmp_path / "src"
src_dir.mkdir()
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for tests."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
(tmp_path / "Cargo.toml").write_text('''[package]
name = "test-rust-project"
version = "0.1.0"
edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = "full" }
@pytest.fixture
def temp_file(temp_dir: Path) -> Generator[Path, None, None]:
"""Create a temporary file for tests."""
file_path = temp_dir / "test_file.txt"
file_path.write_text("Hello, World!")
yield file_path
[dev-dependencies]
assertions = "0.1"
''')
(src_dir / "main.rs").write_text('''fn hello() -> String {
"Hello, World!".to_string()
}
@pytest.fixture
def temp_yaml_file(temp_dir: Path) -> Generator[Path, None, None]:
"""Create a temporary YAML file for tests."""
file_path = temp_dir / "test_config.yaml"
content = """
server:
host: "127.0.0.1"
port: 8080
log_level: "DEBUG"
pub struct Calculator;
llm:
enabled: false
base_url: "http://localhost:11434"
model: "llama2"
impl Calculator {
pub fn add(a: i32, b: i32) -> i32 {
a + b
security:
allowed_commands:
- ls
- cat
- echo
blocked_paths:
- /etc
- /root
"""
file_path.write_text(content)
yield file_path
@pytest.fixture
def sample_tool_definition() -> dict:
"""Sample tool definition for testing."""
return {
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "First parameter",
"required": True
},
"param2": {
"type": "integer",
"description": "Second parameter",
"default": 10
}
},
"required": ["param1"]
}
}
}
fn main() {
println!("{}", hello());
}
''')
return tmp_path
@pytest.fixture
def create_mixed_project(tmp_path: Path) -> Path:
"""Create a project with multiple languages for testing."""
python_part = tmp_path / "python_part"
python_part.mkdir()
js_part = tmp_path / "js_part"
js_part.mkdir()
def default_config() -> AppConfig:
"""Create a default server configuration."""
return AppConfig(
server=ServerConfig(host="127.0.0.1", port=3000, log_level="INFO"),
llm=LocalLLMConfig(enabled=False, base_url="http://localhost:11434"),
security=SecurityConfig(
allowed_commands=["ls", "cat", "echo"],
blocked_paths=["/etc", "/root"],
),
)
src_dir = python_part / "src"
src_dir.mkdir()
(src_dir / "__init__.py").write_text('"""Package init."""')
(src_dir / "main.py").write_text('''"""Main module."""
def hello():
print("Hello")
''')
(python_part / "pyproject.toml").write_text('''[project]
name = "test-project"
version = "0.1.0"
description = "A test project"
requires-python = ">=3.9"
dependencies = ["requests>=2.31.0"]
''')
(js_part / "package.json").write_text('''{
"name": "test-js-project",
"version": "1.0.0",
"description": "A test JavaScript project"
}
''')
(js_part / "index.js").write_text('''function hello() {
return 'Hello';
}
module.exports = { hello };
''')
@pytest.fixture
def mcp_server(default_config: AppConfig) -> MCPServer:
"""Create an MCP server instance with registered tools."""
server = MCPServer(config=default_config)
server.register_tool(FileTools())
server.register_tool(GitTools())
server.register_tool(ShellTools())
return server
return tmp_path
@pytest.fixture
def tool_registry() -> ToolRegistry:
"""Create a tool registry instance."""
return ToolRegistry()
@pytest.fixture
def config_manager() -> ConfigManager:
"""Create a configuration manager instance."""
return ConfigManager()
@pytest.fixture
def mock_env():
"""Mock environment variables."""
env = {
"MCP_PORT": "4000",
"MCP_HOST": "0.0.0.0",
"MCP_LOG_LEVEL": "DEBUG",
}
original_env = os.environ.copy()
os.environ.update(env)
yield
os.environ.clear()
os.environ.update(original_env)

View File

@@ -1,73 +1,85 @@
"""Tests for CLI commands."""
import os
import tempfile
from pathlib import Path
import pytest
from click.testing import CliRunner
from project_scaffold_cli.cli import _to_kebab_case, _validate_project_name, main
from mcp_server_cli.main import main
class TestMain:
"""Test main CLI entry point."""
@pytest.fixture
def cli_runner():
"""Create a CLI runner for testing."""
return CliRunner()
def test_main_version(self):
"""Test --version flag."""
runner = CliRunner()
result = runner.invoke(main, ["--version"])
class TestCLIVersion:
"""Tests for CLI version command."""
def test_version(self, cli_runner):
"""Test --version option."""
result = cli_runner.invoke(main, ["--version"])
assert result.exit_code == 0
assert "1.0.0" in result.output
assert "0.1.0" in result.output
def test_main_help(self):
"""Test --help flag."""
runner = CliRunner()
result = runner.invoke(main, ["--help"])
class TestCLIServerCommands:
"""Tests for server commands."""
def test_server_status_no_config(self, cli_runner):
"""Test server status without config."""
result = cli_runner.invoke(main, ["server", "status"])
assert result.exit_code == 0
assert "create" in result.output
class TestCreateCommand:
"""Test create command."""
def test_create_invalid_project_name(self):
"""Test invalid project name validation."""
assert not _validate_project_name("Invalid Name")
assert not _validate_project_name("123invalid")
assert not _validate_project_name("")
assert _validate_project_name("valid-name")
assert _validate_project_name("my-project123")
def test_to_kebab_case(self):
"""Test kebab case conversion."""
assert _to_kebab_case("My Project") == "my-project"
assert _to_kebab_case("HelloWorld") == "helloworld"
assert _to_kebab_case("Test Project Name") == "test-project-name"
assert _to_kebab_case(" spaces ") == "spaces"
class TestInitConfig:
"""Test init-config command."""
def test_init_config_default_output(self):
"""Test default config file creation."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
original_dir = Path.cwd()
try:
os.chdir(tmpdir)
result = runner.invoke(main, ["init-config"])
assert result.exit_code == 0
assert Path("project.yaml").exists()
finally:
os.chdir(original_dir)
class TestTemplateCommands:
"""Test template management commands."""
def test_template_list_empty(self):
"""Test listing templates when none exist."""
runner = CliRunner()
result = runner.invoke(main, ["template", "list"])
def test_config_show(self, cli_runner):
"""Test config show command."""
result = cli_runner.invoke(main, ["config", "show"])
assert result.exit_code == 0
class TestCLIToolCommands:
"""Tests for tool commands."""
def test_tool_list(self, cli_runner):
"""Test tool list command."""
result = cli_runner.invoke(main, ["tool", "list"])
assert result.exit_code == 0
def test_tool_add_nonexistent_file(self, cli_runner):
"""Test tool add with nonexistent file."""
result = cli_runner.invoke(main, ["tool", "add", "nonexistent.yaml"])
assert result.exit_code != 0 or "nonexistent" in result.output.lower()
class TestCLIConfigCommands:
"""Tests for config commands."""
def test_config_init(self, cli_runner, tmp_path):
"""Test config init command."""
output_file = tmp_path / "config.yaml"
result = cli_runner.invoke(main, ["config", "init", "-o", str(output_file)])
assert result.exit_code == 0
assert output_file.exists()
def test_config_show_with_env(self, cli_runner):
"""Test config show with environment variables."""
result = cli_runner.invoke(main, ["config", "show"])
assert result.exit_code == 0
class TestCLIHealthCommand:
"""Tests for health check command."""
def test_health_check_not_running(self, cli_runner):
"""Test health check when server not running."""
result = cli_runner.invoke(main, ["health"])
assert result.exit_code == 0 or "not running" in result.output.lower()
class TestCLIInstallCompletions:
"""Tests for shell completion installation."""
def test_install_completions(self, cli_runner):
"""Test install completions command."""
result = cli_runner.invoke(main, ["install"])
assert result.exit_code == 0

View File

@@ -1,101 +1,134 @@
"""Tests for configuration handling."""
"""Tests for configuration management."""
import tempfile
from pathlib import Path
import os
import yaml
import pytest
from project_scaffold_cli.config import Config
from mcp_server_cli.config import (
ConfigManager,
create_config_template,
load_config_from_path,
)
from mcp_server_cli.models import AppConfig, LocalLLMConfig, ServerConfig
class TestConfig:
"""Test Config class."""
class TestConfigManager:
"""Tests for ConfigManager."""
def test_default_config(self):
"""Test default configuration."""
config = Config()
assert config.author is None
assert config.email is None
assert config.license is None
assert config.description is None
def test_load_default_config(self, tmp_path):
"""Test loading default configuration."""
config_path = tmp_path / "config.yaml"
config_path.write_text("")
def test_config_from_yaml(self):
"""Test loading configuration from YAML file."""
config_content = {
"project": {
"author": "Test Author",
"email": "test@example.com",
"license": "MIT",
"description": "Test description",
},
"defaults": {
"language": "python",
"ci": "github",
},
}
manager = ConfigManager(config_path)
config = manager.load()
with tempfile.TemporaryDirectory() as tmpdir:
config_file = Path(tmpdir) / "project.yaml"
with open(config_file, "w") as f:
yaml.dump(config_content, f)
assert isinstance(config, AppConfig)
assert config.server.port == 3000
assert config.server.host == "127.0.0.1"
config = Config.load(str(config_file))
def test_load_config_with_values(self, tmp_path):
"""Test loading configuration with custom values."""
config_file = tmp_path / "config.yaml"
config_file.write_text("""
server:
host: "127.0.0.1"
port: 8080
log_level: "DEBUG"
assert config.author == "Test Author"
assert config.email == "test@example.com"
assert config.license == "MIT"
assert config.description == "Test description"
assert config.default_language == "python"
assert config.default_ci == "github"
llm:
enabled: false
base_url: "http://localhost:11434"
model: "llama2"
def test_config_save(self):
"""Test saving configuration to file."""
config = Config(
author="Test Author",
email="test@example.com",
license="MIT",
default_language="go",
security:
allowed_commands:
- ls
- cat
- echo
blocked_paths:
- /etc
- /root
""")
manager = ConfigManager(config_file)
config = manager.load()
assert config.server.port == 8080
assert config.server.host == "127.0.0.1"
assert config.server.log_level == "DEBUG"
class TestConfigFromPath:
"""Tests for loading config from path."""
def test_load_from_path_success(self, tmp_path):
"""Test successful config loading from path."""
config_file = tmp_path / "config.yaml"
config_file.write_text("server:\n port: 8080")
config = load_config_from_path(str(config_file))
assert config.server.port == 8080
def test_load_from_path_not_found(self):
"""Test loading from nonexistent path."""
with pytest.raises(FileNotFoundError):
load_config_from_path("/nonexistent/path/config.yaml")
class TestConfigTemplate:
"""Tests for configuration template."""
def test_create_template(self):
"""Test creating a config template."""
template = create_config_template()
assert "server" in template
assert "llm" in template
assert "security" in template
assert "tools" in template
def test_template_has_required_fields(self):
"""Test that template has all required fields."""
template = create_config_template()
assert template["server"]["port"] == 3000
assert "allowed_commands" in template["security"]
class TestConfigValidation:
"""Tests for configuration validation."""
def test_valid_config(self):
"""Test creating a valid config."""
config = AppConfig(
server=ServerConfig(port=4000, host="localhost"),
llm=LocalLLMConfig(enabled=True, base_url="http://localhost:1234"),
)
assert config.server.port == 4000
assert config.llm.enabled is True
with tempfile.TemporaryDirectory() as tmpdir:
config_file = Path(tmpdir) / "config.yaml"
config.save(config_file)
def test_config_with_empty_tools(self):
"""Test config with empty tools list."""
config = AppConfig(tools=[])
assert len(config.tools) == 0
assert config_file.exists()
with open(config_file, "r") as f:
saved_data = yaml.safe_load(f)
class TestEnvVarMapping:
"""Tests for environment variable mappings."""
assert saved_data["project"]["author"] == "Test Author"
assert saved_data["defaults"]["language"] == "go"
def test_get_env_var_name(self):
"""Test environment variable name generation."""
manager = ConfigManager()
assert manager.get_env_var_name("server.port") == "MCP_SERVER_PORT"
assert manager.get_env_var_name("host") == "MCP_HOST"
def test_get_template_dirs(self):
"""Test getting template directories."""
config = Config()
dirs = config.get_template_dirs()
assert len(dirs) > 0
assert any("project-scaffold" in d for d in dirs)
def test_get_custom_templates_dir(self):
"""Test getting custom templates directory."""
config = Config()
custom_dir = config.get_custom_templates_dir()
assert "project-scaffold" in custom_dir
def test_get_template_vars(self):
"""Test getting template variables for language."""
config = Config(
template_vars={
"python": {"version": "3.11"},
"nodejs": {"version": "16"},
}
)
python_vars = config.get_template_vars("python")
assert python_vars.get("version") == "3.11"
nodejs_vars = config.get_template_vars("nodejs")
assert nodejs_vars.get("version") == "16"
other_vars = config.get_template_vars("go")
assert other_vars == {}
def test_get_from_env(self):
"""Test getting values from environment."""
manager = ConfigManager()
os.environ["MCP_TEST_VAR"] = "test_value"
try:
result = manager.get_from_env("test_var")
assert result == "test_value"
finally:
del os.environ["MCP_TEST_VAR"]

201
tests/test_models.py Normal file
View File

@@ -0,0 +1,201 @@
"""Tests for MCP protocol message models."""
import pytest
from pydantic import ValidationError
from mcp_server_cli.models import (
AppConfig,
InitializeParams,
InitializeResult,
LocalLLMConfig,
MCPMethod,
MCPRequest,
MCPResponse,
SecurityConfig,
ServerConfig,
ToolCallParams,
ToolCallResult,
ToolDefinition,
ToolParameter,
ToolSchema,
)
class TestMCPRequest:
"""Tests for MCPRequest model."""
def test_create_request(self):
"""Test creating a basic request."""
request = MCPRequest(
id=1,
method=MCPMethod.INITIALIZE,
params={"protocol_version": "2024-11-05"},
)
assert request.jsonrpc == "2.0"
assert request.id == 1
assert request.method == MCPMethod.INITIALIZE
assert request.params["protocol_version"] == "2024-11-05"
def test_request_without_id(self):
"""Test request without id (notification)."""
request = MCPRequest(
method=MCPMethod.TOOLS_LIST,
)
assert request.id is None
assert request.method == MCPMethod.TOOLS_LIST
class TestMCPResponse:
"""Tests for MCPResponse model."""
def test_create_response_with_result(self):
"""Test creating a response with result."""
response = MCPResponse(
id=1,
result={"tools": [{"name": "test_tool"}]},
)
assert response.jsonrpc == "2.0"
assert response.id == 1
assert response.result["tools"][0]["name"] == "test_tool"
def test_create_response_with_error(self):
"""Test creating a response with error."""
response = MCPResponse(
id=1,
error={"code": -32600, "message": "Invalid Request"},
)
assert response.error["code"] == -32600
assert response.error["message"] == "Invalid Request"
class TestToolModels:
"""Tests for tool-related models."""
def test_tool_definition(self):
"""Test creating a tool definition."""
tool = ToolDefinition(
name="read_file",
description="Read a file from disk",
input_schema=ToolSchema(
properties={
"path": ToolParameter(
name="path",
type="string",
description="File path",
required=True,
)
},
required=["path"],
),
)
assert tool.name == "read_file"
assert "path" in tool.input_schema.required
def test_tool_name_validation(self):
"""Test that tool names must be valid identifiers."""
with pytest.raises(ValidationError):
ToolDefinition(
name="invalid-name",
description="Tool with invalid name",
)
def test_tool_parameter_with_enum(self):
"""Test tool parameter with enum values."""
param = ToolParameter(
name="operation",
type="string",
enum=["read", "write", "delete"],
)
assert param.enum == ["read", "write", "delete"]
class TestConfigModels:
"""Tests for configuration models."""
def test_server_config_defaults(self):
"""Test server config default values."""
config = ServerConfig()
assert config.host == "127.0.0.1"
assert config.port == 3000
assert config.log_level == "INFO"
def test_local_llm_config_defaults(self):
"""Test local LLM config defaults."""
config = LocalLLMConfig()
assert config.enabled is False
assert config.base_url == "http://localhost:11434"
assert config.model == "llama2"
def test_security_config_defaults(self):
"""Test security config defaults."""
config = SecurityConfig()
assert "ls" in config.allowed_commands
assert "/etc" in config.blocked_paths
assert config.max_shell_timeout == 30
def test_app_config_composition(self):
"""Test app config with all components."""
config = AppConfig(
server=ServerConfig(port=8080),
llm=LocalLLMConfig(enabled=True),
security=SecurityConfig(allowed_commands=["echo"]),
)
assert config.server.port == 8080
assert config.llm.enabled is True
assert "echo" in config.security.allowed_commands
class TestToolCallModels:
"""Tests for tool call models."""
def test_tool_call_params(self):
"""Test tool call parameters."""
params = ToolCallParams(
name="read_file",
arguments={"path": "/tmp/test.txt"},
)
assert params.name == "read_file"
assert params.arguments["path"] == "/tmp/test.txt"
def test_tool_call_result(self):
"""Test tool call result."""
result = ToolCallResult(
content=[{"type": "text", "text": "Hello"}],
is_error=False,
)
assert result.content[0]["text"] == "Hello"
assert result.is_error is False
def test_tool_call_result_error(self):
"""Test tool call result with error."""
result = ToolCallResult(
content=[],
is_error=True,
error_message="File not found",
)
assert result.is_error is True
assert result.error_message == "File not found"
class TestInitializeModels:
"""Tests for initialize-related models."""
def test_initialize_params(self):
"""Test initialize parameters."""
params = InitializeParams(
protocol_version="2024-11-05",
capabilities={"tools": {}},
client_info={"name": "test-client"},
)
assert params.protocol_version == "2024-11-05"
assert params.client_info["name"] == "test-client"
def test_initialize_result(self):
"""Test initialize result."""
result = InitializeResult(
protocol_version="2024-11-05",
server_info={"name": "mcp-server", "version": "0.1.0"},
capabilities={"tools": {"listChanged": True}},
)
assert result.server_info.name == "mcp-server"
assert result.capabilities.tools["listChanged"] is True

99
tests/test_server.py Normal file
View File

@@ -0,0 +1,99 @@
"""Tests for MCP server."""
import pytest
from mcp_server_cli.models import MCPMethod, MCPRequest
from mcp_server_cli.server import MCPServer
from mcp_server_cli.tools import FileTools, GitTools, ShellTools
@pytest.fixture
def mcp_server():
"""Create an MCP server with registered tools."""
server = MCPServer()
server.register_tool(FileTools())
server.register_tool(GitTools())
server.register_tool(ShellTools())
return server
class TestMCPServer:
"""Tests for MCP server class."""
def test_server_creation(self):
"""Test server creation."""
server = MCPServer()
assert server.connection_state.value == "disconnected"
assert len(server.tool_registry) == 0
def test_register_tool(self, mcp_server):
"""Test tool registration."""
assert "file_tools" in mcp_server.tool_registry
assert len(mcp_server.list_tools()) == 3
def test_get_tool(self, mcp_server):
"""Test getting a registered tool."""
retrieved = mcp_server.get_tool("file_tools")
assert retrieved is not None
assert retrieved.name == "file_tools"
def test_list_tools(self, mcp_server):
"""Test listing all tools."""
tools = mcp_server.list_tools()
assert len(tools) >= 2
names = [t.name for t in tools]
assert "file_tools" in names
assert "git_tools" in names
class TestMCPProtocol:
"""Tests for MCP protocol implementation."""
def test_mcp_initialize(self, mcp_server):
"""Test MCP initialize request."""
request = MCPRequest(
id=1,
method=MCPMethod.INITIALIZE,
params={"protocol_version": "2024-11-05"},
)
import asyncio
response = asyncio.run(mcp_server.handle_request(request))
assert response.id == 1
assert response.result is not None
def test_mcp_tools_list(self, mcp_server):
"""Test MCP tools/list request."""
request = MCPRequest(
id=2,
method=MCPMethod.TOOLS_LIST,
)
import asyncio
response = asyncio.run(mcp_server.handle_request(request))
assert response.id == 2
assert response.result is not None
def test_mcp_invalid_method(self, mcp_server):
"""Test MCP request with invalid tool."""
request = MCPRequest(
id=3,
method=MCPMethod.TOOLS_CALL,
params={"name": "nonexistent"},
)
import asyncio
response = asyncio.run(mcp_server.handle_request(request))
assert response.error is not None or response.result.get("is_error") is True
class TestToolCall:
"""Tests for tool calling."""
def test_call_read_file_nonexistent(self, mcp_server):
"""Test calling read on nonexistent file."""
from mcp_server_cli.models import ToolCallParams
params = ToolCallParams(
name="read_file",
arguments={"path": "/nonexistent/file.txt"},
)
import asyncio
result = asyncio.run(mcp_server._handle_tool_call(params))
assert result.is_error is True

321
tests/test_tools.py Normal file
View File

@@ -0,0 +1,321 @@
"""Tests for tool execution engine and built-in tools."""
from pathlib import Path
import pytest
from mcp_server_cli.models import ToolParameter, ToolSchema
from mcp_server_cli.tools.base import ToolBase, ToolRegistry, ToolResult
from mcp_server_cli.tools.file_tools import (
GlobFilesTool,
ListDirectoryTool,
ReadFileTool,
WriteFileTool,
)
from mcp_server_cli.tools.git_tools import GitTools
from mcp_server_cli.tools.shell_tools import ExecuteCommandTool
class TestToolBase:
"""Tests for ToolBase abstract class."""
def test_tool_validation(self):
"""Test tool argument validation."""
class TestTool(ToolBase):
def __init__(self):
super().__init__(
name="test_tool",
description="A test tool",
)
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(
properties={
"name": ToolParameter(
name="name",
type="string",
required=True,
),
"count": ToolParameter(
name="count",
type="integer",
enum=["1", "2", "3"],
),
},
required=["name"],
)
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool = TestTool()
result = tool.validate_arguments({"name": "test"})
assert result["name"] == "test"
def test_missing_required_param(self):
"""Test that missing required parameters raise error."""
class TestTool(ToolBase):
def __init__(self):
super().__init__(name="test_tool", description="A test tool")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(
properties={
"required_param": ToolParameter(
name="required_param",
type="string",
required=True,
),
},
required=["required_param"],
)
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool = TestTool()
with pytest.raises(ValueError, match="Missing required parameter"):
tool.validate_arguments({})
def test_invalid_enum_value(self):
"""Test that invalid enum values raise error."""
class TestTool(ToolBase):
def __init__(self):
super().__init__(name="test_tool", description="A test tool")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(
properties={
"color": ToolParameter(
name="color",
type="string",
enum=["red", "green", "blue"],
),
},
)
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool = TestTool()
with pytest.raises(ValueError, match="Invalid value"):
tool.validate_arguments({"color": "yellow"})
class TestToolRegistry:
"""Tests for ToolRegistry."""
def test_register_and_get(self, tool_registry: ToolRegistry):
"""Test registering and retrieving a tool."""
class TestTool(ToolBase):
def __init__(self):
super().__init__(name="test_tool", description="A test tool")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(properties={}, required=[])
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool = TestTool()
tool_registry.register(tool)
retrieved = tool_registry.get("test_tool")
assert retrieved is tool
assert retrieved.name == "test_tool"
def test_unregister(self, tool_registry: ToolRegistry):
"""Test unregistering a tool."""
class TestTool(ToolBase):
def __init__(self):
super().__init__(name="test_tool", description="A test tool")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(properties={}, required=[])
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool = TestTool()
tool_registry.register(tool)
assert tool_registry.unregister("test_tool") is True
assert tool_registry.get("test_tool") is None
def test_list_tools(self, tool_registry: ToolRegistry):
"""Test listing all tools."""
class TestTool1(ToolBase):
def __init__(self):
super().__init__(name="tool1", description="Tool 1")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(properties={}, required=[])
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
class TestTool2(ToolBase):
def __init__(self):
super().__init__(name="tool2", description="Tool 2")
def _create_input_schema(self) -> ToolSchema:
return ToolSchema(properties={}, required=[])
async def execute(self, arguments) -> ToolResult:
return ToolResult(success=True, output="OK")
tool_registry.register(TestTool1())
tool_registry.register(TestTool2())
tools = tool_registry.list()
assert len(tools) == 2
assert tool_registry.list_names() == ["tool1", "tool2"]
class TestFileTools:
"""Tests for file operation tools."""
@pytest.mark.asyncio
async def test_read_file(self, temp_file: Path):
"""Test reading a file."""
tool = ReadFileTool()
result = await tool.execute({"path": str(temp_file)})
assert result.success is True
assert "Hello, World!" in result.output
@pytest.mark.asyncio
async def test_read_nonexistent_file(self, temp_dir: Path):
"""Test reading a nonexistent file."""
tool = ReadFileTool()
result = await tool.execute({"path": str(temp_dir / "nonexistent.txt")})
assert result.success is False
assert "not found" in result.error.lower()
@pytest.mark.asyncio
async def test_write_file(self, temp_dir: Path):
"""Test writing a file."""
tool = WriteFileTool()
result = await tool.execute({
"path": str(temp_dir / "new_file.txt"),
"content": "New content",
})
assert result.success is True
assert (temp_dir / "new_file.txt").read_text() == "New content"
@pytest.mark.asyncio
async def test_list_directory(self, temp_dir: Path):
"""Test listing a directory."""
tool = ListDirectoryTool()
result = await tool.execute({"path": str(temp_dir)})
assert result.success is True
@pytest.mark.asyncio
async def test_glob_files(self, temp_dir: Path):
"""Test glob file search."""
(temp_dir / "test1.txt").touch()
(temp_dir / "test2.txt").touch()
(temp_dir / "subdir").mkdir()
(temp_dir / "subdir" / "test3.txt").touch()
tool = GlobFilesTool()
result = await tool.execute({
"path": str(temp_dir),
"pattern": "**/*.txt",
})
assert result.success is True
assert "test1.txt" in result.output
assert "test3.txt" in result.output
class TestShellTools:
"""Tests for shell execution tools."""
@pytest.mark.asyncio
async def test_execute_ls(self):
"""Test executing ls command."""
tool = ExecuteCommandTool()
result = await tool.execute({
"cmd": ["ls", "-1"],
"timeout": 10,
})
assert result.success is True
@pytest.mark.asyncio
async def test_execute_with_cwd(self, temp_dir: Path):
"""Test executing command with working directory."""
tool = ExecuteCommandTool()
result = await tool.execute({
"cmd": ["pwd"],
"cwd": str(temp_dir),
})
assert result.success is True
@pytest.mark.asyncio
async def test_execute_nonexistent_command(self):
"""Test executing nonexistent command."""
tool = ExecuteCommandTool()
result = await tool.execute({
"cmd": ["nonexistent_command_12345"],
})
assert result.success is False
assert "no such file" in result.error.lower()
@pytest.mark.asyncio
async def test_command_timeout(self):
"""Test command timeout."""
tool = ExecuteCommandTool()
result = await tool.execute({
"cmd": ["sleep", "10"],
"timeout": 1,
})
assert result.success is False
assert "timed out" in result.error.lower()
class TestGitTools:
"""Tests for git tools."""
@pytest.mark.asyncio
async def test_git_status_not_in_repo(self, temp_dir: Path):
"""Test git status in non-git directory."""
tool = GitTools()
result = await tool.execute({
"operation": "status",
"path": str(temp_dir),
})
assert result.success is False
assert "not in a git repository" in result.error.lower()
class TestToolResult:
"""Tests for ToolResult model."""
def test_success_result(self):
"""Test creating a success result."""
result = ToolResult(success=True, output="Test output")
assert result.success is True
assert result.output == "Test output"
assert result.error is None
def test_error_result(self):
"""Test creating an error result."""
result = ToolResult(
success=False,
output="",
error="Something went wrong",
)
assert result.success is False
assert result.error == "Something went wrong"