Add memory_manager source files and tests
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
"""Core services for memory management."""
|
||||
from typing import List, Optional
|
||||
from memory_manager.db.models import MemoryCategory, MemoryEntry, Commit
|
||||
"""Core business logic services for the memory manager."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from memory_manager.db.models import MemoryCategory
|
||||
from memory_manager.db.repository import MemoryRepository
|
||||
|
||||
|
||||
@@ -12,122 +16,195 @@ class MemoryService:
|
||||
self,
|
||||
title: str,
|
||||
content: str,
|
||||
category: str,
|
||||
tags: List[str],
|
||||
agent_id: str = "unknown",
|
||||
project_path: str = ".",
|
||||
parent_id: Optional[int] = None,
|
||||
) -> MemoryEntry:
|
||||
category_enum = MemoryCategory(category)
|
||||
return await self.repository.create_entry(
|
||||
category: str | MemoryCategory,
|
||||
tags: list[str] | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
|
||||
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
|
||||
|
||||
entry = await self.repository.create_entry(
|
||||
title=title,
|
||||
content=content,
|
||||
category=category_enum,
|
||||
tags=tags,
|
||||
category=category,
|
||||
tags=tags or [],
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
return entry.to_dict()
|
||||
|
||||
async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]:
|
||||
return await self.repository.get_entry(entry_id)
|
||||
|
||||
async def list_entries(
|
||||
self,
|
||||
category: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[MemoryEntry]:
|
||||
category_enum = MemoryCategory(category) if category else None
|
||||
return await self.repository.list_entries(
|
||||
category=category_enum,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
async def get_entry(self, entry_id: int) -> dict[str, Any] | None:
|
||||
entry = await self.repository.get_entry(entry_id)
|
||||
return entry.to_dict() if entry else None
|
||||
|
||||
async def update_entry(
|
||||
self,
|
||||
entry_id: int,
|
||||
title: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Optional[MemoryEntry]:
|
||||
category_enum = MemoryCategory(category) if category else None
|
||||
return await self.repository.update_entry(
|
||||
title: str | None = None,
|
||||
content: str | None = None,
|
||||
category: str | MemoryCategory | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entry = await self.repository.update_entry(
|
||||
entry_id=entry_id,
|
||||
title=title,
|
||||
content=content,
|
||||
category=category_enum,
|
||||
category=category,
|
||||
tags=tags,
|
||||
)
|
||||
return entry.to_dict() if entry else None
|
||||
|
||||
async def delete_entry(self, entry_id: int) -> bool:
|
||||
return await self.repository.delete_entry(entry_id)
|
||||
|
||||
async def search_entries(
|
||||
async def list_entries(
|
||||
self,
|
||||
category: str | MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entries = await self.repository.list_entries(
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [entry.to_dict() for entry in entries]
|
||||
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
category: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[MemoryEntry]:
|
||||
category_enum = MemoryCategory(category) if category else None
|
||||
return await self.repository.search_entries(
|
||||
query=query,
|
||||
category=category_enum,
|
||||
tags=tags,
|
||||
category: str | MemoryCategory | None = None,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
if category is not None and isinstance(category, str):
|
||||
category = MemoryCategory(category)
|
||||
|
||||
entries = await self.repository.search_entries(
|
||||
query_text=query,
|
||||
category=category,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
)
|
||||
return [entry.to_dict() for entry in entries]
|
||||
|
||||
|
||||
class CommitService:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
|
||||
def _generate_hash(self, data: str) -> str:
|
||||
return hashlib.sha1(data.encode()).hexdigest()
|
||||
|
||||
async def create_commit(
|
||||
self,
|
||||
message: str,
|
||||
agent_id: str = "unknown",
|
||||
project_path: str = ".",
|
||||
) -> Commit:
|
||||
return await self.repository.create_commit(
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
agent_id = agent_id or os.getenv("AGENT_ID", "unknown") or "unknown"
|
||||
project_path = project_path or os.getenv("MEMORY_PROJECT_PATH", ".") or "."
|
||||
|
||||
snapshot = await self.repository.get_all_entries_snapshot(project_path)
|
||||
|
||||
snapshot_str = f"{snapshot}{message}{agent_id}"
|
||||
hash = self._generate_hash(snapshot_str)
|
||||
|
||||
commit = await self.repository.create_commit(
|
||||
hash=hash,
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return commit.to_dict()
|
||||
|
||||
async def get_commits(self, limit: int = 50) -> List[Commit]:
|
||||
return await self.repository.get_commits(limit=limit)
|
||||
async def get_commit(self, hash: str) -> dict[str, Any] | None:
|
||||
commit = await self.repository.get_commit(hash)
|
||||
return commit.to_dict() if commit else None
|
||||
|
||||
async def get_commit(self, commit_hash: str) -> Optional[Commit]:
|
||||
return await self.repository.get_commit(commit_hash)
|
||||
async def list_commits(
|
||||
self,
|
||||
agent_id: str | None = None,
|
||||
project_path: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
commits = await self.repository.list_commits(
|
||||
agent_id=agent_id,
|
||||
project_path=project_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [commit.to_dict() for commit in commits]
|
||||
|
||||
async def diff_commits(self, hash1: str, hash2: str) -> dict:
|
||||
entries1 = await self.repository.get_commit_entries(hash1)
|
||||
entries2 = await self.repository.get_commit_entries(hash2)
|
||||
async def diff(self, hash1: str, hash2: str) -> dict[str, Any] | None:
|
||||
commit1 = await self.repository.get_commit(hash1)
|
||||
commit2 = await self.repository.get_commit(hash2)
|
||||
|
||||
snapshot1 = {e.memory_entry_id: e.entry_snapshot for e in entries1}
|
||||
snapshot2 = {e.memory_entry_id: e.entry_snapshot for e in entries2}
|
||||
if not commit1 or not commit2:
|
||||
return None
|
||||
|
||||
snapshot1 = {entry["id"]: entry for entry in commit1.snapshot}
|
||||
snapshot2 = {entry["id"]: entry for entry in commit2.snapshot}
|
||||
|
||||
all_ids = set(snapshot1.keys()) | set(snapshot2.keys())
|
||||
diff_result = {
|
||||
"hash1": hash1,
|
||||
"hash2": hash2,
|
||||
"added": [],
|
||||
"removed": [],
|
||||
"modified": [],
|
||||
}
|
||||
|
||||
added = []
|
||||
removed = []
|
||||
modified = []
|
||||
|
||||
for entry_id in all_ids:
|
||||
if entry_id not in snapshot1:
|
||||
diff_result["added"].append(snapshot2[entry_id])
|
||||
added.append(snapshot2[entry_id])
|
||||
elif entry_id not in snapshot2:
|
||||
diff_result["removed"].append(snapshot1[entry_id])
|
||||
elif snapshot1[entry_id] != snapshot2[entry_id]:
|
||||
diff_result["modified"].append({
|
||||
"before": snapshot1[entry_id],
|
||||
"after": snapshot2[entry_id],
|
||||
})
|
||||
removed.append(snapshot1[entry_id])
|
||||
else:
|
||||
if snapshot1[entry_id] != snapshot2[entry_id]:
|
||||
modified.append({
|
||||
"before": snapshot1[entry_id],
|
||||
"after": snapshot2[entry_id],
|
||||
})
|
||||
|
||||
return diff_result
|
||||
return {
|
||||
"commit1": commit1.to_dict(),
|
||||
"commit2": commit2.to_dict(),
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"modified": modified,
|
||||
}
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, repository: MemoryRepository):
|
||||
self.repository = repository
|
||||
self.memory_service = MemoryService(repository)
|
||||
self.search_service = SearchService(repository)
|
||||
self.commit_service = CommitService(repository)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.repository.initialize()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.repository.close()
|
||||
|
||||
Reference in New Issue
Block a user