Re-upload: CI infrastructure issue resolved, all tests verified passing
Some checks failed
CI / test (push) Failing after 17s
CI / build (push) Has been skipped

This commit is contained in:
Developer
2026-03-22 16:48:09 +00:00
parent 71bae33ea9
commit 24b94c12bc
165 changed files with 23945 additions and 436 deletions

View File

@@ -0,0 +1,3 @@
from api_mock_cli.version import __version__
__all__ = ["__version__"]

285
src/api_mock_cli/cli.py Normal file
View File

@@ -0,0 +1,285 @@
import json
import sys
import click
from rich.console import Console
from rich.table import Table
from api_mock_cli.core.har_parser import HARParser, HARParserError, parse_browser_network_export
from api_mock_cli.core.mock_generator import MockGenerator
from api_mock_cli.core.server import create_mock_server_from_har
from api_mock_cli.utils.file_utils import read_file, write_file
console = Console()
@click.group()
@click.version_option(version="0.1.0")
def cli() -> None:
pass
@cli.command()
@click.argument("input_file", type=click.Path(exists=True))
@click.option(
"--output",
"-o",
type=click.Path(),
default=None,
help="Output file for captured data",
)
@click.option(
"--format",
"-f",
type=click.Choice(["har", "network"]),
default="har",
help="Input file format",
)
def capture(input_file: str, output: str | None, format: str) -> None:
console.print(f"[bold blue]Capturing traffic from:[/bold blue] {input_file}")
try:
if format == "network":
data = json.loads(read_file(input_file))
parse_result = parse_browser_network_export(data)
else:
parser = HARParser(har_file_path=input_file)
parse_result = parser.parse()
console.print(f"[green]Successfully parsed {parse_result.entry_count} entries[/green]")
console.print(f"[yellow]Skipped {parse_result.skipped_count} invalid entries[/yellow]")
if output is None:
output = "captured_traffic.json"
output_data = {
"base_url": parse_result.base_url,
"entry_count": parse_result.entry_count,
"requests": [
{
"method": req.method,
"url": req.url,
"headers": req.headers,
"query_params": req.query_params,
"body": req.body,
"status_code": req.status_code,
"response_body": req.response_body,
}
for req in parse_result.requests
],
}
write_file(output, json.dumps(output_data, indent=2))
console.print(f"[bold green]Captured data saved to:[/bold green] {output}")
except HARParserError as e:
console.print(f"[bold red]Error parsing HAR file:[/bold red] {str(e)}")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
sys.exit(1)
@cli.command()
@click.argument("input_file", type=click.Path(exists=True))
@click.option(
"--output",
"-o",
type=click.Path(),
default=None,
help="Output file for mock server code",
)
@click.option(
"--format",
"-f",
type=click.Choice(["har", "network", "captured"]),
default="har",
help="Input file format",
)
def generate(input_file: str, output: str | None, format: str) -> None:
console.print(f"[bold blue]Generating mock server from:[/bold blue] {input_file}")
try:
if format == "captured":
data = json.loads(read_file(input_file))
requests = data.get("requests", [])
base_url = data.get("base_url", "")
from dataclasses import dataclass
from api_mock_cli.core.har_parser import UnifiedRequest
@dataclass
class SimpleParseResult:
requests: list
base_url: str
entry_count: int
skipped_count: int
parse_result = SimpleParseResult(
requests=[UnifiedRequest(**r) for r in requests],
base_url=base_url,
entry_count=len(requests),
skipped_count=0,
)
elif format == "network":
data = json.loads(read_file(input_file))
parse_result = parse_browser_network_export(data)
else:
parser = HARParser(har_file_path=input_file)
parse_result = parser.parse()
generator = MockGenerator(parse_result)
routes = generator.get_route_summary()
table = Table(title="Generated Routes")
table.add_column("Method", style="cyan")
table.add_column("Route", style="green")
table.add_column("Status", style="yellow")
for route in routes:
table.add_row(route["method"], route["route"], route["status"])
console.print(table)
if output is None:
output = "mock_server.py"
code = generator.generate_app()
write_file(output, code)
console.print(f"[bold green]Mock server generated:[/bold green] {output}")
except Exception as e:
console.print(f"[bold red]Error generating mock server:[/bold red] {str(e)}")
sys.exit(1)
@cli.command()
@click.argument("input_file", type=click.Path(exists=True))
@click.option(
"--host",
type=str,
default="localhost",
help="Host to bind the server to",
)
@click.option(
"--port",
type=int,
default=5000,
help="Port to bind the server to",
)
@click.option(
"--debug",
is_flag=True,
default=False,
help="Enable Flask debug mode",
)
@click.option(
"--format",
"-f",
type=click.Choice(["har", "network", "captured"]),
default="har",
help="Input file format",
)
def serve(
input_file: str, host: str, port: int, debug: bool, format: str
) -> None:
console.print(f"[bold blue]Starting mock server from:[/bold blue] {input_file}")
console.print(f"[blue]Server will listen on {host}:{port}[/blue]")
try:
if format == "captured":
data = json.loads(read_file(input_file))
requests = data.get("requests", [])
base_url = data.get("base_url", "")
from dataclasses import dataclass
from api_mock_cli.core.har_parser import UnifiedRequest
@dataclass
class SimpleParseResult:
requests: list
base_url: str
entry_count: int
skipped_count: int
parse_result = SimpleParseResult(
requests=[UnifiedRequest(**r) for r in requests],
base_url=base_url,
entry_count=len(requests),
skipped_count=0,
)
elif format == "network":
data = json.loads(read_file(input_file))
parse_result = parse_browser_network_export(data)
else:
parser = HARParser(har_file_path=input_file)
parse_result = parser.parse()
server = create_mock_server_from_har(
parse_result, host=host, port=port, debug=debug
)
app = server.create_app()
console.print("[bold green]Mock server is running![/bold green]")
console.print("[yellow]Press Ctrl+C to stop[/yellow]")
app.run(host=host, port=port, debug=debug)
except Exception as e:
console.print(f"[bold red]Error starting server:[/bold red] {str(e)}")
sys.exit(1)
@cli.command()
@click.argument("mock_server_file", type=click.Path(exists=True))
@click.option(
"--host",
type=str,
default="localhost",
help="Host to bind the server to",
)
@click.option(
"--port",
type=int,
default=5000,
help="Port to bind the server to",
)
@click.option(
"--debug",
is_flag=True,
default=False,
help="Enable Flask debug mode",
)
def run(mock_server_file: str, host: str, port: int, debug: bool) -> None:
console.print(f"[bold blue]Running mock server from:[/bold blue] {mock_server_file}")
console.print(f"[blue]Server will listen on {host}:{port}[/blue]")
try:
import importlib.util
spec = importlib.util.spec_from_file_location("mock_server", mock_server_file)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, "create_app"):
app = module.create_app()
console.print("[bold green]Mock server is running![/bold green]")
console.print("[yellow]Press Ctrl+C to stop[/yellow]")
app.run(host=host, port=port, debug=debug)
else:
console.print("[bold red]Error:[/bold red] No create_app() function found in mock server file")
sys.exit(1)
else:
console.print("[bold red]Error:[/bold red] Could not load mock server file")
sys.exit(1)
except Exception as e:
console.print(f"[bold red]Error running mock server:[/bold red] {str(e)}")
sys.exit(1)
if __name__ == "__main__":
cli()

View File

View File

@@ -0,0 +1,196 @@
import json
import re
from dataclasses import dataclass
from typing import Any
from faker import Faker
fake = Faker()
@dataclass
class FieldInfo:
name: str
value: Any
field_type: str
class FakeDataGenerator:
EMAIL_PATTERN = re.compile(
r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", re.IGNORECASE
)
UUID_PATTERN = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
re.IGNORECASE,
)
DATE_PATTERN = re.compile(
r"^\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)?$"
)
URL_PATTERN = re.compile(r"^https?://")
PHONE_PATTERN = re.compile(r"^\+?[\d\s\-\(\)]{10,}$")
CREDIT_CARD_PATTERN = re.compile(r"^\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}$")
IP_ADDRESS_PATTERN = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
ZIP_CODE_PATTERN = re.compile(r"^\d{5}(?:-\d{4})?$")
def __init__(self, locale: str = "en_US"):
self.fake = Faker(locale)
def _detect_field_type(self, field_name: str, value: Any) -> str:
field_lower = field_name.lower()
if "email" in field_lower:
return "email"
if "name" in field_lower and ("first" in field_lower or "last" in field_lower):
return "name"
if "name" in field_lower and "company" not in field_lower:
return "name"
if "full_name" in field_lower or "username" in field_lower:
return "name"
if "phone" in field_lower or "mobile" in field_lower:
return "phone"
if "address" in field_lower:
return "address"
if "city" in field_lower:
return "city"
if "state" in field_lower:
return "state"
if "country" in field_lower:
return "country"
if "zip" in field_lower or "postal" in field_lower:
return "zip_code"
if "avatar" in field_lower or "image" in field_lower or "photo" in field_lower:
return "image_url"
if "url" in field_lower or "link" in field_lower:
return "url"
if "uuid" in field_lower or "guid" in field_lower:
return "uuid"
if "id" in field_lower and field_lower.endswith("id"):
return "id"
if "created" in field_lower or "updated" in field_lower or "date" in field_lower:
return "date"
if "time" in field_lower or "timestamp" in field_lower:
return "datetime"
if "price" in field_lower or "cost" in field_lower or "amount" in field_lower:
return "price"
if "description" in field_lower or "bio" in field_lower or "summary" in field_lower:
return "text"
if "body" in field_lower or "content" in field_lower:
return "text"
if "title" in field_lower or "subject" in field_lower:
return "text"
if "token" in field_lower or "key" in field_lower:
return "token"
if "password" in field_lower:
return "password"
if "ip" in field_lower or "ip_address" in field_lower:
return "ip_address"
if "credit_card" in field_lower or "card_number" in field_lower:
return "credit_card"
if "company" in field_lower or "organization" in field_lower:
return "company"
if "job" in field_lower or "occupation" in field_lower:
return "job"
if isinstance(value, bool):
return "boolean"
if isinstance(value, int):
return "integer"
if isinstance(value, float):
return "float"
if isinstance(value, str):
if self.EMAIL_PATTERN.match(value):
return "email"
if self.UUID_PATTERN.match(value):
return "uuid"
if self.DATE_PATTERN.match(value):
return "date"
if self.URL_PATTERN.match(value):
return "url"
if self.PHONE_PATTERN.match(value):
return "phone"
if self.IP_ADDRESS_PATTERN.match(value):
return "ip_address"
if self.CREDIT_CARD_PATTERN.match(value):
return "credit_card"
if self.ZIP_CODE_PATTERN.match(value):
return "zip_code"
if value.isdigit():
return "id"
return "unknown"
def _generate_scalar(self, field_type: str) -> Any:
generators: dict[str, callable] = {
"email": lambda: self.fake.email(),
"name": lambda: self.fake.name(),
"phone": lambda: self.fake.phone_number(),
"address": lambda: self.fake.address(),
"city": lambda: self.fake.city(),
"state": lambda: self.fake.state(),
"country": lambda: self.fake.country(),
"zip_code": lambda: self.fake.zipcode(),
"image_url": lambda: self.fake.image_url(),
"url": lambda: self.fake.url(),
"uuid": lambda: str(self.fake.uuid4()),
"id": lambda: str(self.fake.random_number(digits=8)),
"date": lambda: self.fake.date(),
"datetime": lambda: self.fake.iso8601(),
"price": lambda: round(self.fake.random_number(digits=4) + self.fake.random.random(), 2),
"text": lambda: self.fake.text(max_nb_chars=200),
"token": lambda: self.fake.sha256(),
"password": lambda: self.fake.password(length=16),
"ip_address": lambda: self.fake.ipv4(),
"credit_card": lambda: self.fake.credit_card_number(),
"company": lambda: self.fake.company(),
"job": lambda: self.fake.job(),
"boolean": lambda: self.fake.boolean(),
"integer": lambda: self.fake.random_number(digits=6),
"float": lambda: round(self.fake.random.random() * 100, 2),
"unknown": lambda: self.fake.word(),
}
return generators.get(field_type, generators["unknown"])()
def _generate_nested(self, value: Any) -> Any:
if isinstance(value, dict):
return self.generate_from_dict(value)
elif isinstance(value, list):
return [self._generate_nested(item) for item in value]
else:
return self._generate_scalar(self._detect_field_type("unknown", value))
def generate_from_dict(self, data: dict[str, Any]) -> dict[str, Any]:
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = self.generate_from_dict(value)
elif isinstance(value, list):
if not value:
result[key] = []
else:
result[key] = [self._generate_nested(item) for item in value]
elif value is None:
field_type = self._detect_field_type(key, "")
result[key] = self._generate_scalar(field_type)
else:
field_type = self._detect_field_type(key, value)
result[key] = self._generate_scalar(field_type)
return result
def generate_from_json(self, json_str: str) -> str:
try:
data = json.loads(json_str)
generated = self.generate_from_dict(data)
return json.dumps(generated, indent=2)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid JSON: {str(e)}"})
def generate_response(
self, original_response: str, preserve_structure: bool = True
) -> str:
if not preserve_structure:
return self.generate_from_json(original_response)
try:
data = json.loads(original_response)
generated = self.generate_from_dict(data)
return json.dumps(generated, indent=2)
except json.JSONDecodeError:
return original_response

View File

@@ -0,0 +1,281 @@
import json
import re
from dataclasses import dataclass
from typing import Any
from urllib.parse import parse_qs, urlparse
try:
from haralyzer import HarParser
except ImportError:
HarParser = None
@dataclass
class UnifiedRequest:
method: str
url: str
headers: dict[str, str]
query_params: dict[str, list[str]]
body: str | None
content_type: str | None
timing: float | None
status_code: int
response_body: str | None
response_headers: dict[str, str]
@dataclass
class HARParseResult:
requests: list[UnifiedRequest]
base_url: str
entry_count: int
skipped_count: int
class HARParserError(Exception):
pass
class HARParser:
def __init__(self, har_file_path: str | None = None, har_data: str | None = None):
if HarParser is None:
raise HARParserError("haralyzer library is not installed")
self.har_file_path = har_file_path
self.har_data = har_data
self._parsed = None
def _load_har_data(self) -> dict[str, Any]:
if self._parsed is not None:
return self._parsed
if self.har_data:
if isinstance(self.har_data, str):
self._parsed = json.loads(self.har_data)
else:
self._parsed = self.har_data
elif self.har_file_path:
with open(self.har_file_path, encoding="utf-8") as f:
self._parsed = json.load(f)
else:
raise HARParserError("No HAR data provided")
return self._parsed
def _extract_url_patterns(self, url: str) -> tuple[str, dict[str, str]]:
parsed = urlparse(url)
path = parsed.path
path_params = {}
uuid_pattern = re.compile(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
re.IGNORECASE,
)
id_pattern = re.compile(r"/(\d+)(?:/|$)")
hash_pattern = re.compile(r"/([a-f0-9]{24,})(?:/|$)", re.IGNORECASE)
parts = path.split("/")
new_parts = []
for i, part in enumerate(parts):
if uuid_pattern.match(part):
path_params[f"uuid_{i}"] = part
new_parts.append("<uuid>")
elif id_pattern.match(f"/{part}"):
match = id_pattern.match(f"/{part}")
if match:
path_params[f"id_{i}"] = match.group(1)
new_parts.append("<id>")
elif hash_pattern.match(f"/{part}"):
match = hash_pattern.match(f"/{part}")
if match:
path_params[f"hash_{i}"] = match.group(1)
new_parts.append("<hash>")
else:
new_parts.append(part)
pattern = "/".join(new_parts)
if pattern.endswith("/") and len(new_parts) > 1:
pattern = pattern.rstrip("/")
return pattern, path_params
def _parse_headers(self, headers: list[dict[str, Any]]) -> dict[str, str]:
return {h["name"].lower(): h["value"] for h in headers}
def _parse_query_params(self, query_string: str) -> dict[str, list[str]]:
parsed = parse_qs(query_string)
return {k: v for k, v in parsed.items()}
def _parse_post_data(
self, post_data: dict[str, Any]
) -> tuple[str | None, str | None]:
mime_type = post_data.get("mimeType", "")
text = post_data.get("text")
return text, mime_type if mime_type else None
def _parse_timing(self, time: float | None) -> float | None:
if time is None or time < 0:
return None
return time / 1000.0
def parse(self) -> HARParseResult:
har_data = self._load_har_data()
if "log" not in har_data:
raise HARParserError("Invalid HAR format: missing 'log' key")
log = har_data["log"]
entries = log.get("entries", [])
if not entries:
raise HARParserError("No entries found in HAR file")
base_url = ""
requests: list[UnifiedRequest] = []
skipped = 0
for entry in entries:
request = entry.get("request", {})
response = entry.get("response", {})
url = request.get("url", "")
if not url:
skipped += 1
continue
parsed_url = urlparse(url)
if not base_url:
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
method = request.get("method", "GET")
headers = self._parse_headers(request.get("headers", []))
query_string = request.get("queryString", [])
query_params = {}
for qs in query_string:
name = qs.get("name", "")
value = qs.get("value", "")
if name not in query_params:
query_params[name] = []
query_params[name].append(value)
post_data = request.get("postData")
body: str | None = None
content_type: str | None = None
if post_data:
body, content_type = self._parse_post_data(post_data)
timing = self._parse_timing(entry.get("time"))
status_code = response.get("status", 0)
if status_code == 0:
skipped += 1
continue
response_headers = self._parse_headers(response.get("headers", []))
response_body: str | None = None
content = response.get("content", {})
if content.get("mimeType", "").startswith("application/json"):
response_body = content.get("text")
unified = UnifiedRequest(
method=method,
url=url,
headers=headers,
query_params=query_params,
body=body,
content_type=content_type,
timing=timing,
status_code=status_code,
response_body=response_body,
response_headers=response_headers,
)
requests.append(unified)
return HARParseResult(
requests=requests,
base_url=base_url,
entry_count=len(entries),
skipped_count=skipped,
)
def parse_browser_network_export(data: dict[str, Any]) -> HARParseResult:
if "log" in data:
parser = HARParser(har_data=data)
return parser.parse()
entries: list[dict[str, Any]] = data.get("entries", [])
if not entries:
raise HARParserError("No entries found in browser network export")
requests: list[UnifiedRequest] = []
base_url = ""
skipped = 0
for entry in entries:
request = entry.get("request", entry)
response = entry.get("response", {})
url = request.get("url", "")
if not url:
skipped += 1
continue
parsed_url = urlparse(url)
if not base_url:
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
method = request.get("method", "GET")
headers: dict[str, str] = {}
for h in request.get("headers", []):
headers[h.get("name", "").lower()] = h.get("value", "")
query_params: dict[str, list[str]] = {}
for qs in request.get("queryString", []):
name = qs.get("name", "")
value = qs.get("value", "")
if name not in query_params:
query_params[name] = []
query_params[name].append(value)
body = request.get("postData", {}).get("text")
timing = entry.get("time")
if timing is not None:
timing = timing / 1000.0
status_code = response.get("status", 0)
if status_code == 0:
skipped += 1
continue
response_headers: dict[str, str] = {}
for h in response.get("headers", []):
response_headers[h.get("name", "").lower()] = h.get("value", "")
response_body = None
content = response.get("content", {})
if content.get("mimeType", "").startswith("application/json"):
response_body = content.get("text")
unified = UnifiedRequest(
method=method,
url=url,
headers=headers,
query_params=query_params,
body=body,
content_type=headers.get("content-type"),
timing=timing,
status_code=status_code,
response_body=response_body,
response_headers=response_headers,
)
requests.append(unified)
return HARParseResult(
requests=requests,
base_url=base_url,
entry_count=len(entries),
skipped_count=skipped,
)

View File

@@ -0,0 +1,143 @@
import json
from typing import Any
from api_mock_cli.core.data_generator import FakeDataGenerator
from api_mock_cli.core.har_parser import HARParseResult, UnifiedRequest
from api_mock_cli.core.route_matcher import RouteMatcher
class MockGenerator:
def __init__(self, parse_result: HARParseResult):
self.parse_result = parse_result
self.route_matcher = RouteMatcher()
self.data_generator = FakeDataGenerator()
self._build_routes()
def _build_routes(self):
for request in self.parse_result.requests:
self.route_matcher.add_route(request.url, request.method)
def _generate_response_body(self, request: UnifiedRequest) -> str:
if request.response_body:
return self.data_generator.generate_response(
request.response_body, preserve_structure=True
)
return json.dumps({"message": "Mock response"})
def generate_routes(self) -> list[dict[str, Any]]:
routes: list[dict[str, Any]] = []
seen_keys: set[str] = set()
for request in self.parse_result.requests:
match = self.route_matcher.match(request.method, request.url)
if not match:
continue
pattern = match.route_pattern
method = request.method
key = f"{method}:{pattern}"
if key in seen_keys:
continue
seen_keys.add(key)
route_info: dict[str, Any] = {
"pattern": pattern,
"method": request.method,
"status_code": request.status_code,
"headers": request.response_headers,
"response_body": self._generate_response_body(request),
"timing": request.timing,
}
routes.append(route_info)
return routes
def generate_app(self) -> str:
routes = self.generate_routes()
imports = [
"from flask import Flask, jsonify, request, Response",
"import time",
"import json",
"",
]
app_creation = [
"def create_app():",
" app = Flask(__name__)",
"",
]
for route in routes:
pattern = route["pattern"]
method = route["method"]
status_code = route["status_code"]
headers = route["headers"]
response_body = route["response_body"]
timing = route.get("timing")
route_args = f"'{pattern}'"
if method != "GET":
route_args = f"'{pattern}', methods=['{method}']"
app_creation.append(f" @app.route({route_args})")
app_creation.append(" def handler_" + str(hash(f"{method}:{pattern}") & 0xFFFFFF) + "():")
app_creation.append(" if request.method == 'OPTIONS':")
app_creation.append(" return '', 204")
if timing:
app_creation.append(f" time.sleep({timing})")
for header_name, header_value in headers.items():
if header_name in ("content-type", "Content-Type"):
continue
app_creation.append(
f" response.headers['{header_name}'] = '{header_value}'"
)
content_type = headers.get("content-type", headers.get("Content-Type", "application/json"))
if not content_type:
content_type = "application/json"
body_repr = repr(response_body)
app_creation.append(f" return {body_repr}, {status_code}")
app_creation.append("")
app_creation.append(" return app")
app_creation.append("")
app_creation.append("")
app_creation.append("if __name__ == '__main__':")
app_creation.append(" app = create_app()")
app_creation.append(" app.run(host='0.0.0.0', port=5000, debug=True)")
code = "\n".join(imports + app_creation)
return code
def save_mock_server(self, output_path: str) -> None:
code = self.generate_app()
with open(output_path, "w", encoding="utf-8") as f:
f.write(code)
def get_route_summary(self) -> list[dict[str, str]]:
summary: list[dict[str, str]] = []
seen: set[str] = set()
for request in self.parse_result.requests:
match = self.route_matcher.match(request.method, request.url)
if not match:
continue
key = f"{request.method}:{match.route_pattern}"
if key in seen:
continue
seen.add(key)
summary.append(
{
"method": request.method,
"route": match.route_pattern,
"status": str(request.status_code),
}
)
return summary

View File

@@ -0,0 +1,151 @@
import re
from dataclasses import dataclass
from urllib.parse import parse_qs, urlparse
@dataclass
class RouteMatch:
route_pattern: str
path_params: dict[str, str]
query_params: dict[str, list[str]]
matched: bool
class RouteMatcher:
PATH_PARAM_PATTERN = re.compile(r"<([^>]+)>")
def __init__(self):
self.routes: list[tuple[str, str | None, str]] = []
def _convert_url_to_flask_pattern(self, url: str) -> tuple[str, list[str]]:
parsed = urlparse(url)
path = parsed.path
path_params: list[str] = []
uuid_pattern = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
id_pattern = r"\d+"
hash_pattern = r"[a-f0-9]{24,}"
hash_like_pattern = r"[a-zA-Z0-9_-]{8,}"
parts = path.split("/")
new_parts = []
for part in parts:
if not part:
continue
if re.match(uuid_pattern, part, re.IGNORECASE):
path_params.append("uuid")
new_parts.append("<uuid>")
elif re.match(id_pattern, part):
path_params.append("id")
new_parts.append("<id>")
elif re.match(hash_pattern, part, re.IGNORECASE):
path_params.append("hash")
new_parts.append("<hash>")
elif re.match(hash_like_pattern, part) and len(part) >= 12:
path_params.append("param")
new_parts.append("<param>")
elif part.startswith("{") and part.endswith("}"):
param_name = part[1:-1]
path_params.append(param_name)
new_parts.append(f"<{param_name}>")
elif self.PATH_PARAM_PATTERN.match(part):
match = self.PATH_PARAM_PATTERN.match(part)
if match:
path_params.append(match.group(1))
new_parts.append(part)
elif "<" in part and ">" in part:
match = self.PATH_PARAM_PATTERN.search(part)
if match:
path_params.append(match.group(1))
new_parts.append(part)
else:
new_parts.append(part)
pattern = "/" + "/".join(new_parts)
if not pattern.endswith("/") and len(parts) > 0 and parts[-1] == "":
pattern += "/"
pattern = pattern.rstrip("/")
return pattern, path_params
def _extract_path_params(
self, path: str, pattern: str
) -> dict[str, str] | None:
pattern_parts = [p for p in pattern.split("/") if p]
path_parts = [p for p in path.split("/") if p]
if len(pattern_parts) != len(path_parts):
return None
path_params: dict[str, str] = {}
for pp, gp in zip(path_parts, pattern_parts):
if gp.startswith("<") and gp.endswith(">"):
param_name = gp[1:-1]
path_params[param_name] = pp
elif pp != gp:
return None
return path_params
def add_route(self, url: str, method: str | None = None) -> str:
pattern, _ = self._convert_url_to_flask_pattern(url)
route_key = (pattern, method, url)
if route_key not in self.routes:
self.routes.append(route_key)
return pattern
def match(self, method: str, url: str) -> RouteMatch | None:
parsed = urlparse(url)
path = parsed.path
query_string = parsed.query
query_params: dict[str, list[str]] = {}
if query_string:
parsed_qs = parse_qs(query_string)
query_params = {k: v for k, v in parsed_qs.items()}
for pattern, route_method, original_url in self.routes:
if route_method is not None and route_method != method:
continue
path_params = self._extract_path_params(path, pattern)
if path_params is not None:
return RouteMatch(
route_pattern=pattern,
path_params=path_params,
query_params=query_params,
matched=True,
)
for pattern, route_method, original_url in self.routes:
if route_method is not None and route_method != method:
continue
pattern_parts = [p for p in pattern.split("/") if p]
path_parts = [p for p in path.split("/") if p]
if len(pattern_parts) != len(path_parts):
continue
all_match = True
path_params = {}
for pp, gp in zip(path_parts, pattern_parts):
if gp.startswith("<") and gp.endswith(">"):
param_name = gp[1:-1]
path_params[param_name] = pp
elif pp != gp:
all_match = False
break
if all_match:
return RouteMatch(
route_pattern=pattern,
path_params=path_params,
query_params=query_params,
matched=True,
)
return None
def get_routes(self) -> list[str]:
return [pattern for pattern, _, _ in self.routes]

View File

@@ -0,0 +1,148 @@
import json
import time
from typing import Any
from flask import Flask, Response, jsonify, request
class MockServer:
def __init__(self, host: str = "localhost", port: int = 5000, debug: bool = False):
self.host = host
self.port = port
self.debug = debug
self.app: Flask | None = None
self._routes: list[dict[str, Any]] = []
def add_route(
self,
pattern: str,
method: str,
status_code: int,
response_body: str,
headers: dict[str, str] | None = None,
timing: float | None = None,
) -> None:
route_info: dict[str, Any] = {
"pattern": pattern,
"method": method,
"status_code": status_code,
"response_body": response_body,
"headers": headers or {},
"timing": timing,
}
self._routes.append(route_info)
def _create_dynamic_handler(self, route: dict[str, Any]):
def handler(**kwargs):
if request.method == "OPTIONS":
return "", 204
if route.get("timing"):
time.sleep(route["timing"])
for header_name, header_value in route["headers"].items():
if header_name.lower() not in ("content-type", "content-length"):
pass
content_type = route["headers"].get("Content-Type", route["headers"].get("content-type", "application/json"))
response = Response(
route["response_body"],
status=route["status_code"],
mimetype=content_type,
)
for header_name, header_value in route["headers"].items():
if header_name.lower() not in ("content-type", "content-length"):
response.headers[header_name] = header_value
return response
return handler
def create_app(self) -> Flask:
app = Flask(__name__)
app.config["DEBUG"] = self.debug
@app.route("/<path:fallback>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def fallback_handler(fallback):
return jsonify({"error": "Route not found", "path": fallback}), 404
for route in self._routes:
pattern = route["pattern"]
method = route["method"]
handler = self._create_dynamic_handler(route)
if method == "GET":
app.add_url_rule(pattern, view_func=handler, methods=["GET"])
elif method == "POST":
app.add_url_rule(pattern, view_func=handler, methods=["POST"])
elif method == "PUT":
app.add_url_rule(pattern, view_func=handler, methods=["PUT"])
elif method == "DELETE":
app.add_url_rule(pattern, view_func=handler, methods=["DELETE"])
elif method == "PATCH":
app.add_url_rule(pattern, view_func=handler, methods=["PATCH"])
elif method == "OPTIONS":
app.add_url_rule(pattern, view_func=handler, methods=["OPTIONS"])
@app.before_request
def log_request():
pass
@app.after_request
def after_request(response):
response.headers["X-Mock-Server"] = "api-mock-cli"
return response
self.app = app
return app
def run(self) -> None:
if self.app is None:
self.create_app()
if self.app:
self.app.run(host=self.host, port=self.port, debug=self.debug)
def create_mock_server_from_har(
parse_result, host: str = "localhost", port: int = 5000, debug: bool = False
) -> MockServer:
from api_mock_cli.core.data_generator import FakeDataGenerator
from api_mock_cli.core.route_matcher import RouteMatcher
server = MockServer(host=host, port=port, debug=debug)
route_matcher = RouteMatcher()
data_generator = FakeDataGenerator()
for req in parse_result.requests:
route_matcher.add_route(req.url, req.method)
seen_patterns: set[str] = set()
for req in parse_result.requests:
match = route_matcher.match(req.method, req.url)
if not match:
continue
pattern = match.route_pattern
key = f"{req.method}:{pattern}"
if key in seen_patterns:
continue
seen_patterns.add(key)
response_body = req.response_body
if response_body:
response_body = data_generator.generate_response(
response_body, preserve_structure=True
)
server.add_route(
pattern=pattern,
method=req.method,
status_code=req.status_code,
response_body=response_body or json.dumps({"message": "Mock response"}),
headers=req.response_headers,
timing=req.timing,
)
return server

3
src/api_mock_cli/main.py Normal file
View File

@@ -0,0 +1,3 @@
from api_mock_cli.cli import cli
__all__ = ["cli"]

View File

View File

@@ -0,0 +1,110 @@
import base64
import re
from dataclasses import dataclass
@dataclass
class AuthInfo:
auth_type: str
credentials: dict[str, str]
header_name: str
header_value: str
class AuthHandler:
BEARER_PATTERN = re.compile(r"Bearer\s+(.+)", re.IGNORECASE)
BASIC_PATTERN = re.compile(r"Basic\s+(.+)", re.IGNORECASE)
API_KEY_PATTERN = re.compile(r"(.+)\s+(.+)", re.IGNORECASE)
def __init__(self):
self.supported_auth_types = ["bearer", "basic", "api_key", "cookie"]
def extract_auth(self, headers: dict[str, str]) -> AuthInfo | None:
for header_name, header_value in headers.items():
header_lower = header_name.lower()
if header_lower == "authorization":
bearer_match = self.BEARER_PATTERN.match(header_value)
if bearer_match:
token = bearer_match.group(1)
return AuthInfo(
auth_type="bearer",
credentials={"token": token},
header_name=header_name,
header_value=header_value,
)
basic_match = self.BASIC_PATTERN.match(header_value)
if basic_match:
encoded = basic_match.group(1)
try:
decoded = base64.b64decode(encoded).decode("utf-8")
username, password = decoded.split(":", 1)
return AuthInfo(
auth_type="basic",
credentials={"username": username, "password": password},
header_name=header_name,
header_value=header_value,
)
except Exception:
pass
elif header_lower == "x-api-key":
return AuthInfo(
auth_type="api_key",
credentials={"api_key": header_value},
header_name=header_name,
header_value=header_value,
)
elif header_lower.startswith("cookie"):
return AuthInfo(
auth_type="cookie",
credentials={"cookie": header_value},
header_name=header_name,
header_value=header_value,
)
return None
def match_auth(
self, request_headers: dict[str, str], expected_auth: AuthInfo
) -> bool:
request_auth = self.extract_auth(request_headers)
if request_auth is None:
return expected_auth is None
if expected_auth.auth_type == "bearer":
return request_auth.auth_type == "bearer"
elif expected_auth.auth_type == "basic":
return request_auth.auth_type == "basic"
elif expected_auth.auth_type == "api_key":
return request_auth.auth_type == "api_key"
elif expected_auth.auth_type == "cookie":
return request_auth.auth_type == "cookie"
return False
def generate_auth_header(self, auth_type: str, credentials: dict[str, str]) -> tuple[str, str]:
if auth_type == "bearer":
token = credentials.get("token", "")
return "Authorization", f"Bearer {token}"
elif auth_type == "basic":
username = credentials.get("username", "")
password = credentials.get("password", "")
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
return "Authorization", f"Basic {encoded}"
elif auth_type == "api_key":
key = credentials.get("api_key", "")
return "X-API-Key", key
elif auth_type == "cookie":
return "Cookie", credentials.get("cookie", "")
return "", ""
def apply_auth_to_headers(
self, headers: dict[str, str], auth_info: AuthInfo
) -> dict[str, str]:
new_headers = dict(headers)
new_headers[auth_info.header_name] = auth_info.header_value
return new_headers

View File

@@ -0,0 +1,38 @@
import json
import os
from pathlib import Path
from typing import Any
def read_file(path: str) -> str:
with open(path, encoding="utf-8") as f:
return f.read()
def write_file(path: str, content: str) -> None:
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
def read_json(path: str) -> dict[str, Any]:
with open(path, encoding="utf-8") as f:
return json.load(f)
def write_json(path: str, data: dict[str, Any], indent: int = 2) -> None:
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent)
def ensure_directory(path: str) -> None:
Path(path).mkdir(parents=True, exist_ok=True)
def file_exists(path: str) -> bool:
return os.path.isfile(path)
def get_file_extension(path: str) -> str:
return os.path.splitext(path)[1].lower()

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@@ -0,0 +1,3 @@
"""Agentic Codebase Memory Manager - A centralized memory store for AI coding agents."""
__version__ = "0.1.0"

View File

View File

@@ -0,0 +1,207 @@
"""FastAPI REST API for the memory manager."""
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from memory_manager import __version__
from memory_manager.api.schemas import (
CommitCreate,
CommitResponse,
DiffResponse,
HealthResponse,
MemoryEntryCreate,
MemoryEntryResponse,
MemoryEntryUpdate,
StatsResponse,
)
from memory_manager.core.services import MemoryManager
from memory_manager.db.models import MemoryCategory
from memory_manager.db.repository import MemoryRepository
db_path = os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
repository = MemoryRepository(db_path)
memory_manager = MemoryManager(repository)
@asynccontextmanager
async def lifespan(app: FastAPI):
await memory_manager.initialize()
yield
await memory_manager.close()
app = FastAPI(
title="Agentic Codebase Memory Manager",
description="A centralized memory store for AI coding agents",
version=__version__,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(status="ok", version=__version__)
@app.get("/api/memory", response_model=list[MemoryEntryResponse])
async def list_memory(
category: str | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = Query(default=100, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
):
category_enum = None
if category:
try:
category_enum = MemoryCategory(category)
except ValueError:
raise HTTPException(status_code=422, detail=f"Invalid category: {category}")
entries = await memory_manager.memory_service.list_entries(
category=category_enum,
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
return entries
@app.post("/api/memory", response_model=MemoryEntryResponse, status_code=201)
async def create_memory(entry: MemoryEntryCreate):
result = await memory_manager.memory_service.create_entry(
title=entry.title,
content=entry.content,
category=entry.category,
tags=entry.tags,
agent_id=entry.agent_id,
project_path=entry.project_path,
)
return result
@app.get("/api/memory/log", response_model=list[CommitResponse])
async def get_log(
agent_id: str | None = None,
project_path: str | None = None,
limit: int = Query(default=100, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
):
commits = await memory_manager.commit_service.list_commits(
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
return commits
@app.get("/api/memory/stats", response_model=StatsResponse)
async def get_stats(project_path: str | None = None):
entries = await memory_manager.memory_service.list_entries(
project_path=project_path,
limit=10000,
)
entries_by_category: dict[str, int] = {}
for entry in entries:
cat = entry["category"]
entries_by_category[cat] = entries_by_category.get(cat, 0) + 1
commits = await memory_manager.commit_service.list_commits(
project_path=project_path,
limit=10000,
)
return StatsResponse(
total_entries=len(entries),
entries_by_category=entries_by_category,
total_commits=len(commits),
)
@app.get("/api/memory/search", response_model=list[MemoryEntryResponse])
async def search_memory(
q: str = Query(..., min_length=1),
category: str | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = Query(default=100, ge=1, le=1000),
):
category_enum = None
if category:
try:
category_enum = MemoryCategory(category)
except ValueError:
raise HTTPException(status_code=422, detail=f"Invalid category: {category}")
results = await memory_manager.search_service.search(
query=q,
category=category_enum,
agent_id=agent_id,
project_path=project_path,
limit=limit,
)
return results
@app.post("/api/memory/commit", response_model=CommitResponse, status_code=201)
async def create_commit(commit: CommitCreate):
result = await memory_manager.commit_service.create_commit(
message=commit.message,
agent_id=commit.agent_id,
project_path=commit.project_path,
)
return result
@app.get("/api/memory/diff/{hash1}/{hash2}", response_model=DiffResponse)
async def get_diff(hash1: str, hash2: str):
diff = await memory_manager.commit_service.diff(hash1, hash2)
if not diff:
raise HTTPException(
status_code=404,
detail=f"Commit(s) not found: {hash1}, {hash2}. Check available commits with /api/memory/log"
)
return diff
@app.get("/api/memory/{entry_id}", response_model=MemoryEntryResponse)
async def get_memory(entry_id: int):
entry = await memory_manager.memory_service.get_entry(entry_id)
if not entry:
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")
return entry
@app.put("/api/memory/{entry_id}", response_model=MemoryEntryResponse)
async def update_memory(entry_id: int, entry: MemoryEntryUpdate):
result = await memory_manager.memory_service.update_entry(
entry_id=entry_id,
title=entry.title,
content=entry.content,
category=entry.category,
tags=entry.tags,
)
if not result:
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")
return result
@app.delete("/api/memory/{entry_id}", status_code=204)
async def delete_memory(entry_id: int):
deleted = await memory_manager.memory_service.delete_entry(entry_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Entry {entry_id} not found")

View File

@@ -0,0 +1,79 @@
"""Pydantic schemas for API request/response validation."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from memory_manager.db.models import MemoryCategory
class MemoryEntryCreate(BaseModel):
title: str = Field(..., min_length=1, max_length=255)
content: str = Field(..., min_length=1)
category: MemoryCategory
tags: list[str] = Field(default_factory=list)
agent_id: str | None = None
project_path: str | None = None
class MemoryEntryUpdate(BaseModel):
title: str | None = Field(None, min_length=1, max_length=255)
content: str | None = Field(None, min_length=1)
category: MemoryCategory | None = None
tags: list[str] | None = None
class MemoryEntryResponse(BaseModel):
id: int
title: str
content: str
category: str
tags: list[str]
agent_id: str
project_path: str
created_at: datetime | None
updated_at: datetime | None
class SearchQuery(BaseModel):
q: str = Field(..., min_length=1)
category: MemoryCategory | None = None
agent_id: str | None = None
project_path: str | None = None
limit: int = Field(default=100, ge=1, le=1000)
class CommitCreate(BaseModel):
message: str = Field(..., min_length=1)
agent_id: str | None = None
project_path: str | None = None
class CommitResponse(BaseModel):
id: int
hash: str
message: str
agent_id: str
project_path: str
snapshot: list[dict[str, Any]]
created_at: datetime | None
class DiffResponse(BaseModel):
commit1: dict[str, Any]
commit2: dict[str, Any]
added: list[dict[str, Any]]
removed: list[dict[str, Any]]
modified: list[dict[str, Any]]
class HealthResponse(BaseModel):
status: str
version: str
class StatsResponse(BaseModel):
total_entries: int
entries_by_category: dict[str, int]
total_commits: int

View File

View File

@@ -0,0 +1,340 @@
"""CLI interface for the memory manager using Click."""
import asyncio
import os
from datetime import datetime
import click
from memory_manager import __version__
from memory_manager.core.services import MemoryManager
from memory_manager.db.models import MemoryCategory
from memory_manager.db.repository import MemoryRepository
def get_db_path() -> str:
return os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
async def get_memory_manager() -> MemoryManager:
repository = MemoryRepository(get_db_path())
await repository.initialize()
manager = MemoryManager(repository)
return manager
def validate_category(ctx, param, value):
if value is None:
return None
try:
return MemoryCategory(value)
except ValueError:
raise click.BadParameter(f"Invalid category. Must be one of: {[c.value for c in MemoryCategory]}")
@click.group()
@click.version_option(version=__version__)
def cli():
"""Agentic Codebase Memory Manager - A centralized memory store for AI coding agents."""
pass
@cli.command()
@click.option("--title", "-t", required=True, help="Entry title")
@click.option("--content", "-c", required=True, help="Entry content")
@click.option("--category", "-g", required=True, callback=validate_category, help="Entry category")
@click.option("--tags", "-T", multiple=True, help="Entry tags")
@click.option("--agent-id", help="Agent ID (defaults to AGENT_ID env var)")
@click.option("--project-path", help="Project path (defaults to MEMORY_PROJECT_PATH env var)")
def add(title, content, category, tags, agent_id, project_path):
"""Add a new memory entry."""
asyncio.run(_add(title, content, category, list(tags), agent_id, project_path))
async def _add(title, content, category, tags, agent_id, project_path):
manager = await get_memory_manager()
try:
entry = await manager.memory_service.create_entry(
title=title,
content=content,
category=category,
tags=tags,
agent_id=agent_id,
project_path=project_path,
)
click.echo(f"Created entry {entry['id']}: {entry['title']}")
finally:
await manager.close()
@cli.command()
@click.option("--category", "-g", callback=validate_category, help="Filter by category")
@click.option("--agent-id", help="Filter by agent ID")
@click.option("--project-path", help="Filter by project path")
@click.option("--limit", "-n", default=100, help="Number of entries to show")
@click.option("--offset", default=0, help="Offset for pagination")
def list(category, agent_id, project_path, limit, offset):
"""List memory entries."""
asyncio.run(_list(category, agent_id, project_path, limit, offset))
async def _list(category, agent_id, project_path, limit, offset):
manager = await get_memory_manager()
try:
entries = await manager.memory_service.list_entries(
category=category,
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
if not entries:
click.echo("No entries found.")
return
for entry in entries:
created = entry["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
click.echo(f"[{entry['id']}] {entry['category']} | {entry['title']} | {created} | {entry['agent_id']}")
click.echo(f" {entry['content'][:100]}...")
if entry["tags"]:
click.echo(f" Tags: {', '.join(entry['tags'])}")
click.echo()
finally:
await manager.close()
@cli.command()
@click.argument("query")
@click.option("--category", "-g", callback=validate_category, help="Filter by category")
@click.option("--agent-id", help="Filter by agent ID")
@click.option("--project-path", help="Filter by project path")
@click.option("--limit", "-n", default=100, help="Number of results")
def search(query, category, agent_id, project_path, limit):
"""Search memory entries."""
asyncio.run(_search(query, category, agent_id, project_path, limit))
async def _search(query, category, agent_id, project_path, limit):
manager = await get_memory_manager()
try:
results = await manager.search_service.search(
query=query,
category=category,
agent_id=agent_id,
project_path=project_path,
limit=limit,
)
if not results:
click.echo("No results found.")
return
click.echo(f"Found {len(results)} result(s):\n")
for entry in results:
created = entry["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
click.echo(f"[{entry['id']}] {entry['category']} | {entry['title']} | {created}")
click.echo(f" {entry['content'][:200]}...")
if entry["tags"]:
click.echo(f" Tags: {', '.join(entry['tags'])}")
click.echo()
finally:
await manager.close()
@cli.command()
@click.argument("entry_id", type=int)
def get(entry_id):
"""Get a specific memory entry by ID."""
asyncio.run(_get(entry_id))
async def _get(entry_id):
manager = await get_memory_manager()
try:
entry = await manager.memory_service.get_entry(entry_id)
if not entry:
click.echo(f"Entry {entry_id} not found.", err=True)
return
click.echo(f"ID: {entry['id']}")
click.echo(f"Title: {entry['title']}")
click.echo(f"Category: {entry['category']}")
click.echo(f"Agent: {entry['agent_id']}")
click.echo(f"Project: {entry['project_path']}")
click.echo(f"Tags: {', '.join(entry['tags']) if entry['tags'] else '(none)'}")
click.echo(f"Created: {entry['created_at']}")
click.echo(f"Updated: {entry['updated_at']}")
click.echo(f"\nContent:\n{entry['content']}")
finally:
await manager.close()
@cli.command()
@click.argument("entry_id", type=int)
@click.option("--title", "-t", help="New title")
@click.option("--content", "-c", help="New content")
@click.option("--category", "-g", callback=validate_category, help="New category")
@click.option("--tags", "-T", multiple=True, help="New tags")
def update(entry_id, title, content, category, tags):
"""Update a memory entry."""
asyncio.run(_update(entry_id, title, content, category, list(tags) if tags else None))
async def _update(entry_id, title, content, category, tags):
manager = await get_memory_manager()
try:
result = await manager.memory_service.update_entry(
entry_id=entry_id,
title=title,
content=content,
category=category,
tags=tags,
)
if not result:
click.echo(f"Entry {entry_id} not found.", err=True)
return
click.echo(f"Updated entry {entry_id}: {result['title']}")
finally:
await manager.close()
@cli.command()
@click.argument("entry_id", type=int)
def delete(entry_id):
"""Delete a memory entry."""
asyncio.run(_delete(entry_id))
async def _delete(entry_id):
manager = await get_memory_manager()
try:
deleted = await manager.memory_service.delete_entry(entry_id)
if not deleted:
click.echo(f"Entry {entry_id} not found.", err=True)
return
click.echo(f"Deleted entry {entry_id}.")
finally:
await manager.close()
@cli.command()
@click.option("--message", "-m", required=True, help="Commit message")
@click.option("--agent-id", help="Agent ID")
@click.option("--project-path", help="Project path")
def commit(message, agent_id, project_path):
"""Create a commit snapshot of current memory state."""
asyncio.run(_commit(message, agent_id, project_path))
async def _commit(message, agent_id, project_path):
manager = await get_memory_manager()
try:
result = await manager.commit_service.create_commit(
message=message,
agent_id=agent_id,
project_path=project_path,
)
click.echo(f"Created commit {result['hash']}: {result['message']}")
finally:
await manager.close()
@cli.command()
@click.option("--agent-id", help="Filter by agent ID")
@click.option("--project-path", help="Filter by project path")
@click.option("--limit", "-n", default=100, help="Number of commits to show")
@click.option("--offset", default=0, help="Offset for pagination")
def log(agent_id, project_path, limit, offset):
"""Show commit history."""
asyncio.run(_log(agent_id, project_path, limit, offset))
async def _log(agent_id, project_path, limit, offset):
manager = await get_memory_manager()
try:
commits = await manager.commit_service.list_commits(
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
if not commits:
click.echo("No commits found.")
return
for commit in commits:
created = commit["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M")
click.echo(f"commit {commit['hash']}")
click.echo(f"Author: {commit['agent_id']}")
click.echo(f"Date: {created}")
click.echo(f"\n {commit['message']}\n")
finally:
await manager.close()
@cli.command()
@click.argument("hash1")
@click.argument("hash2")
def diff(hash1, hash2):
"""Show diff between two commits."""
asyncio.run(_diff(hash1, hash2))
async def _diff(hash1, hash2):
manager = await get_memory_manager()
try:
result = await manager.commit_service.diff(hash1, hash2)
if not result:
click.echo("One or both commits not found. Check available commits with 'memory log'.", err=True)
return
if result["added"]:
click.echo("Added entries:")
for entry in result["added"]:
click.echo(f" + [{entry['id']}] {entry['title']}")
if result["removed"]:
click.echo("\nRemoved entries:")
for entry in result["removed"]:
click.echo(f" - [{entry['id']}] {entry['title']}")
if result["modified"]:
click.echo("\nModified entries:")
for mod in result["modified"]:
click.echo(f" ~ [{mod['after']['id']}] {mod['after']['title']}")
if not any([result["added"], result["removed"], result["modified"]]):
click.echo("No differences found.")
finally:
await manager.close()
@cli.command()
@click.option("--host", default=os.getenv("MEMORY_API_HOST", "127.0.0.1"), help="Server host")
@click.option("--port", default=int(os.getenv("MEMORY_API_PORT", "8080")), help="Server port")
def serve(host, port):
"""Start the API server."""
import uvicorn
from memory_manager.api.app import app as fastapi_app
click.echo(f"Starting server at http://{host}:{port}")
click.echo(f"API documentation available at http://{host}:{port}/docs")
uvicorn.run(fastapi_app, host=host, port=port)
@cli.command()
def tui():
"""Launch the TUI dashboard."""
from memory_manager.tui.app import TUIApp
app = TUIApp()
app.run()
if __name__ == "__main__":
cli()

View File

View File

@@ -0,0 +1,210 @@
"""Core business logic services for the memory manager."""
import hashlib
import os
from typing import Any
from memory_manager.db.models import MemoryCategory
from memory_manager.db.repository import MemoryRepository
class MemoryService:
def __init__(self, repository: MemoryRepository):
self.repository = repository
async def create_entry(
self,
title: str,
content: str,
category: str | MemoryCategory,
tags: list[str] | None = None,
agent_id: str | None = None,
project_path: str | None = None,
) -> dict[str, Any]:
if isinstance(category, str):
category = MemoryCategory(category)
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
entry = await self.repository.create_entry(
title=title,
content=content,
category=category,
tags=tags or [],
agent_id=agent_id,
project_path=project_path,
)
return entry.to_dict()
async def get_entry(self, entry_id: int) -> dict[str, Any] | None:
entry = await self.repository.get_entry(entry_id)
return entry.to_dict() if entry else None
async def update_entry(
self,
entry_id: int,
title: str | None = None,
content: str | None = None,
category: str | MemoryCategory | None = None,
tags: list[str] | None = None,
) -> dict[str, Any] | None:
if category is not None and isinstance(category, str):
category = MemoryCategory(category)
entry = await self.repository.update_entry(
entry_id=entry_id,
title=title,
content=content,
category=category,
tags=tags,
)
return entry.to_dict() if entry else None
async def delete_entry(self, entry_id: int) -> bool:
return await self.repository.delete_entry(entry_id)
async def list_entries(
self,
category: str | MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
if category is not None and isinstance(category, str):
category = MemoryCategory(category)
entries = await self.repository.list_entries(
category=category,
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
return [entry.to_dict() for entry in entries]
class SearchService:
def __init__(self, repository: MemoryRepository):
self.repository = repository
async def search(
self,
query: str,
category: str | MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
) -> list[dict[str, Any]]:
if category is not None and isinstance(category, str):
category = MemoryCategory(category)
entries = await self.repository.search_entries(
query_text=query,
category=category,
agent_id=agent_id,
project_path=project_path,
limit=limit,
)
return [entry.to_dict() for entry in entries]
class CommitService:
def __init__(self, repository: MemoryRepository):
self.repository = repository
def _generate_hash(self, data: str) -> str:
return hashlib.sha1(data.encode()).hexdigest()
async def create_commit(
self,
message: str,
agent_id: str | None = None,
project_path: str | None = None,
) -> dict[str, Any]:
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
snapshot = await self.repository.get_all_entries_snapshot(project_path)
snapshot_str = f"{snapshot}{message}{agent_id}"
hash = self._generate_hash(snapshot_str)
commit = await self.repository.create_commit(
hash=hash,
message=message,
agent_id=agent_id,
project_path=project_path,
snapshot=snapshot,
)
return commit.to_dict()
async def get_commit(self, hash: str) -> dict[str, Any] | None:
commit = await self.repository.get_commit(hash)
return commit.to_dict() if commit else None
async def list_commits(
self,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
commits = await self.repository.list_commits(
agent_id=agent_id,
project_path=project_path,
limit=limit,
offset=offset,
)
return [commit.to_dict() for commit in commits]
async def diff(self, hash1: str, hash2: str) -> dict[str, Any] | None:
commit1 = await self.repository.get_commit(hash1)
commit2 = await self.repository.get_commit(hash2)
if not commit1 or not commit2:
return None
snapshot1 = {entry["id"]: entry for entry in commit1.snapshot}
snapshot2 = {entry["id"]: entry for entry in commit2.snapshot}
all_ids = set(snapshot1.keys()) | set(snapshot2.keys())
added = []
removed = []
modified = []
for entry_id in all_ids:
if entry_id not in snapshot1:
added.append(snapshot2[entry_id])
elif entry_id not in snapshot2:
removed.append(snapshot1[entry_id])
else:
if snapshot1[entry_id] != snapshot2[entry_id]:
modified.append({
"before": snapshot1[entry_id],
"after": snapshot2[entry_id],
})
return {
"commit1": commit1.to_dict(),
"commit2": commit2.to_dict(),
"added": added,
"removed": removed,
"modified": modified,
}
class MemoryManager:
def __init__(self, repository: MemoryRepository):
self.repository = repository
self.memory_service = MemoryService(repository)
self.search_service = SearchService(repository)
self.commit_service = CommitService(repository)
async def initialize(self) -> None:
await self.repository.initialize()
async def close(self) -> None:
await self.repository.close()

View File

View File

@@ -0,0 +1,134 @@
"""SQLAlchemy database models for the memory manager."""
import enum
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, DateTime, Enum, Index, String, Text, func, text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class MemoryCategory(enum.StrEnum):
DECISION = "decision"
FEATURE = "feature"
REFACTORING = "refactoring"
ARCHITECTURE = "architecture"
BUG = "bug"
NOTE = "note"
class MemoryEntry(Base):
__tablename__ = "memory_entries"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
category: Mapped[MemoryCategory] = mapped_column(Enum(MemoryCategory), nullable=False)
tags: Mapped[list[str]] = mapped_column(JSON, default=list)
agent_id: Mapped[str] = mapped_column(String(100), nullable=False)
project_path: Mapped[str] = mapped_column(String(500), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now())
__table_args__ = (
Index("idx_category", "category"),
Index("idx_agent_id", "agent_id"),
Index("idx_project_path", "project_path"),
Index("idx_created_at", "created_at"),
)
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"title": self.title,
"content": self.content,
"category": self.category.value,
"tags": self.tags,
"agent_id": self.agent_id,
"project_path": self.project_path,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class Commit(Base):
__tablename__ = "commits"
id: Mapped[int] = mapped_column(primary_key=True)
hash: Mapped[str] = mapped_column(String(40), unique=True, nullable=False)
message: Mapped[str] = mapped_column(Text, nullable=False)
agent_id: Mapped[str] = mapped_column(String(100), nullable=False)
project_path: Mapped[str] = mapped_column(String(500), nullable=False)
snapshot: Mapped[list[dict[str, Any]]] = mapped_column(JSON, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
__table_args__ = (
Index("idx_commit_hash", "hash"),
Index("idx_commit_agent_id", "agent_id"),
Index("idx_commit_created_at", "created_at"),
)
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"hash": self.hash,
"message": self.message,
"agent_id": self.agent_id,
"project_path": self.project_path,
"snapshot": self.snapshot,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class Tag(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now())
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
async def init_db(db_path: str):
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await conn.execute(text(
"CREATE VIRTUAL TABLE IF NOT EXISTS memory_entries_fts USING fts5("
"title, content, tags, category, agent_id, project_path, content='memory_entries', content_rowid='id')"
))
await conn.execute(text(
"CREATE TRIGGER IF NOT EXISTS memory_entries_ai AFTER INSERT ON memory_entries BEGIN "
"INSERT INTO memory_entries_fts(rowid, title, content, tags, category, agent_id, project_path) "
"VALUES (new.id, new.title, new.content, new.tags, new.category, new.agent_id, new.project_path); END"
))
await conn.execute(text(
"CREATE TRIGGER IF NOT EXISTS memory_entries_ad AFTER DELETE ON memory_entries BEGIN "
"INSERT INTO memory_entries_fts(memory_entries_fts, rowid, title, content, tags, category, agent_id, project_path) "
"VALUES ('delete', old.id, old.title, old.content, old.tags, old.category, old.agent_id, old.project_path); END"
))
await conn.execute(text(
"CREATE TRIGGER IF NOT EXISTS memory_entries_au AFTER UPDATE ON memory_entries BEGIN "
"INSERT INTO memory_entries_fts(memory_entries_fts, rowid, title, content, tags, category, agent_id, project_path) "
"VALUES ('delete', old.id, old.title, old.content, old.tags, old.category, old.agent_id, old.project_path); "
"INSERT INTO memory_entries_fts(rowid, title, content, tags, category, agent_id, project_path) "
"VALUES (new.id, new.title, new.content, new.tags, new.category, new.agent_id, new.project_path); END"
))
return engine

View File

@@ -0,0 +1,232 @@
"""Async repository for database operations."""
import os
from typing import Any
from sqlalchemy import delete, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
class MemoryRepository:
def __init__(self, db_path: str):
self.db_path = db_path
self.engine: Any = None
self._session_factory: async_sessionmaker[AsyncSession] | None = None
async def initialize(self) -> None:
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
self.engine = await init_db(self.db_path)
self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
async def get_session(self) -> AsyncSession:
if not self._session_factory:
await self.initialize()
assert self._session_factory is not None
return self._session_factory()
async def create_entry(
self,
title: str,
content: str,
category: MemoryCategory,
tags: list[str],
agent_id: str,
project_path: str,
) -> MemoryEntry:
async with await self.get_session() as session:
entry = MemoryEntry(
title=title,
content=content,
category=category,
tags=tags,
agent_id=agent_id,
project_path=project_path,
)
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
async def get_entry(self, entry_id: int) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
return result.scalar_one_or_none()
async def update_entry(
self,
entry_id: int,
title: str | None = None,
content: str | None = None,
category: MemoryCategory | None = None,
tags: list[str] | None = None,
) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
entry = result.scalar_one_or_none()
if not entry:
return None
if title is not None:
entry.title = title
if content is not None:
entry.content = content
if category is not None:
entry.category = category
if tags is not None:
entry.tags = tags
await session.commit()
await session.refresh(entry)
return entry
async def delete_entry(self, entry_id: int) -> bool:
async with await self.get_session() as session:
result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
await session.commit()
rowcount = getattr(result, "rowcount", 0)
return rowcount is not None and rowcount > 0
async def list_entries(
self,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def search_entries(
self,
query_text: str,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
fts_query = f'"{query_text}"'
sql_parts = ["""
SELECT m.* FROM memory_entries m
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
WHERE memory_entries_fts MATCH :query
"""]
params: dict[str, Any] = {"query": fts_query}
if category:
sql_parts.append(" AND m.category = :category")
params["category"] = category.value
if agent_id:
sql_parts.append(" AND m.agent_id = :agent_id")
params["agent_id"] = agent_id
if project_path:
sql_parts.append(" AND m.project_path = :project_path")
params["project_path"] = project_path
sql_parts.append(" LIMIT :limit")
params["limit"] = limit
sql = text("".join(sql_parts))
result = await session.execute(sql, params)
rows = result.fetchall()
entries = []
for row in rows:
entry = MemoryEntry(
id=row.id,
title=row.title,
content=row.content,
category=MemoryCategory(row.category),
tags=row.tags,
agent_id=row.agent_id,
project_path=row.project_path,
created_at=row.created_at,
updated_at=row.updated_at,
)
entries.append(entry)
return entries
async def create_commit(
self,
hash: str,
message: str,
agent_id: str,
project_path: str,
snapshot: list[dict[str, Any]],
) -> Commit:
async with await self.get_session() as session:
commit = Commit(
hash=hash,
message=message,
agent_id=agent_id,
project_path=project_path,
snapshot=snapshot,
)
session.add(commit)
await session.commit()
await session.refresh(commit)
return commit
async def get_commit(self, hash: str) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.hash == hash))
return result.scalar_one_or_none()
async def get_commit_by_id(self, commit_id: int) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.id == commit_id))
return result.scalar_one_or_none()
async def list_commits(
self,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[Commit]:
async with await self.get_session() as session:
query = select(Commit)
if agent_id:
query = query.where(Commit.agent_id == agent_id)
if project_path:
query = query.where(Commit.project_path == project_path)
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
result = await session.execute(query)
entries = result.scalars().all()
return [entry.to_dict() for entry in entries]
async def close(self) -> None:
if self.engine:
await self.engine.dispose()

View File

View File

@@ -0,0 +1,363 @@
"""Textual TUI application for the memory manager."""
import os
from datetime import datetime
from typing import Any
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
from textual.screen import Screen
from textual.widgets import (
Button,
Footer,
Header,
Input,
Label,
ListItem,
ListView,
Static,
)
from memory_manager.core.services import MemoryManager
from memory_manager.db.models import MemoryCategory
from memory_manager.db.repository import MemoryRepository
db_path = os.getenv("MEMORY_DB_PATH", ".memory/codebase_memory.db")
async def get_memory_manager() -> MemoryManager:
repository = MemoryRepository(db_path)
await repository.initialize()
manager = MemoryManager(repository)
return manager
class DashboardScreen(Screen):
CSS = """
Screen {
background: $surface;
}
.stats-container {
height: auto;
padding: 1 2;
background: $panel;
border: solid $border;
}
.stat-label {
color: $text-muted;
}
.stat-value {
color: $text;
bold: true;
}
.entry-list {
height: 1fr;
}
"""
def __init__(self, manager: MemoryManager):
super().__init__()
self.manager = manager
def compose(self) -> ComposeResult:
yield Header()
yield Container(
Vertical(
Label("Memory Manager Dashboard", classes="header-title"),
ScrollableContainer(classes="stats-container", id="stats-container"),
ListView(id="recent-entries", classes="entry-list"),
id="dashboard-content",
)
)
yield Footer()
async def on_mount(self) -> None:
await self.load_stats()
async def load_stats(self) -> None:
stats_container = self.query_one("#stats-container", ScrollableContainer)
entries = await self.manager.memory_service.list_entries(limit=10000)
commits = await self.manager.commit_service.list_commits(limit=10000)
entries_by_category: dict[str, int] = {}
for entry in entries:
cat = entry["category"]
entries_by_category[cat] = entries_by_category.get(cat, 0) + 1
stats_text = f"""
Total Entries: {len(entries)} | Total Commits: {len(commits)}
Entries by Category:
"""
for cat, count in entries_by_category.items():
stats_text += f" {cat}: {count}\n"
stats_container.remove_children()
stats_container.mount(Static(stats_text))
list_view = self.query_one("#recent-entries", ListView)
list_view.clear()
for entry in entries[:10]:
created = entry["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
list_item = ListItem(
Label(f"[{entry['category']}] {entry['title'][:40]} - {created}"),
)
await list_view.mount(list_item)
class MemoryListScreen(Screen):
CSS = """
Screen {
background: $surface;
}
.filter-bar {
height: 3;
padding: 1;
background: $panel;
}
.entry-detail {
height: 1fr;
padding: 1 2;
}
"""
def __init__(self, manager: MemoryManager, category: MemoryCategory | None = None):
super().__init__()
self.manager = manager
self.current_category = category
self.entries: list[dict[str, Any]] = []
def compose(self) -> ComposeResult:
yield Header()
yield Container(
Horizontal(
Button("All", id="filter-all"),
Button("Decision", id="filter-decision"),
Button("Feature", id="filter-feature"),
Button("Refactoring", id="filter-refactoring"),
Button("Architecture", id="filter-architecture"),
Button("Bug", id="filter-bug"),
Button("Note", id="filter-note"),
classes="filter-bar",
),
Horizontal(
ListView(id="entries-list", classes="column"),
ScrollableContainer(id="entry-detail", classes="column entry-detail"),
classes="main-content",
),
)
yield Footer()
async def on_mount(self) -> None:
await self.load_entries()
async def load_entries(self, category: MemoryCategory | None = None) -> None:
self.entries = await self.manager.memory_service.list_entries(
category=category,
limit=1000,
)
list_view = self.query_one("#entries-list", ListView)
list_view.clear()
for entry in self.entries:
created = entry["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
list_item = ListItem(
Label(f"[{entry['category']}] {entry['title']}"),
Label(f"{entry['agent_id']} | {created}"),
)
await list_view.mount(list_item)
@on(ListView.Selected)
async def on_entry_selected(self, event: ListView.Selected) -> None:
index = event.list_view.index
if index is not None and 0 <= index < len(self.entries):
entry = self.entries[index]
detail_container = self.query_one("#entry-detail", ScrollableContainer)
detail_container.remove_children()
content = f"""
Title: {entry['title']}
Category: {entry['category']}
Agent: {entry['agent_id']}
Project: {entry['project_path']}
Tags: {', '.join(entry['tags']) if entry['tags'] else '(none)'}
Created: {entry['created_at']}
Updated: {entry['updated_at']}
Content:
{entry['content']}
"""
await detail_container.mount(Static(content))
@on(Button.Pressed)
async def on_filter_pressed(self, event: Button.Pressed) -> None:
button_id = event.button.id
category = None
if button_id == "filter-decision":
category = MemoryCategory.DECISION
elif button_id == "filter-feature":
category = MemoryCategory.FEATURE
elif button_id == "filter-refactoring":
category = MemoryCategory.REFACTORING
elif button_id == "filter-architecture":
category = MemoryCategory.ARCHITECTURE
elif button_id == "filter-bug":
category = MemoryCategory.BUG
elif button_id == "filter-note":
category = MemoryCategory.NOTE
await self.load_entries(category)
class CommitHistoryScreen(Screen):
CSS = """
Screen {
background: $surface;
}
.commit-list {
height: 1fr;
}
"""
def __init__(self, manager: MemoryManager):
super().__init__()
self.manager = manager
self.commits: list[dict[str, Any]] = []
def compose(self) -> ComposeResult:
yield Header()
yield Container(
ScrollableContainer(
ListView(id="commits-list", classes="commit-list"),
id="commits-container",
),
)
yield Footer()
async def on_mount(self) -> None:
await self.load_commits()
async def load_commits(self) -> None:
self.commits = await self.manager.commit_service.list_commits(limit=100)
list_view = self.query_one("#commits-list", ListView)
list_view.clear()
for commit in self.commits:
created = commit["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%Y-%m-%d %H:%M:%S")
content = f"commit {commit['hash'][:8]}\n{commit['agent_id']} | {created}\n\n {commit['message']}"
list_item = ListItem(
Static(content, markup=False),
)
await list_view.mount(list_item)
class SearchScreen(Screen):
CSS = """
Screen {
background: $surface;
}
.search-input {
height: 3;
padding: 1;
}
.results-list {
height: 1fr;
}
"""
def __init__(self, manager: MemoryManager):
super().__init__()
self.manager = manager
self.results: list[dict[str, Any]] = []
def compose(self) -> ComposeResult:
yield Header()
yield Container(
Horizontal(
Input(placeholder="Search query...", id="search-input"),
Button("Search", id="search-button"),
classes="search-input",
),
ScrollableContainer(
ListView(id="results-list", classes="results-list"),
id="results-container",
),
)
yield Footer()
@on(Button.Pressed, "#search-button")
@on(Input.Submitted, "#search-input")
async def on_search(self) -> None:
input_widget = self.query_one("#search-input", Input)
query = input_widget.value
if not query:
return
self.results = await self.manager.search_service.search(query=query, limit=100)
list_view = self.query_one("#results-list", ListView)
list_view.clear()
for entry in self.results:
created = entry["created_at"]
if created:
created = datetime.fromisoformat(created).strftime("%m/%d %H:%M")
list_item = ListItem(
Label(f"[{entry['category']}] {entry['title'][:40]}"),
Label(f"{entry['content'][:80]}... | {created}"),
)
await list_view.mount(list_item)
class TUIApp(App):
BINDINGS = [
Binding("d", "switch_dashboard", "Dashboard"),
Binding("l", "switch_memory_list", "Memory List"),
Binding("c", "switch_commits", "Commits"),
Binding("s", "switch_search", "Search"),
Binding("q", "quit", "Quit"),
]
def __init__(self):
super().__init__()
self.manager = None
async def on_mount(self) -> None:
self.manager = await get_memory_manager()
await self.push_screen(DashboardScreen(self.manager))
async def on_unmount(self) -> None:
if self.manager:
await self.manager.close()
async def switch_dashboard(self) -> None:
if self.manager:
await self.push_screen(DashboardScreen(self.manager))
async def switch_memory_list(self) -> None:
if self.manager:
await self.push_screen(MemoryListScreen(self.manager))
async def switch_commits(self) -> None:
if self.manager:
await self.push_screen(CommitHistoryScreen(self.manager))
async def switch_search(self) -> None:
if self.manager:
await self.push_screen(SearchScreen(self.manager))