Re-upload: CI infrastructure issue resolved, all tests verified passing
This commit is contained in:
3
src/api_mock_cli/__init__.py
Normal file
3
src/api_mock_cli/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from api_mock_cli.version import __version__
|
||||
|
||||
__all__ = ["__version__"]
|
||||
285
src/api_mock_cli/cli.py
Normal file
285
src/api_mock_cli/cli.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from api_mock_cli.core.har_parser import HARParser, HARParserError, parse_browser_network_export
|
||||
from api_mock_cli.core.mock_generator import MockGenerator
|
||||
from api_mock_cli.core.server import create_mock_server_from_har
|
||||
from api_mock_cli.utils.file_utils import read_file, write_file
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version="0.1.0")
|
||||
def cli() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("input_file", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
type=click.Path(),
|
||||
default=None,
|
||||
help="Output file for captured data",
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"-f",
|
||||
type=click.Choice(["har", "network"]),
|
||||
default="har",
|
||||
help="Input file format",
|
||||
)
|
||||
def capture(input_file: str, output: str | None, format: str) -> None:
|
||||
console.print(f"[bold blue]Capturing traffic from:[/bold blue] {input_file}")
|
||||
|
||||
try:
|
||||
if format == "network":
|
||||
data = json.loads(read_file(input_file))
|
||||
parse_result = parse_browser_network_export(data)
|
||||
else:
|
||||
parser = HARParser(har_file_path=input_file)
|
||||
parse_result = parser.parse()
|
||||
|
||||
console.print(f"[green]Successfully parsed {parse_result.entry_count} entries[/green]")
|
||||
console.print(f"[yellow]Skipped {parse_result.skipped_count} invalid entries[/yellow]")
|
||||
|
||||
if output is None:
|
||||
output = "captured_traffic.json"
|
||||
|
||||
output_data = {
|
||||
"base_url": parse_result.base_url,
|
||||
"entry_count": parse_result.entry_count,
|
||||
"requests": [
|
||||
{
|
||||
"method": req.method,
|
||||
"url": req.url,
|
||||
"headers": req.headers,
|
||||
"query_params": req.query_params,
|
||||
"body": req.body,
|
||||
"status_code": req.status_code,
|
||||
"response_body": req.response_body,
|
||||
}
|
||||
for req in parse_result.requests
|
||||
],
|
||||
}
|
||||
|
||||
write_file(output, json.dumps(output_data, indent=2))
|
||||
console.print(f"[bold green]Captured data saved to:[/bold green] {output}")
|
||||
|
||||
except HARParserError as e:
|
||||
console.print(f"[bold red]Error parsing HAR file:[/bold red] {str(e)}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error:[/bold red] {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("input_file", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
type=click.Path(),
|
||||
default=None,
|
||||
help="Output file for mock server code",
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"-f",
|
||||
type=click.Choice(["har", "network", "captured"]),
|
||||
default="har",
|
||||
help="Input file format",
|
||||
)
|
||||
def generate(input_file: str, output: str | None, format: str) -> None:
|
||||
console.print(f"[bold blue]Generating mock server from:[/bold blue] {input_file}")
|
||||
|
||||
try:
|
||||
if format == "captured":
|
||||
data = json.loads(read_file(input_file))
|
||||
requests = data.get("requests", [])
|
||||
base_url = data.get("base_url", "")
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from api_mock_cli.core.har_parser import UnifiedRequest
|
||||
|
||||
@dataclass
|
||||
class SimpleParseResult:
|
||||
requests: list
|
||||
base_url: str
|
||||
entry_count: int
|
||||
skipped_count: int
|
||||
|
||||
parse_result = SimpleParseResult(
|
||||
requests=[UnifiedRequest(**r) for r in requests],
|
||||
base_url=base_url,
|
||||
entry_count=len(requests),
|
||||
skipped_count=0,
|
||||
)
|
||||
elif format == "network":
|
||||
data = json.loads(read_file(input_file))
|
||||
parse_result = parse_browser_network_export(data)
|
||||
else:
|
||||
parser = HARParser(har_file_path=input_file)
|
||||
parse_result = parser.parse()
|
||||
|
||||
generator = MockGenerator(parse_result)
|
||||
routes = generator.get_route_summary()
|
||||
|
||||
table = Table(title="Generated Routes")
|
||||
table.add_column("Method", style="cyan")
|
||||
table.add_column("Route", style="green")
|
||||
table.add_column("Status", style="yellow")
|
||||
|
||||
for route in routes:
|
||||
table.add_row(route["method"], route["route"], route["status"])
|
||||
|
||||
console.print(table)
|
||||
|
||||
if output is None:
|
||||
output = "mock_server.py"
|
||||
|
||||
code = generator.generate_app()
|
||||
write_file(output, code)
|
||||
|
||||
console.print(f"[bold green]Mock server generated:[/bold green] {output}")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error generating mock server:[/bold red] {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("input_file", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host to bind the server to",
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Port to bind the server to",
|
||||
)
|
||||
@click.option(
|
||||
"--debug",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Enable Flask debug mode",
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"-f",
|
||||
type=click.Choice(["har", "network", "captured"]),
|
||||
default="har",
|
||||
help="Input file format",
|
||||
)
|
||||
def serve(
|
||||
input_file: str, host: str, port: int, debug: bool, format: str
|
||||
) -> None:
|
||||
console.print(f"[bold blue]Starting mock server from:[/bold blue] {input_file}")
|
||||
console.print(f"[blue]Server will listen on {host}:{port}[/blue]")
|
||||
|
||||
try:
|
||||
if format == "captured":
|
||||
data = json.loads(read_file(input_file))
|
||||
requests = data.get("requests", [])
|
||||
base_url = data.get("base_url", "")
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from api_mock_cli.core.har_parser import UnifiedRequest
|
||||
|
||||
@dataclass
|
||||
class SimpleParseResult:
|
||||
requests: list
|
||||
base_url: str
|
||||
entry_count: int
|
||||
skipped_count: int
|
||||
|
||||
parse_result = SimpleParseResult(
|
||||
requests=[UnifiedRequest(**r) for r in requests],
|
||||
base_url=base_url,
|
||||
entry_count=len(requests),
|
||||
skipped_count=0,
|
||||
)
|
||||
elif format == "network":
|
||||
data = json.loads(read_file(input_file))
|
||||
parse_result = parse_browser_network_export(data)
|
||||
else:
|
||||
parser = HARParser(har_file_path=input_file)
|
||||
parse_result = parser.parse()
|
||||
|
||||
server = create_mock_server_from_har(
|
||||
parse_result, host=host, port=port, debug=debug
|
||||
)
|
||||
app = server.create_app()
|
||||
|
||||
console.print("[bold green]Mock server is running![/bold green]")
|
||||
console.print("[yellow]Press Ctrl+C to stop[/yellow]")
|
||||
|
||||
app.run(host=host, port=port, debug=debug)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error starting server:[/bold red] {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("mock_server_file", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host to bind the server to",
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Port to bind the server to",
|
||||
)
|
||||
@click.option(
|
||||
"--debug",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Enable Flask debug mode",
|
||||
)
|
||||
def run(mock_server_file: str, host: str, port: int, debug: bool) -> None:
|
||||
console.print(f"[bold blue]Running mock server from:[/bold blue] {mock_server_file}")
|
||||
console.print(f"[blue]Server will listen on {host}:{port}[/blue]")
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("mock_server", mock_server_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if hasattr(module, "create_app"):
|
||||
app = module.create_app()
|
||||
console.print("[bold green]Mock server is running![/bold green]")
|
||||
console.print("[yellow]Press Ctrl+C to stop[/yellow]")
|
||||
app.run(host=host, port=port, debug=debug)
|
||||
else:
|
||||
console.print("[bold red]Error:[/bold red] No create_app() function found in mock server file")
|
||||
sys.exit(1)
|
||||
else:
|
||||
console.print("[bold red]Error:[/bold red] Could not load mock server file")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error running mock server:[/bold red] {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
0
src/api_mock_cli/core/__init__.py
Normal file
0
src/api_mock_cli/core/__init__.py
Normal file
196
src/api_mock_cli/core/data_generator.py
Normal file
196
src/api_mock_cli/core/data_generator.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from faker import Faker
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldInfo:
|
||||
name: str
|
||||
value: Any
|
||||
field_type: str
|
||||
|
||||
|
||||
class FakeDataGenerator:
|
||||
EMAIL_PATTERN = re.compile(
|
||||
r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", re.IGNORECASE
|
||||
)
|
||||
UUID_PATTERN = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
DATE_PATTERN = re.compile(
|
||||
r"^\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)?$"
|
||||
)
|
||||
URL_PATTERN = re.compile(r"^https?://")
|
||||
PHONE_PATTERN = re.compile(r"^\+?[\d\s\-\(\)]{10,}$")
|
||||
CREDIT_CARD_PATTERN = re.compile(r"^\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}$")
|
||||
IP_ADDRESS_PATTERN = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
|
||||
ZIP_CODE_PATTERN = re.compile(r"^\d{5}(?:-\d{4})?$")
|
||||
|
||||
def __init__(self, locale: str = "en_US"):
|
||||
self.fake = Faker(locale)
|
||||
|
||||
def _detect_field_type(self, field_name: str, value: Any) -> str:
|
||||
field_lower = field_name.lower()
|
||||
|
||||
if "email" in field_lower:
|
||||
return "email"
|
||||
if "name" in field_lower and ("first" in field_lower or "last" in field_lower):
|
||||
return "name"
|
||||
if "name" in field_lower and "company" not in field_lower:
|
||||
return "name"
|
||||
if "full_name" in field_lower or "username" in field_lower:
|
||||
return "name"
|
||||
if "phone" in field_lower or "mobile" in field_lower:
|
||||
return "phone"
|
||||
if "address" in field_lower:
|
||||
return "address"
|
||||
if "city" in field_lower:
|
||||
return "city"
|
||||
if "state" in field_lower:
|
||||
return "state"
|
||||
if "country" in field_lower:
|
||||
return "country"
|
||||
if "zip" in field_lower or "postal" in field_lower:
|
||||
return "zip_code"
|
||||
if "avatar" in field_lower or "image" in field_lower or "photo" in field_lower:
|
||||
return "image_url"
|
||||
if "url" in field_lower or "link" in field_lower:
|
||||
return "url"
|
||||
if "uuid" in field_lower or "guid" in field_lower:
|
||||
return "uuid"
|
||||
if "id" in field_lower and field_lower.endswith("id"):
|
||||
return "id"
|
||||
if "created" in field_lower or "updated" in field_lower or "date" in field_lower:
|
||||
return "date"
|
||||
if "time" in field_lower or "timestamp" in field_lower:
|
||||
return "datetime"
|
||||
if "price" in field_lower or "cost" in field_lower or "amount" in field_lower:
|
||||
return "price"
|
||||
if "description" in field_lower or "bio" in field_lower or "summary" in field_lower:
|
||||
return "text"
|
||||
if "body" in field_lower or "content" in field_lower:
|
||||
return "text"
|
||||
if "title" in field_lower or "subject" in field_lower:
|
||||
return "text"
|
||||
if "token" in field_lower or "key" in field_lower:
|
||||
return "token"
|
||||
if "password" in field_lower:
|
||||
return "password"
|
||||
if "ip" in field_lower or "ip_address" in field_lower:
|
||||
return "ip_address"
|
||||
if "credit_card" in field_lower or "card_number" in field_lower:
|
||||
return "credit_card"
|
||||
if "company" in field_lower or "organization" in field_lower:
|
||||
return "company"
|
||||
if "job" in field_lower or "occupation" in field_lower:
|
||||
return "job"
|
||||
if isinstance(value, bool):
|
||||
return "boolean"
|
||||
if isinstance(value, int):
|
||||
return "integer"
|
||||
if isinstance(value, float):
|
||||
return "float"
|
||||
if isinstance(value, str):
|
||||
if self.EMAIL_PATTERN.match(value):
|
||||
return "email"
|
||||
if self.UUID_PATTERN.match(value):
|
||||
return "uuid"
|
||||
if self.DATE_PATTERN.match(value):
|
||||
return "date"
|
||||
if self.URL_PATTERN.match(value):
|
||||
return "url"
|
||||
if self.PHONE_PATTERN.match(value):
|
||||
return "phone"
|
||||
if self.IP_ADDRESS_PATTERN.match(value):
|
||||
return "ip_address"
|
||||
if self.CREDIT_CARD_PATTERN.match(value):
|
||||
return "credit_card"
|
||||
if self.ZIP_CODE_PATTERN.match(value):
|
||||
return "zip_code"
|
||||
if value.isdigit():
|
||||
return "id"
|
||||
return "unknown"
|
||||
|
||||
def _generate_scalar(self, field_type: str) -> Any:
|
||||
generators: dict[str, callable] = {
|
||||
"email": lambda: self.fake.email(),
|
||||
"name": lambda: self.fake.name(),
|
||||
"phone": lambda: self.fake.phone_number(),
|
||||
"address": lambda: self.fake.address(),
|
||||
"city": lambda: self.fake.city(),
|
||||
"state": lambda: self.fake.state(),
|
||||
"country": lambda: self.fake.country(),
|
||||
"zip_code": lambda: self.fake.zipcode(),
|
||||
"image_url": lambda: self.fake.image_url(),
|
||||
"url": lambda: self.fake.url(),
|
||||
"uuid": lambda: str(self.fake.uuid4()),
|
||||
"id": lambda: str(self.fake.random_number(digits=8)),
|
||||
"date": lambda: self.fake.date(),
|
||||
"datetime": lambda: self.fake.iso8601(),
|
||||
"price": lambda: round(self.fake.random_number(digits=4) + self.fake.random.random(), 2),
|
||||
"text": lambda: self.fake.text(max_nb_chars=200),
|
||||
"token": lambda: self.fake.sha256(),
|
||||
"password": lambda: self.fake.password(length=16),
|
||||
"ip_address": lambda: self.fake.ipv4(),
|
||||
"credit_card": lambda: self.fake.credit_card_number(),
|
||||
"company": lambda: self.fake.company(),
|
||||
"job": lambda: self.fake.job(),
|
||||
"boolean": lambda: self.fake.boolean(),
|
||||
"integer": lambda: self.fake.random_number(digits=6),
|
||||
"float": lambda: round(self.fake.random.random() * 100, 2),
|
||||
"unknown": lambda: self.fake.word(),
|
||||
}
|
||||
return generators.get(field_type, generators["unknown"])()
|
||||
|
||||
def _generate_nested(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return self.generate_from_dict(value)
|
||||
elif isinstance(value, list):
|
||||
return [self._generate_nested(item) for item in value]
|
||||
else:
|
||||
return self._generate_scalar(self._detect_field_type("unknown", value))
|
||||
|
||||
def generate_from_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = self.generate_from_dict(value)
|
||||
elif isinstance(value, list):
|
||||
if not value:
|
||||
result[key] = []
|
||||
else:
|
||||
result[key] = [self._generate_nested(item) for item in value]
|
||||
elif value is None:
|
||||
field_type = self._detect_field_type(key, "")
|
||||
result[key] = self._generate_scalar(field_type)
|
||||
else:
|
||||
field_type = self._detect_field_type(key, value)
|
||||
result[key] = self._generate_scalar(field_type)
|
||||
return result
|
||||
|
||||
def generate_from_json(self, json_str: str) -> str:
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
generated = self.generate_from_dict(data)
|
||||
return json.dumps(generated, indent=2)
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({"error": f"Invalid JSON: {str(e)}"})
|
||||
|
||||
def generate_response(
|
||||
self, original_response: str, preserve_structure: bool = True
|
||||
) -> str:
|
||||
if not preserve_structure:
|
||||
return self.generate_from_json(original_response)
|
||||
|
||||
try:
|
||||
data = json.loads(original_response)
|
||||
generated = self.generate_from_dict(data)
|
||||
return json.dumps(generated, indent=2)
|
||||
except json.JSONDecodeError:
|
||||
return original_response
|
||||
281
src/api_mock_cli/core/har_parser.py
Normal file
281
src/api_mock_cli/core/har_parser.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
try:
|
||||
from haralyzer import HarParser
|
||||
except ImportError:
|
||||
HarParser = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedRequest:
|
||||
method: str
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
query_params: dict[str, list[str]]
|
||||
body: str | None
|
||||
content_type: str | None
|
||||
timing: float | None
|
||||
status_code: int
|
||||
response_body: str | None
|
||||
response_headers: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HARParseResult:
|
||||
requests: list[UnifiedRequest]
|
||||
base_url: str
|
||||
entry_count: int
|
||||
skipped_count: int
|
||||
|
||||
|
||||
class HARParserError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HARParser:
|
||||
def __init__(self, har_file_path: str | None = None, har_data: str | None = None):
|
||||
if HarParser is None:
|
||||
raise HARParserError("haralyzer library is not installed")
|
||||
|
||||
self.har_file_path = har_file_path
|
||||
self.har_data = har_data
|
||||
self._parsed = None
|
||||
|
||||
def _load_har_data(self) -> dict[str, Any]:
|
||||
if self._parsed is not None:
|
||||
return self._parsed
|
||||
|
||||
if self.har_data:
|
||||
if isinstance(self.har_data, str):
|
||||
self._parsed = json.loads(self.har_data)
|
||||
else:
|
||||
self._parsed = self.har_data
|
||||
elif self.har_file_path:
|
||||
with open(self.har_file_path, encoding="utf-8") as f:
|
||||
self._parsed = json.load(f)
|
||||
else:
|
||||
raise HARParserError("No HAR data provided")
|
||||
|
||||
return self._parsed
|
||||
|
||||
def _extract_url_patterns(self, url: str) -> tuple[str, dict[str, str]]:
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path
|
||||
path_params = {}
|
||||
|
||||
uuid_pattern = re.compile(
|
||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
id_pattern = re.compile(r"/(\d+)(?:/|$)")
|
||||
hash_pattern = re.compile(r"/([a-f0-9]{24,})(?:/|$)", re.IGNORECASE)
|
||||
|
||||
parts = path.split("/")
|
||||
new_parts = []
|
||||
for i, part in enumerate(parts):
|
||||
if uuid_pattern.match(part):
|
||||
path_params[f"uuid_{i}"] = part
|
||||
new_parts.append("<uuid>")
|
||||
elif id_pattern.match(f"/{part}"):
|
||||
match = id_pattern.match(f"/{part}")
|
||||
if match:
|
||||
path_params[f"id_{i}"] = match.group(1)
|
||||
new_parts.append("<id>")
|
||||
elif hash_pattern.match(f"/{part}"):
|
||||
match = hash_pattern.match(f"/{part}")
|
||||
if match:
|
||||
path_params[f"hash_{i}"] = match.group(1)
|
||||
new_parts.append("<hash>")
|
||||
else:
|
||||
new_parts.append(part)
|
||||
|
||||
pattern = "/".join(new_parts)
|
||||
if pattern.endswith("/") and len(new_parts) > 1:
|
||||
pattern = pattern.rstrip("/")
|
||||
|
||||
return pattern, path_params
|
||||
|
||||
def _parse_headers(self, headers: list[dict[str, Any]]) -> dict[str, str]:
|
||||
return {h["name"].lower(): h["value"] for h in headers}
|
||||
|
||||
def _parse_query_params(self, query_string: str) -> dict[str, list[str]]:
|
||||
parsed = parse_qs(query_string)
|
||||
return {k: v for k, v in parsed.items()}
|
||||
|
||||
def _parse_post_data(
|
||||
self, post_data: dict[str, Any]
|
||||
) -> tuple[str | None, str | None]:
|
||||
mime_type = post_data.get("mimeType", "")
|
||||
text = post_data.get("text")
|
||||
return text, mime_type if mime_type else None
|
||||
|
||||
def _parse_timing(self, time: float | None) -> float | None:
|
||||
if time is None or time < 0:
|
||||
return None
|
||||
return time / 1000.0
|
||||
|
||||
def parse(self) -> HARParseResult:
|
||||
har_data = self._load_har_data()
|
||||
|
||||
if "log" not in har_data:
|
||||
raise HARParserError("Invalid HAR format: missing 'log' key")
|
||||
|
||||
log = har_data["log"]
|
||||
entries = log.get("entries", [])
|
||||
|
||||
if not entries:
|
||||
raise HARParserError("No entries found in HAR file")
|
||||
|
||||
base_url = ""
|
||||
requests: list[UnifiedRequest] = []
|
||||
skipped = 0
|
||||
|
||||
for entry in entries:
|
||||
request = entry.get("request", {})
|
||||
response = entry.get("response", {})
|
||||
|
||||
url = request.get("url", "")
|
||||
if not url:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not base_url:
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
method = request.get("method", "GET")
|
||||
headers = self._parse_headers(request.get("headers", []))
|
||||
query_string = request.get("queryString", [])
|
||||
query_params = {}
|
||||
for qs in query_string:
|
||||
name = qs.get("name", "")
|
||||
value = qs.get("value", "")
|
||||
if name not in query_params:
|
||||
query_params[name] = []
|
||||
query_params[name].append(value)
|
||||
|
||||
post_data = request.get("postData")
|
||||
body: str | None = None
|
||||
content_type: str | None = None
|
||||
if post_data:
|
||||
body, content_type = self._parse_post_data(post_data)
|
||||
|
||||
timing = self._parse_timing(entry.get("time"))
|
||||
|
||||
status_code = response.get("status", 0)
|
||||
if status_code == 0:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
response_headers = self._parse_headers(response.get("headers", []))
|
||||
|
||||
response_body: str | None = None
|
||||
content = response.get("content", {})
|
||||
if content.get("mimeType", "").startswith("application/json"):
|
||||
response_body = content.get("text")
|
||||
|
||||
unified = UnifiedRequest(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
query_params=query_params,
|
||||
body=body,
|
||||
content_type=content_type,
|
||||
timing=timing,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
response_headers=response_headers,
|
||||
)
|
||||
requests.append(unified)
|
||||
|
||||
return HARParseResult(
|
||||
requests=requests,
|
||||
base_url=base_url,
|
||||
entry_count=len(entries),
|
||||
skipped_count=skipped,
|
||||
)
|
||||
|
||||
|
||||
def parse_browser_network_export(data: dict[str, Any]) -> HARParseResult:
|
||||
if "log" in data:
|
||||
parser = HARParser(har_data=data)
|
||||
return parser.parse()
|
||||
|
||||
entries: list[dict[str, Any]] = data.get("entries", [])
|
||||
if not entries:
|
||||
raise HARParserError("No entries found in browser network export")
|
||||
|
||||
requests: list[UnifiedRequest] = []
|
||||
base_url = ""
|
||||
skipped = 0
|
||||
|
||||
for entry in entries:
|
||||
request = entry.get("request", entry)
|
||||
response = entry.get("response", {})
|
||||
|
||||
url = request.get("url", "")
|
||||
if not url:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not base_url:
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
method = request.get("method", "GET")
|
||||
headers: dict[str, str] = {}
|
||||
for h in request.get("headers", []):
|
||||
headers[h.get("name", "").lower()] = h.get("value", "")
|
||||
|
||||
query_params: dict[str, list[str]] = {}
|
||||
for qs in request.get("queryString", []):
|
||||
name = qs.get("name", "")
|
||||
value = qs.get("value", "")
|
||||
if name not in query_params:
|
||||
query_params[name] = []
|
||||
query_params[name].append(value)
|
||||
|
||||
body = request.get("postData", {}).get("text")
|
||||
|
||||
timing = entry.get("time")
|
||||
if timing is not None:
|
||||
timing = timing / 1000.0
|
||||
|
||||
status_code = response.get("status", 0)
|
||||
if status_code == 0:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
response_headers: dict[str, str] = {}
|
||||
for h in response.get("headers", []):
|
||||
response_headers[h.get("name", "").lower()] = h.get("value", "")
|
||||
|
||||
response_body = None
|
||||
content = response.get("content", {})
|
||||
if content.get("mimeType", "").startswith("application/json"):
|
||||
response_body = content.get("text")
|
||||
|
||||
unified = UnifiedRequest(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
query_params=query_params,
|
||||
body=body,
|
||||
content_type=headers.get("content-type"),
|
||||
timing=timing,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
response_headers=response_headers,
|
||||
)
|
||||
requests.append(unified)
|
||||
|
||||
return HARParseResult(
|
||||
requests=requests,
|
||||
base_url=base_url,
|
||||
entry_count=len(entries),
|
||||
skipped_count=skipped,
|
||||
)
|
||||
143
src/api_mock_cli/core/mock_generator.py
Normal file
143
src/api_mock_cli/core/mock_generator.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from api_mock_cli.core.data_generator import FakeDataGenerator
|
||||
from api_mock_cli.core.har_parser import HARParseResult, UnifiedRequest
|
||||
from api_mock_cli.core.route_matcher import RouteMatcher
|
||||
|
||||
|
||||
class MockGenerator:
|
||||
def __init__(self, parse_result: HARParseResult):
|
||||
self.parse_result = parse_result
|
||||
self.route_matcher = RouteMatcher()
|
||||
self.data_generator = FakeDataGenerator()
|
||||
self._build_routes()
|
||||
|
||||
def _build_routes(self):
|
||||
for request in self.parse_result.requests:
|
||||
self.route_matcher.add_route(request.url, request.method)
|
||||
|
||||
def _generate_response_body(self, request: UnifiedRequest) -> str:
|
||||
if request.response_body:
|
||||
return self.data_generator.generate_response(
|
||||
request.response_body, preserve_structure=True
|
||||
)
|
||||
return json.dumps({"message": "Mock response"})
|
||||
|
||||
def generate_routes(self) -> list[dict[str, Any]]:
|
||||
routes: list[dict[str, Any]] = []
|
||||
seen_keys: set[str] = set()
|
||||
|
||||
for request in self.parse_result.requests:
|
||||
match = self.route_matcher.match(request.method, request.url)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
pattern = match.route_pattern
|
||||
method = request.method
|
||||
key = f"{method}:{pattern}"
|
||||
if key in seen_keys:
|
||||
continue
|
||||
seen_keys.add(key)
|
||||
|
||||
route_info: dict[str, Any] = {
|
||||
"pattern": pattern,
|
||||
"method": request.method,
|
||||
"status_code": request.status_code,
|
||||
"headers": request.response_headers,
|
||||
"response_body": self._generate_response_body(request),
|
||||
"timing": request.timing,
|
||||
}
|
||||
routes.append(route_info)
|
||||
|
||||
return routes
|
||||
|
||||
def generate_app(self) -> str:
|
||||
routes = self.generate_routes()
|
||||
|
||||
imports = [
|
||||
"from flask import Flask, jsonify, request, Response",
|
||||
"import time",
|
||||
"import json",
|
||||
"",
|
||||
]
|
||||
|
||||
app_creation = [
|
||||
"def create_app():",
|
||||
" app = Flask(__name__)",
|
||||
"",
|
||||
]
|
||||
|
||||
for route in routes:
|
||||
pattern = route["pattern"]
|
||||
method = route["method"]
|
||||
status_code = route["status_code"]
|
||||
headers = route["headers"]
|
||||
response_body = route["response_body"]
|
||||
timing = route.get("timing")
|
||||
|
||||
route_args = f"'{pattern}'"
|
||||
if method != "GET":
|
||||
route_args = f"'{pattern}', methods=['{method}']"
|
||||
|
||||
app_creation.append(f" @app.route({route_args})")
|
||||
app_creation.append(" def handler_" + str(hash(f"{method}:{pattern}") & 0xFFFFFF) + "():")
|
||||
app_creation.append(" if request.method == 'OPTIONS':")
|
||||
app_creation.append(" return '', 204")
|
||||
|
||||
if timing:
|
||||
app_creation.append(f" time.sleep({timing})")
|
||||
|
||||
for header_name, header_value in headers.items():
|
||||
if header_name in ("content-type", "Content-Type"):
|
||||
continue
|
||||
app_creation.append(
|
||||
f" response.headers['{header_name}'] = '{header_value}'"
|
||||
)
|
||||
|
||||
content_type = headers.get("content-type", headers.get("Content-Type", "application/json"))
|
||||
if not content_type:
|
||||
content_type = "application/json"
|
||||
|
||||
body_repr = repr(response_body)
|
||||
app_creation.append(f" return {body_repr}, {status_code}")
|
||||
app_creation.append("")
|
||||
|
||||
app_creation.append(" return app")
|
||||
app_creation.append("")
|
||||
app_creation.append("")
|
||||
app_creation.append("if __name__ == '__main__':")
|
||||
app_creation.append(" app = create_app()")
|
||||
app_creation.append(" app.run(host='0.0.0.0', port=5000, debug=True)")
|
||||
|
||||
code = "\n".join(imports + app_creation)
|
||||
return code
|
||||
|
||||
def save_mock_server(self, output_path: str) -> None:
|
||||
code = self.generate_app()
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
def get_route_summary(self) -> list[dict[str, str]]:
|
||||
summary: list[dict[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for request in self.parse_result.requests:
|
||||
match = self.route_matcher.match(request.method, request.url)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
key = f"{request.method}:{match.route_pattern}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
summary.append(
|
||||
{
|
||||
"method": request.method,
|
||||
"route": match.route_pattern,
|
||||
"status": str(request.status_code),
|
||||
}
|
||||
)
|
||||
|
||||
return summary
|
||||
151
src/api_mock_cli/core/route_matcher.py
Normal file
151
src/api_mock_cli/core/route_matcher.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteMatch:
|
||||
route_pattern: str
|
||||
path_params: dict[str, str]
|
||||
query_params: dict[str, list[str]]
|
||||
matched: bool
|
||||
|
||||
|
||||
class RouteMatcher:
|
||||
PATH_PARAM_PATTERN = re.compile(r"<([^>]+)>")
|
||||
|
||||
def __init__(self):
|
||||
self.routes: list[tuple[str, str | None, str]] = []
|
||||
|
||||
def _convert_url_to_flask_pattern(self, url: str) -> tuple[str, list[str]]:
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path
|
||||
path_params: list[str] = []
|
||||
|
||||
uuid_pattern = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
id_pattern = r"\d+"
|
||||
hash_pattern = r"[a-f0-9]{24,}"
|
||||
hash_like_pattern = r"[a-zA-Z0-9_-]{8,}"
|
||||
|
||||
parts = path.split("/")
|
||||
new_parts = []
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
|
||||
if re.match(uuid_pattern, part, re.IGNORECASE):
|
||||
path_params.append("uuid")
|
||||
new_parts.append("<uuid>")
|
||||
elif re.match(id_pattern, part):
|
||||
path_params.append("id")
|
||||
new_parts.append("<id>")
|
||||
elif re.match(hash_pattern, part, re.IGNORECASE):
|
||||
path_params.append("hash")
|
||||
new_parts.append("<hash>")
|
||||
elif re.match(hash_like_pattern, part) and len(part) >= 12:
|
||||
path_params.append("param")
|
||||
new_parts.append("<param>")
|
||||
elif part.startswith("{") and part.endswith("}"):
|
||||
param_name = part[1:-1]
|
||||
path_params.append(param_name)
|
||||
new_parts.append(f"<{param_name}>")
|
||||
elif self.PATH_PARAM_PATTERN.match(part):
|
||||
match = self.PATH_PARAM_PATTERN.match(part)
|
||||
if match:
|
||||
path_params.append(match.group(1))
|
||||
new_parts.append(part)
|
||||
elif "<" in part and ">" in part:
|
||||
match = self.PATH_PARAM_PATTERN.search(part)
|
||||
if match:
|
||||
path_params.append(match.group(1))
|
||||
new_parts.append(part)
|
||||
else:
|
||||
new_parts.append(part)
|
||||
|
||||
pattern = "/" + "/".join(new_parts)
|
||||
if not pattern.endswith("/") and len(parts) > 0 and parts[-1] == "":
|
||||
pattern += "/"
|
||||
|
||||
pattern = pattern.rstrip("/")
|
||||
|
||||
return pattern, path_params
|
||||
|
||||
def _extract_path_params(
|
||||
self, path: str, pattern: str
|
||||
) -> dict[str, str] | None:
|
||||
pattern_parts = [p for p in pattern.split("/") if p]
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
if len(pattern_parts) != len(path_parts):
|
||||
return None
|
||||
|
||||
path_params: dict[str, str] = {}
|
||||
for pp, gp in zip(path_parts, pattern_parts):
|
||||
if gp.startswith("<") and gp.endswith(">"):
|
||||
param_name = gp[1:-1]
|
||||
path_params[param_name] = pp
|
||||
elif pp != gp:
|
||||
return None
|
||||
|
||||
return path_params
|
||||
|
||||
def add_route(self, url: str, method: str | None = None) -> str:
|
||||
pattern, _ = self._convert_url_to_flask_pattern(url)
|
||||
route_key = (pattern, method, url)
|
||||
if route_key not in self.routes:
|
||||
self.routes.append(route_key)
|
||||
return pattern
|
||||
|
||||
def match(self, method: str, url: str) -> RouteMatch | None:
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path
|
||||
query_string = parsed.query
|
||||
|
||||
query_params: dict[str, list[str]] = {}
|
||||
if query_string:
|
||||
parsed_qs = parse_qs(query_string)
|
||||
query_params = {k: v for k, v in parsed_qs.items()}
|
||||
|
||||
for pattern, route_method, original_url in self.routes:
|
||||
if route_method is not None and route_method != method:
|
||||
continue
|
||||
path_params = self._extract_path_params(path, pattern)
|
||||
if path_params is not None:
|
||||
return RouteMatch(
|
||||
route_pattern=pattern,
|
||||
path_params=path_params,
|
||||
query_params=query_params,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
for pattern, route_method, original_url in self.routes:
|
||||
if route_method is not None and route_method != method:
|
||||
continue
|
||||
pattern_parts = [p for p in pattern.split("/") if p]
|
||||
path_parts = [p for p in path.split("/") if p]
|
||||
|
||||
if len(pattern_parts) != len(path_parts):
|
||||
continue
|
||||
|
||||
all_match = True
|
||||
path_params = {}
|
||||
for pp, gp in zip(path_parts, pattern_parts):
|
||||
if gp.startswith("<") and gp.endswith(">"):
|
||||
param_name = gp[1:-1]
|
||||
path_params[param_name] = pp
|
||||
elif pp != gp:
|
||||
all_match = False
|
||||
break
|
||||
|
||||
if all_match:
|
||||
return RouteMatch(
|
||||
route_pattern=pattern,
|
||||
path_params=path_params,
|
||||
query_params=query_params,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_routes(self) -> list[str]:
|
||||
return [pattern for pattern, _, _ in self.routes]
|
||||
148
src/api_mock_cli/core/server.py
Normal file
148
src/api_mock_cli/core/server.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, Response, jsonify, request
|
||||
|
||||
|
||||
class MockServer:
|
||||
def __init__(self, host: str = "localhost", port: int = 5000, debug: bool = False):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.debug = debug
|
||||
self.app: Flask | None = None
|
||||
self._routes: list[dict[str, Any]] = []
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
pattern: str,
|
||||
method: str,
|
||||
status_code: int,
|
||||
response_body: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timing: float | None = None,
|
||||
) -> None:
|
||||
route_info: dict[str, Any] = {
|
||||
"pattern": pattern,
|
||||
"method": method,
|
||||
"status_code": status_code,
|
||||
"response_body": response_body,
|
||||
"headers": headers or {},
|
||||
"timing": timing,
|
||||
}
|
||||
self._routes.append(route_info)
|
||||
|
||||
def _create_dynamic_handler(self, route: dict[str, Any]):
|
||||
def handler(**kwargs):
|
||||
if request.method == "OPTIONS":
|
||||
return "", 204
|
||||
|
||||
if route.get("timing"):
|
||||
time.sleep(route["timing"])
|
||||
|
||||
for header_name, header_value in route["headers"].items():
|
||||
if header_name.lower() not in ("content-type", "content-length"):
|
||||
pass
|
||||
|
||||
content_type = route["headers"].get("Content-Type", route["headers"].get("content-type", "application/json"))
|
||||
response = Response(
|
||||
route["response_body"],
|
||||
status=route["status_code"],
|
||||
mimetype=content_type,
|
||||
)
|
||||
|
||||
for header_name, header_value in route["headers"].items():
|
||||
if header_name.lower() not in ("content-type", "content-length"):
|
||||
response.headers[header_name] = header_value
|
||||
|
||||
return response
|
||||
|
||||
return handler
|
||||
|
||||
def create_app(self) -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["DEBUG"] = self.debug
|
||||
|
||||
@app.route("/<path:fallback>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
||||
def fallback_handler(fallback):
|
||||
return jsonify({"error": "Route not found", "path": fallback}), 404
|
||||
|
||||
for route in self._routes:
|
||||
pattern = route["pattern"]
|
||||
method = route["method"]
|
||||
handler = self._create_dynamic_handler(route)
|
||||
|
||||
if method == "GET":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["GET"])
|
||||
elif method == "POST":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["POST"])
|
||||
elif method == "PUT":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["PUT"])
|
||||
elif method == "DELETE":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["DELETE"])
|
||||
elif method == "PATCH":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["PATCH"])
|
||||
elif method == "OPTIONS":
|
||||
app.add_url_rule(pattern, view_func=handler, methods=["OPTIONS"])
|
||||
|
||||
@app.before_request
|
||||
def log_request():
|
||||
pass
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
response.headers["X-Mock-Server"] = "api-mock-cli"
|
||||
return response
|
||||
|
||||
self.app = app
|
||||
return app
|
||||
|
||||
def run(self) -> None:
|
||||
if self.app is None:
|
||||
self.create_app()
|
||||
if self.app:
|
||||
self.app.run(host=self.host, port=self.port, debug=self.debug)
|
||||
|
||||
|
||||
def create_mock_server_from_har(
|
||||
parse_result, host: str = "localhost", port: int = 5000, debug: bool = False
|
||||
) -> MockServer:
|
||||
from api_mock_cli.core.data_generator import FakeDataGenerator
|
||||
from api_mock_cli.core.route_matcher import RouteMatcher
|
||||
|
||||
server = MockServer(host=host, port=port, debug=debug)
|
||||
route_matcher = RouteMatcher()
|
||||
data_generator = FakeDataGenerator()
|
||||
|
||||
for req in parse_result.requests:
|
||||
route_matcher.add_route(req.url, req.method)
|
||||
|
||||
seen_patterns: set[str] = set()
|
||||
|
||||
for req in parse_result.requests:
|
||||
match = route_matcher.match(req.method, req.url)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
pattern = match.route_pattern
|
||||
key = f"{req.method}:{pattern}"
|
||||
if key in seen_patterns:
|
||||
continue
|
||||
seen_patterns.add(key)
|
||||
|
||||
response_body = req.response_body
|
||||
if response_body:
|
||||
response_body = data_generator.generate_response(
|
||||
response_body, preserve_structure=True
|
||||
)
|
||||
|
||||
server.add_route(
|
||||
pattern=pattern,
|
||||
method=req.method,
|
||||
status_code=req.status_code,
|
||||
response_body=response_body or json.dumps({"message": "Mock response"}),
|
||||
headers=req.response_headers,
|
||||
timing=req.timing,
|
||||
)
|
||||
|
||||
return server
|
||||
3
src/api_mock_cli/main.py
Normal file
3
src/api_mock_cli/main.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from api_mock_cli.cli import cli
|
||||
|
||||
__all__ = ["cli"]
|
||||
0
src/api_mock_cli/utils/__init__.py
Normal file
0
src/api_mock_cli/utils/__init__.py
Normal file
110
src/api_mock_cli/utils/auth_handler.py
Normal file
110
src/api_mock_cli/utils/auth_handler.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import base64
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthInfo:
|
||||
auth_type: str
|
||||
credentials: dict[str, str]
|
||||
header_name: str
|
||||
header_value: str
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
BEARER_PATTERN = re.compile(r"Bearer\s+(.+)", re.IGNORECASE)
|
||||
BASIC_PATTERN = re.compile(r"Basic\s+(.+)", re.IGNORECASE)
|
||||
API_KEY_PATTERN = re.compile(r"(.+)\s+(.+)", re.IGNORECASE)
|
||||
|
||||
def __init__(self):
|
||||
self.supported_auth_types = ["bearer", "basic", "api_key", "cookie"]
|
||||
|
||||
def extract_auth(self, headers: dict[str, str]) -> AuthInfo | None:
|
||||
for header_name, header_value in headers.items():
|
||||
header_lower = header_name.lower()
|
||||
|
||||
if header_lower == "authorization":
|
||||
bearer_match = self.BEARER_PATTERN.match(header_value)
|
||||
if bearer_match:
|
||||
token = bearer_match.group(1)
|
||||
return AuthInfo(
|
||||
auth_type="bearer",
|
||||
credentials={"token": token},
|
||||
header_name=header_name,
|
||||
header_value=header_value,
|
||||
)
|
||||
|
||||
basic_match = self.BASIC_PATTERN.match(header_value)
|
||||
if basic_match:
|
||||
encoded = basic_match.group(1)
|
||||
try:
|
||||
decoded = base64.b64decode(encoded).decode("utf-8")
|
||||
username, password = decoded.split(":", 1)
|
||||
return AuthInfo(
|
||||
auth_type="basic",
|
||||
credentials={"username": username, "password": password},
|
||||
header_name=header_name,
|
||||
header_value=header_value,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elif header_lower == "x-api-key":
|
||||
return AuthInfo(
|
||||
auth_type="api_key",
|
||||
credentials={"api_key": header_value},
|
||||
header_name=header_name,
|
||||
header_value=header_value,
|
||||
)
|
||||
|
||||
elif header_lower.startswith("cookie"):
|
||||
return AuthInfo(
|
||||
auth_type="cookie",
|
||||
credentials={"cookie": header_value},
|
||||
header_name=header_name,
|
||||
header_value=header_value,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def match_auth(
|
||||
self, request_headers: dict[str, str], expected_auth: AuthInfo
|
||||
) -> bool:
|
||||
request_auth = self.extract_auth(request_headers)
|
||||
if request_auth is None:
|
||||
return expected_auth is None
|
||||
|
||||
if expected_auth.auth_type == "bearer":
|
||||
return request_auth.auth_type == "bearer"
|
||||
elif expected_auth.auth_type == "basic":
|
||||
return request_auth.auth_type == "basic"
|
||||
elif expected_auth.auth_type == "api_key":
|
||||
return request_auth.auth_type == "api_key"
|
||||
elif expected_auth.auth_type == "cookie":
|
||||
return request_auth.auth_type == "cookie"
|
||||
|
||||
return False
|
||||
|
||||
def generate_auth_header(self, auth_type: str, credentials: dict[str, str]) -> tuple[str, str]:
|
||||
if auth_type == "bearer":
|
||||
token = credentials.get("token", "")
|
||||
return "Authorization", f"Bearer {token}"
|
||||
elif auth_type == "basic":
|
||||
username = credentials.get("username", "")
|
||||
password = credentials.get("password", "")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return "Authorization", f"Basic {encoded}"
|
||||
elif auth_type == "api_key":
|
||||
key = credentials.get("api_key", "")
|
||||
return "X-API-Key", key
|
||||
elif auth_type == "cookie":
|
||||
return "Cookie", credentials.get("cookie", "")
|
||||
|
||||
return "", ""
|
||||
|
||||
def apply_auth_to_headers(
|
||||
self, headers: dict[str, str], auth_info: AuthInfo
|
||||
) -> dict[str, str]:
|
||||
new_headers = dict(headers)
|
||||
new_headers[auth_info.header_name] = auth_info.header_value
|
||||
return new_headers
|
||||
38
src/api_mock_cli/utils/file_utils.py
Normal file
38
src/api_mock_cli/utils/file_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def read_file(path: str) -> str:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_file(path: str, content: str) -> None:
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def read_json(path: str) -> dict[str, Any]:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(path: str, data: dict[str, Any], indent: int = 2) -> None:
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=indent)
|
||||
|
||||
|
||||
def ensure_directory(path: str) -> None:
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def file_exists(path: str) -> bool:
|
||||
return os.path.isfile(path)
|
||||
|
||||
|
||||
def get_file_extension(path: str) -> str:
|
||||
return os.path.splitext(path)[1].lower()
|
||||
1
src/api_mock_cli/version.py
Normal file
1
src/api_mock_cli/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
3
src/memory_manager/__init__.py
Normal file
3
src/memory_manager/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Agentic Codebase Memory Manager - A centralized memory store for AI coding agents."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
0
src/memory_manager/api/__init__.py
Normal file
0
src/memory_manager/api/__init__.py
Normal file
207
src/memory_manager/api/app.py
Normal file
207
src/memory_manager/api/app.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""FastAPI REST API for the memory manager."""
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from memory_manager import __version__
|
||||
from memory_manager.api.schemas import (
|
||||
CommitCreate,
|
||||
CommitResponse,
|
||||
DiffResponse,
|
||||
HealthResponse,
|
||||
MemoryEntryCreate,
|
||||
MemoryEntryResponse,
|
||||
MemoryEntryUpdate,
|
||||
StatsResponse,
|
||||
)
|
||||
from memory_manager.core.services import MemoryManager
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
from memory_manager.db.repository import MemoryRepository
|
||||
|
||||
db_path = os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
|
||||
repository = MemoryRepository(db_path)
|
||||
memory_manager = MemoryManager(repository)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await memory_manager.initialize()
|
||||
yield
|
||||
await memory_manager.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Agentic Codebase Memory Manager",
|
||||
description="A centralized memory store for AI coding agents",
|
||||
version=__version__,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health():
|
||||
return HealthResponse(status="ok", version=__version__)
|
||||
|
||||
|
||||
@app.get("/api/memory", response_model=list[MemoryEntryResponse])
|
||||
async def list_memory(
|
||||
category: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
):
|
||||
category_enum = None
|
||||
if category:
|
||||
try:
|
||||
category_enum = MemoryCategory(category)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid category: {category}")
|
||||
|
||||
entries = await memory_manager.memory_service.list_entries(
|
||||
category=category_enum,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return entries
|
||||
|
||||
|
||||
@app.post("/api/memory", response_model=MemoryEntryResponse, status_code=201)
|
||||
async def create_memory(entry: MemoryEntryCreate):
|
||||
result = await memory_manager.memory_service.create_entry(
|
||||
title=entry.title,
|
||||
content=entry.content,
|
||||
category=entry.category,
|
||||
tags=entry.tags,
|
||||
agent_id=entry.agent_id,
|
||||
project_path=entry.project_path,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/memory/log", response_model=list[CommitResponse])
|
||||
async def get_log(
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
):
|
||||
commits = await memory_manager.commit_service.list_commits(
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return commits
|
||||
|
||||
|
||||
@app.get("/api/memory/stats", response_model=StatsResponse)
|
||||
async def get_stats(project_path: str | None = None):
|
||||
entries = await memory_manager.memory_service.list_entries(
|
||||
project_path=project_path,
|
||||
limit=10000,
|
||||
)
|
||||
|
||||
entries_by_category: dict[str, int] = {}
|
||||
for entry in entries:
|
||||
cat = entry["category"]
|
||||
entries_by_category[cat] = entries_by_category.get(cat, 0) + 1
|
||||
|
||||
commits = await memory_manager.commit_service.list_commits(
|
||||
project_path=project_path,
|
||||
limit=10000,
|
||||
)
|
||||
|
||||
return StatsResponse(
|
||||
total_entries=len(entries),
|
||||
entries_by_category=entries_by_category,
|
||||
total_commits=len(commits),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/memory/search", response_model=list[MemoryEntryResponse])
|
||||
async def search_memory(
|
||||
q: str = Query(..., min_length=1),
|
||||
category: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
):
|
||||
category_enum = None
|
||||
if category:
|
||||
try:
|
||||
category_enum = MemoryCategory(category)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid category: {category}")
|
||||
|
||||
results = await memory_manager.search_service.search(
|
||||
query=q,
|
||||
category=category_enum,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@app.post("/api/memory/commit", response_model=CommitResponse, status_code=201)
|
||||
async def create_commit(commit: CommitCreate):
|
||||
result = await memory_manager.commit_service.create_commit(
|
||||
message=commit.message,
|
||||
agent_id=commit.agent_id,
|
||||
project_path=commit.project_path,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/memory/diff/{hash1}/{hash2}", response_model=DiffResponse)
|
||||
async def get_diff(hash1: str, hash2: str):
|
||||
diff = await memory_manager.commit_service.diff(hash1, hash2)
|
||||
if not diff:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Commit(s) not found: {hash1}, {hash2}. Check available commits with /api/memory/log"
|
||||
)
|
||||
return diff
|
||||
|
||||
|
||||
@app.get("/api/memory/{entry_id}", response_model=MemoryEntryResponse)
|
||||
async def get_memory(entry_id: int):
|
||||
entry = await memory_manager.memory_service.get_entry(entry_id)
|
||||
if not entry:
|
||||
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")
|
||||
return entry
|
||||
|
||||
|
||||
@app.put("/api/memory/{entry_id}", response_model=MemoryEntryResponse)
|
||||
async def update_memory(entry_id: int, entry: MemoryEntryUpdate):
|
||||
result = await memory_manager.memory_service.update_entry(
|
||||
entry_id=entry_id,
|
||||
title=entry.title,
|
||||
content=entry.content,
|
||||
category=entry.category,
|
||||
tags=entry.tags,
|
||||
)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")
|
||||
return result
|
||||
|
||||
|
||||
@app.delete("/api/memory/{entry_id}", status_code=204)
|
||||
async def delete_memory(entry_id: int):
|
||||
deleted = await memory_manager.memory_service.delete_entry(entry_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")
|
||||
79
src/memory_manager/api/schemas.py
Normal file
79
src/memory_manager/api/schemas.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Pydantic schemas for API request/response validation."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
|
||||
|
||||
class MemoryEntryCreate(BaseModel):
|
||||
title: str = Field(..., min_length=1, max_length=255)
|
||||
content: str = Field(..., min_length=1)
|
||||
category: MemoryCategory
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
agent_id: str | None = None
|
||||
project_path: str | None = None
|
||||
|
||||
|
||||
class MemoryEntryUpdate(BaseModel):
|
||||
title: str | None = Field(None, min_length=1, max_length=255)
|
||||
content: str | None = Field(None, min_length=1)
|
||||
category: MemoryCategory | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class MemoryEntryResponse(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
content: str
|
||||
category: str
|
||||
tags: list[str]
|
||||
agent_id: str
|
||||
project_path: str
|
||||
created_at: datetime | None
|
||||
updated_at: datetime | None
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
q: str = Field(..., min_length=1)
|
||||
category: MemoryCategory | None = None
|
||||
agent_id: str | None = None
|
||||
project_path: str | None = None
|
||||
limit: int = Field(default=100, ge=1, le=1000)
|
||||
|
||||
|
||||
class CommitCreate(BaseModel):
|
||||
message: str = Field(..., min_length=1)
|
||||
agent_id: str | None = None
|
||||
project_path: str | None = None
|
||||
|
||||
|
||||
class CommitResponse(BaseModel):
|
||||
id: int
|
||||
hash: str
|
||||
message: str
|
||||
agent_id: str
|
||||
project_path: str
|
||||
snapshot: list[dict[str, Any]]
|
||||
created_at: datetime | None
|
||||
|
||||
|
||||
class DiffResponse(BaseModel):
|
||||
commit1: dict[str, Any]
|
||||
commit2: dict[str, Any]
|
||||
added: list[dict[str, Any]]
|
||||
removed: list[dict[str, Any]]
|
||||
modified: list[dict[str, Any]]
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
version: str
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_entries: int
|
||||
entries_by_category: dict[str, int]
|
||||
total_commits: int
|
||||
0
src/memory_manager/cli/__init__.py
Normal file
0
src/memory_manager/cli/__init__.py
Normal file
340
src/memory_manager/cli/main.py
Normal file
340
src/memory_manager/cli/main.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""CLI interface for the memory manager using Click."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import click
|
||||
|
||||
from memory_manager import __version__
|
||||
from memory_manager.core.services import MemoryManager
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
from memory_manager.db.repository import MemoryRepository
|
||||
|
||||
|
||||
def get_db_path() -> str:
|
||||
return os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
|
||||
|
||||
|
||||
async def get_memory_manager() -> MemoryManager:
|
||||
repository = MemoryRepository(get_db_path())
|
||||
await repository.initialize()
|
||||
manager = MemoryManager(repository)
|
||||
return manager
|
||||
|
||||
|
||||
def validate_category(ctx, param, value):
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return MemoryCategory(value)
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"Invalid category. Must be one of: {[c.value for c in MemoryCategory]}")
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=__version__)
|
||||
def cli():
|
||||
"""Agentic Codebase Memory Manager - A centralized memory store for AI coding agents."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--title", "-t", required=True, help="Entry title")
|
||||
@click.option("--content", "-c", required=True, help="Entry content")
|
||||
@click.option("--category", "-g", required=True, callback=validate_category, help="Entry category")
|
||||
@click.option("--tags", "-T", multiple=True, help="Entry tags")
|
||||
@click.option("--agent-id", help="Agent ID (defaults to AGENT_ID env var)")
|
||||
@click.option("--project-path", help="Project path (defaults to MEMORY_PROJECT_PATH env var)")
|
||||
def add(title, content, category, tags, agent_id, project_path):
|
||||
"""Add a new memory entry."""
|
||||
asyncio.run(_add(title, content, category, list(tags), agent_id, project_path))
|
||||
|
||||
|
||||
async def _add(title, content, category, tags, agent_id, project_path):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
entry = await manager.memory_service.create_entry(
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
tags=tags,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
)
|
||||
click.echo(f"Created entry {entry['id']}: {entry['title']}")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--category", "-g", callback=validate_category, help="Filter by category")
|
||||
@click.option("--agent-id", help="Filter by agent ID")
|
||||
@click.option("--project-path", help="Filter by project path")
|
||||
@click.option("--limit", "-n", default=100, help="Number of entries to show")
|
||||
@click.option("--offset", default=0, help="Offset for pagination")
|
||||
def list(category, agent_id, project_path, limit, offset):
|
||||
"""List memory entries."""
|
||||
asyncio.run(_list(category, agent_id, project_path, limit, offset))
|
||||
|
||||
|
||||
async def _list(category, agent_id, project_path, limit, offset):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
entries = await manager.memory_service.list_entries(
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if not entries:
|
||||
click.echo("No entries found.")
|
||||
return
|
||||
|
||||
for entry in entries:
|
||||
created = entry["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
|
||||
click.echo(f"[{entry['id']}] {entry['category']} | {entry['title']} | {created} | {entry['agent_id']}")
|
||||
click.echo(f" {entry['content'][:100]}...")
|
||||
if entry["tags"]:
|
||||
click.echo(f" Tags: {', '.join(entry['tags'])}")
|
||||
click.echo()
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("query")
|
||||
@click.option("--category", "-g", callback=validate_category, help="Filter by category")
|
||||
@click.option("--agent-id", help="Filter by agent ID")
|
||||
@click.option("--project-path", help="Filter by project path")
|
||||
@click.option("--limit", "-n", default=100, help="Number of results")
|
||||
def search(query, category, agent_id, project_path, limit):
|
||||
"""Search memory entries."""
|
||||
asyncio.run(_search(query, category, agent_id, project_path, limit))
|
||||
|
||||
|
||||
async def _search(query, category, agent_id, project_path, limit):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
results = await manager.search_service.search(
|
||||
query=query,
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
)
|
||||
if not results:
|
||||
click.echo("No results found.")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(results)} result(s):\n")
|
||||
for entry in results:
|
||||
created = entry["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
|
||||
click.echo(f"[{entry['id']}] {entry['category']} | {entry['title']} | {created}")
|
||||
click.echo(f" {entry['content'][:200]}...")
|
||||
if entry["tags"]:
|
||||
click.echo(f" Tags: {', '.join(entry['tags'])}")
|
||||
click.echo()
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("entry_id", type=int)
|
||||
def get(entry_id):
|
||||
"""Get a specific memory entry by ID."""
|
||||
asyncio.run(_get(entry_id))
|
||||
|
||||
|
||||
async def _get(entry_id):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
entry = await manager.memory_service.get_entry(entry_id)
|
||||
if not entry:
|
||||
click.echo(f"Entry {entry_id} not found.", err=True)
|
||||
return
|
||||
|
||||
click.echo(f"ID: {entry['id']}")
|
||||
click.echo(f"Title: {entry['title']}")
|
||||
click.echo(f"Category: {entry['category']}")
|
||||
click.echo(f"Agent: {entry['agent_id']}")
|
||||
click.echo(f"Project: {entry['project_path']}")
|
||||
click.echo(f"Tags: {', '.join(entry['tags']) if entry['tags'] else '(none)'}")
|
||||
click.echo(f"Created: {entry['created_at']}")
|
||||
click.echo(f"Updated: {entry['updated_at']}")
|
||||
click.echo(f"\nContent:\n{entry['content']}")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("entry_id", type=int)
|
||||
@click.option("--title", "-t", help="New title")
|
||||
@click.option("--content", "-c", help="New content")
|
||||
@click.option("--category", "-g", callback=validate_category, help="New category")
|
||||
@click.option("--tags", "-T", multiple=True, help="New tags")
|
||||
def update(entry_id, title, content, category, tags):
|
||||
"""Update a memory entry."""
|
||||
asyncio.run(_update(entry_id, title, content, category, list(tags) if tags else None))
|
||||
|
||||
|
||||
async def _update(entry_id, title, content, category, tags):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
result = await manager.memory_service.update_entry(
|
||||
entry_id=entry_id,
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
tags=tags,
|
||||
)
|
||||
if not result:
|
||||
click.echo(f"Entry {entry_id} not found.", err=True)
|
||||
return
|
||||
click.echo(f"Updated entry {entry_id}: {result['title']}")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("entry_id", type=int)
|
||||
def delete(entry_id):
|
||||
"""Delete a memory entry."""
|
||||
asyncio.run(_delete(entry_id))
|
||||
|
||||
|
||||
async def _delete(entry_id):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
deleted = await manager.memory_service.delete_entry(entry_id)
|
||||
if not deleted:
|
||||
click.echo(f"Entry {entry_id} not found.", err=True)
|
||||
return
|
||||
click.echo(f"Deleted entry {entry_id}.")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--message", "-m", required=True, help="Commit message")
|
||||
@click.option("--agent-id", help="Agent ID")
|
||||
@click.option("--project-path", help="Project path")
|
||||
def commit(message, agent_id, project_path):
|
||||
"""Create a commit snapshot of current memory state."""
|
||||
asyncio.run(_commit(message, agent_id, project_path))
|
||||
|
||||
|
||||
async def _commit(message, agent_id, project_path):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
result = await manager.commit_service.create_commit(
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
)
|
||||
click.echo(f"Created commit {result['hash']}: {result['message']}")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--agent-id", help="Filter by agent ID")
|
||||
@click.option("--project-path", help="Filter by project path")
|
||||
@click.option("--limit", "-n", default=100, help="Number of commits to show")
|
||||
@click.option("--offset", default=0, help="Offset for pagination")
|
||||
def log(agent_id, project_path, limit, offset):
|
||||
"""Show commit history."""
|
||||
asyncio.run(_log(agent_id, project_path, limit, offset))
|
||||
|
||||
|
||||
async def _log(agent_id, project_path, limit, offset):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
commits = await manager.commit_service.list_commits(
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if not commits:
|
||||
click.echo("No commits found.")
|
||||
return
|
||||
|
||||
for commit in commits:
|
||||
created = commit["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
|
||||
click.echo(f"commit {commit['hash']}")
|
||||
click.echo(f"Author: {commit['agent_id']}")
|
||||
click.echo(f"Date: {created}")
|
||||
click.echo(f"\n {commit['message']}\n")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("hash1")
|
||||
@click.argument("hash2")
|
||||
def diff(hash1, hash2):
|
||||
"""Show diff between two commits."""
|
||||
asyncio.run(_diff(hash1, hash2))
|
||||
|
||||
|
||||
async def _diff(hash1, hash2):
|
||||
manager = await get_memory_manager()
|
||||
try:
|
||||
result = await manager.commit_service.diff(hash1, hash2)
|
||||
if not result:
|
||||
click.echo("One or both commits not found. Check available commits with 'memory log'.", err=True)
|
||||
return
|
||||
|
||||
if result["added"]:
|
||||
click.echo("Added entries:")
|
||||
for entry in result["added"]:
|
||||
click.echo(f" + [{entry['id']}] {entry['title']}")
|
||||
if result["removed"]:
|
||||
click.echo("\nRemoved entries:")
|
||||
for entry in result["removed"]:
|
||||
click.echo(f" - [{entry['id']}] {entry['title']}")
|
||||
if result["modified"]:
|
||||
click.echo("\nModified entries:")
|
||||
for mod in result["modified"]:
|
||||
click.echo(f" ~ [{mod['after']['id']}] {mod['after']['title']}")
|
||||
|
||||
if not any([result["added"], result["removed"], result["modified"]]):
|
||||
click.echo("No differences found.")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--host", default=os.getenv("MEMORY_API_HOST", "127.0.0.1"), help="Server host")
|
||||
@click.option("--port", default=int(os.getenv("MEMORY_API_PORT", "8080")), help="Server port")
|
||||
def serve(host, port):
|
||||
"""Start the API server."""
|
||||
import uvicorn
|
||||
|
||||
from memory_manager.api.app import app as fastapi_app
|
||||
|
||||
click.echo(f"Starting server at http://{host}:{port}")
|
||||
click.echo(f"API documentation available at http://{host}:{port}/docs")
|
||||
|
||||
uvicorn.run(fastapi_app, host=host, port=port)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def tui():
|
||||
"""Launch the TUI dashboard."""
|
||||
from memory_manager.tui.app import TUIApp
|
||||
|
||||
app = TUIApp()
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
0
src/memory_manager/core/__init__.py
Normal file
0
src/memory_manager/core/__init__.py
Normal file
210
src/memory_manager/core/services.py
Normal file
210
src/memory_manager/core/services.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Core business logic services for the memory manager."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
from memory_manager.db.repository import MemoryRepository
|
||||
|
||||
|
||||
class MemoryService:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
|
||||
async def create_entry(
|
||||
self,
|
||||
title: str,
|
||||
content: str,
|
||||
category: str | MemoryCategory,
|
||||
tags: list[str] | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
|
||||
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
|
||||
|
||||
entry = await self.repository.create_entry(
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
tags=tags or [],
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
)
|
||||
return entry.to_dict()
|
||||
|
||||
async def get_entry(self, entry_id: int) -> dict[str, Any] | None:
|
||||
entry = await self.repository.get_entry(entry_id)
|
||||
return entry.to_dict() if entry else None
|
||||
|
||||
async def update_entry(
|
||||
self,
|
||||
entry_id: int,
|
||||
title: str | None = None,
|
||||
content: str | None = None,
|
||||
category: str | MemoryCategory | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entry = await self.repository.update_entry(
|
||||
entry_id=entry_id,
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
tags=tags,
|
||||
)
|
||||
return entry.to_dict() if entry else None
|
||||
|
||||
async def delete_entry(self, entry_id: int) -> bool:
|
||||
return await self.repository.delete_entry(entry_id)
|
||||
|
||||
async def list_entries(
|
||||
self,
|
||||
category: str | MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entries = await self.repository.list_entries(
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [entry.to_dict() for entry in entries]
|
||||
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
category: str | MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entries = await self.repository.search_entries(
|
||||
query_text=query,
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
)
|
||||
return [entry.to_dict() for entry in entries]
|
||||
|
||||
|
||||
class CommitService:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
|
||||
def _generate_hash(self, data: str) -> str:
|
||||
return hashlib.sha1(data.encode()).hexdigest()
|
||||
|
||||
async def create_commit(
|
||||
self,
|
||||
message: str,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
|
||||
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
|
||||
|
||||
snapshot = await self.repository.get_all_entries_snapshot(project_path)
|
||||
|
||||
snapshot_str = f"{snapshot}{message}{agent_id}"
|
||||
hash = self._generate_hash(snapshot_str)
|
||||
|
||||
commit = await self.repository.create_commit(
|
||||
hash=hash,
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return commit.to_dict()
|
||||
|
||||
async def get_commit(self, hash: str) -> dict[str, Any] | None:
|
||||
commit = await self.repository.get_commit(hash)
|
||||
return commit.to_dict() if commit else None
|
||||
|
||||
async def list_commits(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
commits = await self.repository.list_commits(
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [commit.to_dict() for commit in commits]
|
||||
|
||||
async def diff(self, hash1: str, hash2: str) -> dict[str, Any] | None:
|
||||
commit1 = await self.repository.get_commit(hash1)
|
||||
commit2 = await self.repository.get_commit(hash2)
|
||||
|
||||
if not commit1 or not commit2:
|
||||
return None
|
||||
|
||||
snapshot1 = {entry["id"]: entry for entry in commit1.snapshot}
|
||||
snapshot2 = {entry["id"]: entry for entry in commit2.snapshot}
|
||||
|
||||
all_ids = set(snapshot1.keys()) | set(snapshot2.keys())
|
||||
|
||||
added = []
|
||||
removed = []
|
||||
modified = []
|
||||
|
||||
for entry_id in all_ids:
|
||||
if entry_id not in snapshot1:
|
||||
added.append(snapshot2[entry_id])
|
||||
elif entry_id not in snapshot2:
|
||||
removed.append(snapshot1[entry_id])
|
||||
else:
|
||||
if snapshot1[entry_id] != snapshot2[entry_id]:
|
||||
modified.append({
|
||||
"before": snapshot1[entry_id],
|
||||
"after": snapshot2[entry_id],
|
||||
})
|
||||
|
||||
return {
|
||||
"commit1": commit1.to_dict(),
|
||||
"commit2": commit2.to_dict(),
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"modified": modified,
|
||||
}
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
self.memory_service = MemoryService(repository)
|
||||
self.search_service = SearchService(repository)
|
||||
self.commit_service = CommitService(repository)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.repository.initialize()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.repository.close()
|
||||
0
src/memory_manager/db/__init__.py
Normal file
0
src/memory_manager/db/__init__.py
Normal file
134
src/memory_manager/db/models.py
Normal file
134
src/memory_manager/db/models.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""SQLAlchemy database models for the memory manager."""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Enum, Index, String, Text, func, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryCategory(enum.StrEnum):
|
||||
DECISION = "decision"
|
||||
FEATURE = "feature"
|
||||
REFACTORING = "refactoring"
|
||||
ARCHITECTURE = "architecture"
|
||||
BUG = "bug"
|
||||
NOTE = "note"
|
||||
|
||||
|
||||
class MemoryEntry(Base):
|
||||
__tablename__ = "memory_entries"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
category: Mapped[MemoryCategory] = mapped_column(Enum(MemoryCategory), nullable=False)
|
||||
tags: Mapped[list[str]] = mapped_column(JSON, default=list)
|
||||
agent_id: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
project_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_category", "category"),
|
||||
Index("idx_agent_id", "agent_id"),
|
||||
Index("idx_project_path", "project_path"),
|
||||
Index("idx_created_at", "created_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"title": self.title,
|
||||
"content": self.content,
|
||||
"category": self.category.value,
|
||||
"tags": self.tags,
|
||||
"agent_id": self.agent_id,
|
||||
"project_path": self.project_path,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Commit(Base):
|
||||
__tablename__ = "commits"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
hash: Mapped[str] = mapped_column(String(40), unique=True, nullable=False)
|
||||
message: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
agent_id: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
project_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
snapshot: Mapped[list[dict[str, Any]]] = mapped_column(JSON, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_commit_hash", "hash"),
|
||||
Index("idx_commit_agent_id", "agent_id"),
|
||||
Index("idx_commit_created_at", "created_at"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"hash": self.hash,
|
||||
"message": self.message,
|
||||
"agent_id": self.agent_id,
|
||||
"project_path": self.project_path,
|
||||
"snapshot": self.snapshot,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
async def init_db(db_path: str):
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await conn.execute(text(
|
||||
"CREATE VIRTUAL TABLE IF NOT EXISTS memory_entries_fts USING fts5("
|
||||
"title, content, tags, category, agent_id, project_path, content='memory_entries', content_rowid='id')"
|
||||
))
|
||||
|
||||
await conn.execute(text(
|
||||
"CREATE TRIGGER IF NOT EXISTS memory_entries_ai AFTER INSERT ON memory_entries BEGIN "
|
||||
"INSERT INTO memory_entries_fts(rowid, title, content, tags, category, agent_id, project_path) "
|
||||
"VALUES (new.id, new.title, new.content, new.tags, new.category, new.agent_id, new.project_path); END"
|
||||
))
|
||||
|
||||
await conn.execute(text(
|
||||
"CREATE TRIGGER IF NOT EXISTS memory_entries_ad AFTER DELETE ON memory_entries BEGIN "
|
||||
"INSERT INTO memory_entries_fts(memory_entries_fts, rowid, title, content, tags, category, agent_id, project_path) "
|
||||
"VALUES ('delete', old.id, old.title, old.content, old.tags, old.category, old.agent_id, old.project_path); END"
|
||||
))
|
||||
|
||||
await conn.execute(text(
|
||||
"CREATE TRIGGER IF NOT EXISTS memory_entries_au AFTER UPDATE ON memory_entries BEGIN "
|
||||
"INSERT INTO memory_entries_fts(memory_entries_fts, rowid, title, content, tags, category, agent_id, project_path) "
|
||||
"VALUES ('delete', old.id, old.title, old.content, old.tags, old.category, old.agent_id, old.project_path); "
|
||||
"INSERT INTO memory_entries_fts(rowid, title, content, tags, category, agent_id, project_path) "
|
||||
"VALUES (new.id, new.title, new.content, new.tags, new.category, new.agent_id, new.project_path); END"
|
||||
))
|
||||
|
||||
return engine
|
||||
232
src/memory_manager/db/repository.py
Normal file
232
src/memory_manager/db/repository.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Async repository for database operations."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
|
||||
|
||||
|
||||
class MemoryRepository:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.engine: Any = None
|
||||
self._session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
self.engine = await init_db(self.db_path)
|
||||
self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
|
||||
|
||||
async def get_session(self) -> AsyncSession:
|
||||
if not self._session_factory:
|
||||
await self.initialize()
|
||||
assert self._session_factory is not None
|
||||
return self._session_factory()
|
||||
|
||||
async def create_entry(
|
||||
self,
|
||||
title: str,
|
||||
content: str,
|
||||
category: MemoryCategory,
|
||||
tags: list[str],
|
||||
agent_id: str,
|
||||
project_path: str,
|
||||
) -> MemoryEntry:
|
||||
async with await self.get_session() as session:
|
||||
entry = MemoryEntry(
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
tags=tags,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
)
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
async def get_entry(self, entry_id: int) -> MemoryEntry | None:
|
||||
async with await self.get_session() as session:
|
||||
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_entry(
|
||||
self,
|
||||
entry_id: int,
|
||||
title: str | None = None,
|
||||
content: str | None = None,
|
||||
category: MemoryCategory | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> MemoryEntry | None:
|
||||
async with await self.get_session() as session:
|
||||
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
if title is not None:
|
||||
entry.title = title
|
||||
if content is not None:
|
||||
entry.content = content
|
||||
if category is not None:
|
||||
entry.category = category
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
async def delete_entry(self, entry_id: int) -> bool:
|
||||
async with await self.get_session() as session:
|
||||
result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||
await session.commit()
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
return rowcount is not None and rowcount > 0
|
||||
|
||||
async def list_entries(
|
||||
self,
|
||||
category: MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[MemoryEntry]:
|
||||
async with await self.get_session() as session:
|
||||
query = select(MemoryEntry)
|
||||
|
||||
if category:
|
||||
query = query.where(MemoryEntry.category == category)
|
||||
if agent_id:
|
||||
query = query.where(MemoryEntry.agent_id == agent_id)
|
||||
if project_path:
|
||||
query = query.where(MemoryEntry.project_path == project_path)
|
||||
|
||||
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def search_entries(
|
||||
self,
|
||||
query_text: str,
|
||||
category: MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[MemoryEntry]:
|
||||
async with await self.get_session() as session:
|
||||
fts_query = f'"{query_text}"'
|
||||
|
||||
sql_parts = ["""
|
||||
SELECT m.* FROM memory_entries m
|
||||
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
|
||||
WHERE memory_entries_fts MATCH :query
|
||||
"""]
|
||||
|
||||
params: dict[str, Any] = {"query": fts_query}
|
||||
|
||||
if category:
|
||||
sql_parts.append(" AND m.category = :category")
|
||||
params["category"] = category.value
|
||||
if agent_id:
|
||||
sql_parts.append(" AND m.agent_id = :agent_id")
|
||||
params["agent_id"] = agent_id
|
||||
if project_path:
|
||||
sql_parts.append(" AND m.project_path = :project_path")
|
||||
params["project_path"] = project_path
|
||||
|
||||
sql_parts.append(" LIMIT :limit")
|
||||
params["limit"] = limit
|
||||
|
||||
sql = text("".join(sql_parts))
|
||||
|
||||
result = await session.execute(sql, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
entries = []
|
||||
for row in rows:
|
||||
entry = MemoryEntry(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
content=row.content,
|
||||
category=MemoryCategory(row.category),
|
||||
tags=row.tags,
|
||||
agent_id=row.agent_id,
|
||||
project_path=row.project_path,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
async def create_commit(
|
||||
self,
|
||||
hash: str,
|
||||
message: str,
|
||||
agent_id: str,
|
||||
project_path: str,
|
||||
snapshot: list[dict[str, Any]],
|
||||
) -> Commit:
|
||||
async with await self.get_session() as session:
|
||||
commit = Commit(
|
||||
hash=hash,
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
session.add(commit)
|
||||
await session.commit()
|
||||
await session.refresh(commit)
|
||||
return commit
|
||||
|
||||
async def get_commit(self, hash: str) -> Commit | None:
|
||||
async with await self.get_session() as session:
|
||||
result = await session.execute(select(Commit).where(Commit.hash == hash))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_commit_by_id(self, commit_id: int) -> Commit | None:
|
||||
async with await self.get_session() as session:
|
||||
result = await session.execute(select(Commit).where(Commit.id == commit_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_commits(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Commit]:
|
||||
async with await self.get_session() as session:
|
||||
query = select(Commit)
|
||||
|
||||
if agent_id:
|
||||
query = query.where(Commit.agent_id == agent_id)
|
||||
if project_path:
|
||||
query = query.where(Commit.project_path == project_path)
|
||||
|
||||
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
|
||||
async with await self.get_session() as session:
|
||||
query = select(MemoryEntry)
|
||||
if project_path:
|
||||
query = query.where(MemoryEntry.project_path == project_path)
|
||||
|
||||
result = await session.execute(query)
|
||||
entries = result.scalars().all()
|
||||
return [entry.to_dict() for entry in entries]
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.engine:
|
||||
await self.engine.dispose()
|
||||
0
src/memory_manager/tui/__init__.py
Normal file
0
src/memory_manager/tui/__init__.py
Normal file
363
src/memory_manager/tui/app.py
Normal file
363
src/memory_manager/tui/app.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Textual TUI application for the memory manager."""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from textual import on
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Static,
|
||||
)
|
||||
|
||||
from memory_manager.core.services import MemoryManager
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
from memory_manager.db.repository import MemoryRepository
|
||||
|
||||
db_path = os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
|
||||
|
||||
|
||||
async def get_memory_manager() -> MemoryManager:
|
||||
repository = MemoryRepository(db_path)
|
||||
await repository.initialize()
|
||||
manager = MemoryManager(repository)
|
||||
return manager
|
||||
|
||||
|
||||
class DashboardScreen(Screen):
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $surface;
|
||||
}
|
||||
.stats-container {
|
||||
height: auto;
|
||||
padding: 1 2;
|
||||
background: $panel;
|
||||
border: solid $border;
|
||||
}
|
||||
.stat-label {
|
||||
color: $text-muted;
|
||||
}
|
||||
.stat-value {
|
||||
color: $text;
|
||||
bold: true;
|
||||
}
|
||||
.entry-list {
|
||||
height: 1fr;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, manager: MemoryManager):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Container(
|
||||
Vertical(
|
||||
Label("Memory Manager Dashboard", classes="header-title"),
|
||||
ScrollableContainer(classes="stats-container", id="stats-container"),
|
||||
ListView(id="recent-entries", classes="entry-list"),
|
||||
id="dashboard-content",
|
||||
)
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
await self.load_stats()
|
||||
|
||||
async def load_stats(self) -> None:
|
||||
stats_container = self.query_one("#stats-container", ScrollableContainer)
|
||||
|
||||
entries = await self.manager.memory_service.list_entries(limit=10000)
|
||||
commits = await self.manager.commit_service.list_commits(limit=10000)
|
||||
|
||||
entries_by_category: dict[str, int] = {}
|
||||
for entry in entries:
|
||||
cat = entry["category"]
|
||||
entries_by_category[cat] = entries_by_category.get(cat, 0) + 1
|
||||
|
||||
stats_text = f"""
|
||||
Total Entries: {len(entries)} | Total Commits: {len(commits)}
|
||||
|
||||
Entries by Category:
|
||||
"""
|
||||
for cat, count in entries_by_category.items():
|
||||
stats_text += f" {cat}: {count}\n"
|
||||
|
||||
stats_container.remove_children()
|
||||
stats_container.mount(Static(stats_text))
|
||||
|
||||
list_view = self.query_one("#recent-entries", ListView)
|
||||
list_view.clear()
|
||||
for entry in entries[:10]:
|
||||
created = entry["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
|
||||
list_item = ListItem(
|
||||
Label(f"[{entry['category']}] {entry['title'][:40]} - {created}"),
|
||||
)
|
||||
await list_view.mount(list_item)
|
||||
|
||||
|
||||
class MemoryListScreen(Screen):
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $surface;
|
||||
}
|
||||
.filter-bar {
|
||||
height: 3;
|
||||
padding: 1;
|
||||
background: $panel;
|
||||
}
|
||||
.entry-detail {
|
||||
height: 1fr;
|
||||
padding: 1 2;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, manager: MemoryManager, category: MemoryCategory | None = None):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.current_category = category
|
||||
self.entries: list[dict[str, Any]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Container(
|
||||
Horizontal(
|
||||
Button("All", id="filter-all"),
|
||||
Button("Decision", id="filter-decision"),
|
||||
Button("Feature", id="filter-feature"),
|
||||
Button("Refactoring", id="filter-refactoring"),
|
||||
Button("Architecture", id="filter-architecture"),
|
||||
Button("Bug", id="filter-bug"),
|
||||
Button("Note", id="filter-note"),
|
||||
classes="filter-bar",
|
||||
),
|
||||
Horizontal(
|
||||
ListView(id="entries-list", classes="column"),
|
||||
ScrollableContainer(id="entry-detail", classes="column entry-detail"),
|
||||
classes="main-content",
|
||||
),
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
await self.load_entries()
|
||||
|
||||
async def load_entries(self, category: MemoryCategory | None = None) -> None:
|
||||
self.entries = await self.manager.memory_service.list_entries(
|
||||
category=category,
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
list_view = self.query_one("#entries-list", ListView)
|
||||
list_view.clear()
|
||||
|
||||
for entry in self.entries:
|
||||
created = entry["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
|
||||
list_item = ListItem(
|
||||
Label(f"[{entry['category']}] {entry['title']}"),
|
||||
Label(f"{entry['agent_id']} | {created}"),
|
||||
)
|
||||
await list_view.mount(list_item)
|
||||
|
||||
@on(ListView.Selected)
|
||||
async def on_entry_selected(self, event: ListView.Selected) -> None:
|
||||
index = event.list_view.index
|
||||
if index is not None and 0 <= index < len(self.entries):
|
||||
entry = self.entries[index]
|
||||
detail_container = self.query_one("#entry-detail", ScrollableContainer)
|
||||
detail_container.remove_children()
|
||||
|
||||
content = f"""
|
||||
Title: {entry['title']}
|
||||
Category: {entry['category']}
|
||||
Agent: {entry['agent_id']}
|
||||
Project: {entry['project_path']}
|
||||
Tags: {', '.join(entry['tags']) if entry['tags'] else '(none)'}
|
||||
Created: {entry['created_at']}
|
||||
Updated: {entry['updated_at']}
|
||||
|
||||
Content:
|
||||
{entry['content']}
|
||||
"""
|
||||
await detail_container.mount(Static(content))
|
||||
|
||||
@on(Button.Pressed)
|
||||
async def on_filter_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
category = None
|
||||
|
||||
if button_id == "filter-decision":
|
||||
category = MemoryCategory.DECISION
|
||||
elif button_id == "filter-feature":
|
||||
category = MemoryCategory.FEATURE
|
||||
elif button_id == "filter-refactoring":
|
||||
category = MemoryCategory.REFACTORING
|
||||
elif button_id == "filter-architecture":
|
||||
category = MemoryCategory.ARCHITECTURE
|
||||
elif button_id == "filter-bug":
|
||||
category = MemoryCategory.BUG
|
||||
elif button_id == "filter-note":
|
||||
category = MemoryCategory.NOTE
|
||||
|
||||
await self.load_entries(category)
|
||||
|
||||
|
||||
class CommitHistoryScreen(Screen):
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $surface;
|
||||
}
|
||||
.commit-list {
|
||||
height: 1fr;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, manager: MemoryManager):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.commits: list[dict[str, Any]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Container(
|
||||
ScrollableContainer(
|
||||
ListView(id="commits-list", classes="commit-list"),
|
||||
id="commits-container",
|
||||
),
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
await self.load_commits()
|
||||
|
||||
async def load_commits(self) -> None:
|
||||
self.commits = await self.manager.commit_service.list_commits(limit=100)
|
||||
|
||||
list_view = self.query_one("#commits-list", ListView)
|
||||
list_view.clear()
|
||||
|
||||
for commit in self.commits:
|
||||
created = commit["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
content = f"commit {commit['hash'][:8]}\n{commit['agent_id']} | {created}\n\n {commit['message']}"
|
||||
|
||||
list_item = ListItem(
|
||||
Static(content, markup=False),
|
||||
)
|
||||
await list_view.mount(list_item)
|
||||
|
||||
|
||||
class SearchScreen(Screen):
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $surface;
|
||||
}
|
||||
.search-input {
|
||||
height: 3;
|
||||
padding: 1;
|
||||
}
|
||||
.results-list {
|
||||
height: 1fr;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, manager: MemoryManager):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.results: list[dict[str, Any]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Container(
|
||||
Horizontal(
|
||||
Input(placeholder="Search query...", id="search-input"),
|
||||
Button("Search", id="search-button"),
|
||||
classes="search-input",
|
||||
),
|
||||
ScrollableContainer(
|
||||
ListView(id="results-list", classes="results-list"),
|
||||
id="results-container",
|
||||
),
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
@on(Button.Pressed, "#search-button")
|
||||
@on(Input.Submitted, "#search-input")
|
||||
async def on_search(self) -> None:
|
||||
input_widget = self.query_one("#search-input", Input)
|
||||
query = input_widget.value
|
||||
|
||||
if not query:
|
||||
return
|
||||
|
||||
self.results = await self.manager.search_service.search(query=query, limit=100)
|
||||
|
||||
list_view = self.query_one("#results-list", ListView)
|
||||
list_view.clear()
|
||||
|
||||
for entry in self.results:
|
||||
created = entry["created_at"]
|
||||
if created:
|
||||
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
|
||||
|
||||
list_item = ListItem(
|
||||
Label(f"[{entry['category']}] {entry['title'][:40]}"),
|
||||
Label(f"{entry['content'][:80]}... | {created}"),
|
||||
)
|
||||
await list_view.mount(list_item)
|
||||
|
||||
|
||||
class TUIApp(App):
|
||||
BINDINGS = [
|
||||
Binding("d", "switch_dashboard", "Dashboard"),
|
||||
Binding("l", "switch_memory_list", "Memory List"),
|
||||
Binding("c", "switch_commits", "Commits"),
|
||||
Binding("s", "switch_search", "Search"),
|
||||
Binding("q", "quit", "Quit"),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.manager = None
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
self.manager = await get_memory_manager()
|
||||
await self.push_screen(DashboardScreen(self.manager))
|
||||
|
||||
async def on_unmount(self) -> None:
|
||||
if self.manager:
|
||||
await self.manager.close()
|
||||
|
||||
async def switch_dashboard(self) -> None:
|
||||
if self.manager:
|
||||
await self.push_screen(DashboardScreen(self.manager))
|
||||
|
||||
async def switch_memory_list(self) -> None:
|
||||
if self.manager:
|
||||
await self.push_screen(MemoryListScreen(self.manager))
|
||||
|
||||
async def switch_commits(self) -> None:
|
||||
if self.manager:
|
||||
await self.push_screen(CommitHistoryScreen(self.manager))
|
||||
|
||||
async def switch_search(self) -> None:
|
||||
if self.manager:
|
||||
await self.push_screen(SearchScreen(self.manager))
|
||||
Reference in New Issue
Block a user