diff --git a/src/memory_manager/db/repository.py b/src/memory_manager/db/repository.py new file mode 100644 index 0000000..f7f15e1 --- /dev/null +++ b/src/memory_manager/db/repository.py @@ -0,0 +1,239 @@ +"""Async repository for memory entries and commits.""" +import hashlib +import json +from datetime import datetime +from typing import List, Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.orm import selectinload + +from memory_manager.db.models import Base, MemoryEntry, Commit, CommitEntry, MemoryCategory + + +class MemoryRepository: + def __init__(self, db_path: str): + self.db_path = db_path + self.engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False) + self.async_session = async_sessionmaker(self.engine, class_=AsyncSession) + + async def init_db(self): + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def close(self): + await self.engine.dispose() + + async def create_entry( + self, + title: str, + content: str, + category: MemoryCategory, + tags: List[str], + agent_id: str, + project_path: str, + parent_id: Optional[int] = None, + ) -> MemoryEntry: + async with self.async_session() as session: + entry = MemoryEntry( + title=title, + content=content, + category=category, + tags=tags, + agent_id=agent_id, + project_path=project_path, + parent_id=parent_id, + ) + session.add(entry) + await session.commit() + await session.refresh(entry) + return entry + + async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]: + async with self.async_session() as session: + result = await session.execute( + select(MemoryEntry).where(MemoryEntry.id == entry_id) + ) + return result.scalar_one_or_none() + + async def list_entries( + self, + category: Optional[MemoryCategory] = None, + agent_id: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> List[MemoryEntry]: + async with self.async_session() as session: + query = select(MemoryEntry) + if category: + query = query.where(MemoryEntry.category == category) + if agent_id: + query = query.where(MemoryEntry.agent_id == agent_id) + query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_entry( + self, + entry_id: int, + title: Optional[str] = None, + content: Optional[str] = None, + category: Optional[MemoryCategory] = None, + tags: Optional[List[str]] = None, + ) -> Optional[MemoryEntry]: + async with self.async_session() as session: + result = await session.execute( + select(MemoryEntry).where(MemoryEntry.id == entry_id) + ) + entry = result.scalar_one_or_none() + if not entry: + return None + if title is not None: + entry.title = title + if content is not None: + entry.content = content + if category is not None: + entry.category = category + if tags is not None: + entry.tags = tags + entry.updated_at = datetime.utcnow() + await session.commit() + await session.refresh(entry) + return entry + + async def delete_entry(self, entry_id: int) -> bool: + async with self.async_session() as session: + result = await session.execute( + select(MemoryEntry).where(MemoryEntry.id == entry_id) + ) + entry = result.scalar_one_or_none() + if not entry: + return False + await session.delete(entry) + await session.commit() + return True + + async def search_entries( + self, + query: str, + category: Optional[MemoryCategory] = None, + tags: Optional[List[str]] = None, + ) -> List[MemoryEntry]: + async with self.async_session() as session: + search_pattern = f"%{query}%" + stmt = select(MemoryEntry).where( + (MemoryEntry.title.like(search_pattern)) | + (MemoryEntry.content.like(search_pattern)) + ) + if category: + stmt = stmt.where(MemoryEntry.category == category) + if tags: + for tag in tags: + stmt = stmt.where(MemoryEntry.tags.contains([tag])) + stmt = stmt.order_by(MemoryEntry.created_at.desc()) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def create_commit( + self, + message: str, + agent_id: str, + project_path: str, + ) -> Commit: + async with self.async_session() as session: + entries_result = await session.execute(select(MemoryEntry)) + entries = list(entries_result.scalars().all()) + + entries_snapshot = [] + for entry in entries: + entries_snapshot.append({ + "id": entry.id, + "title": entry.title, + "content": entry.content, + "category": entry.category.value, + "tags": entry.tags, + "agent_id": entry.agent_id, + "project_path": entry.project_path, + }) + + snapshot_json = json.dumps(entries_snapshot, sort_keys=True) + commit_hash = hashlib.sha1( + f"{snapshot_json}{datetime.utcnow().isoformat()}".encode() + ).hexdigest() + + commit = Commit( + hash=commit_hash, + message=message, + agent_id=agent_id, + project_path=project_path, + ) + session.add(commit) + await session.flush() + + for entry in entries: + commit_entry = CommitEntry( + commit_id=commit.id, + memory_entry_id=entry.id, + entry_snapshot={ + "id": entry.id, + "title": entry.title, + "content": entry.content, + "category": entry.category.value, + "tags": entry.tags, + "agent_id": entry.agent_id, + }, + ) + session.add(commit_entry) + + await session.commit() + await session.refresh(commit) + return commit + + async def get_commits(self, limit: int = 50) -> List[Commit]: + async with self.async_session() as session: + result = await session.execute( + select(Commit).order_by(Commit.created_at.desc()).limit(limit) + ) + return list(result.scalars().all()) + + async def get_commit(self, commit_hash: str) -> Optional[Commit]: + async with self.async_session() as session: + result = await session.execute( + select(Commit).where(Commit.hash == commit_hash) + ) + return result.scalar_one_or_none() + + async def get_commit_entries(self, commit_hash: str) -> List[CommitEntry]: + async with self.async_session() as session: + commit_result = await session.execute( + select(Commit).where(Commit.hash == commit_hash) + ) + commit = commit_result.scalar_one_or_none() + if not commit: + return [] + result = await session.execute( + select(CommitEntry) + .options(selectinload(CommitEntry.memory_entry)) + .where(CommitEntry.commit_id == commit.id) + ) + return list(result.scalars().all()) + + async def get_stats(self) -> dict: + async with self.async_session() as session: + total_entries = await session.scalar(select(func.count(MemoryEntry.id))) + total_commits = await session.scalar(select(func.count(Commit.id))) + + category_counts = {} + for category in MemoryCategory: + count = await session.scalar( + select(func.count(MemoryEntry.id)).where( + MemoryEntry.category == category + ) + ) + category_counts[category.value] = count or 0 + + return { + "total_entries": total_entries or 0, + "total_commits": total_commits or 0, + "category_counts": category_counts, + }