diff --git a/src/memory_manager/db/repository.py b/src/memory_manager/db/repository.py index cc4be5a..b6c28ef 100644 --- a/src/memory_manager/db/repository.py +++ b/src/memory_manager/db/repository.py @@ -1,232 +1 @@ -"""Async repository for database operations.""" - -import os -from typing import Any - -from sqlalchemy import delete, select, text -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db - - -class MemoryRepository: - def __init__(self, db_path: str): - self.db_path = db_path - self.engine: Any = None - self._session_factory: async_sessionmaker[AsyncSession] | None = None - - async def initialize(self) -> None: - db_dir = os.path.dirname(self.db_path) - if db_dir and not os.path.exists(db_dir): - os.makedirs(db_dir, exist_ok=True) - - self.engine = await init_db(self.db_path) - self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False) - - async def get_session(self) -> AsyncSession: - if not self._session_factory: - await self.initialize() - assert self._session_factory is not None - return self._session_factory() - - async def create_entry( - self, - title: str, - content: str, - category: MemoryCategory, - tags: list[str], - agent_id: str, - project_path: str, - ) -> MemoryEntry: - async with await self.get_session() as session: - entry = MemoryEntry( - title=title, - content=content, - category=category, - tags=tags, - agent_id=agent_id, - project_path=project_path, - ) - session.add(entry) - await session.commit() - await session.refresh(entry) - return entry - - async def get_entry(self, entry_id: int) -> MemoryEntry | None: - async with await self.get_session() as session: - result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) - return result.scalar_one_or_none() - - async def update_entry( - self, - entry_id: int, - title: str | None = None, - content: str | None = None, - category: MemoryCategory | None = None, - tags: list[str] | None = None, - ) -> MemoryEntry | None: - async with await self.get_session() as session: - result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) - entry = result.scalar_one_or_none() - if not entry: - return None - - if title is not None: - entry.title = title - if content is not None: - entry.content = content - if category is not None: - entry.category = category - if tags is not None: - entry.tags = tags - - await session.commit() - await session.refresh(entry) - return entry - - async def delete_entry(self, entry_id: int) -> bool: - async with await self.get_session() as session: - result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id)) - await session.commit() - rowcount = getattr(result, "rowcount", 0) - return rowcount is not None and rowcount > 0 - - async def list_entries( - self, - category: MemoryCategory | None = None, - agent_id: str | None = None, - project_path: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> list[MemoryEntry]: - async with await self.get_session() as session: - query = select(MemoryEntry) - - if category: - query = query.where(MemoryEntry.category == category) - if agent_id: - query = query.where(MemoryEntry.agent_id == agent_id) - if project_path: - query = query.where(MemoryEntry.project_path == project_path) - - query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset) - result = await session.execute(query) - return list(result.scalars().all()) - - async def search_entries( - self, - query_text: str, - category: MemoryCategory | None = None, - agent_id: str | None = None, - project_path: str | None = None, - limit: int = 100, - ) -> list[MemoryEntry]: - async with await self.get_session() as session: - fts_query = f'"{query_text}"' - - sql_parts = [""" - SELECT m.* FROM memory_entries m - INNER JOIN memory_entries_fts fts ON m.id = fts.rowid - WHERE memory_entries_fts MATCH :query - """] - - params: dict[str, Any] = {"query": fts_query} - - if category: - sql_parts.append(" AND m.category = :category") - params["category"] = category.value - if agent_id: - sql_parts.append(" AND m.agent_id = :agent_id") - params["agent_id"] = agent_id - if project_path: - sql_parts.append(" AND m.project_path = :project_path") - params["project_path"] = project_path - - sql_parts.append(" LIMIT :limit") - params["limit"] = limit - - sql = text("".join(sql_parts)) - - result = await session.execute(sql, params) - rows = result.fetchall() - - entries = [] - for row in rows: - entry = MemoryEntry( - id=row.id, - title=row.title, - content=row.content, - category=MemoryCategory(row.category), - tags=row.tags, - agent_id=row.agent_id, - project_path=row.project_path, - created_at=row.created_at, - updated_at=row.updated_at, - ) - entries.append(entry) - - return entries - - async def create_commit( - self, - hash: str, - message: str, - agent_id: str, - project_path: str, - snapshot: list[dict[str, Any]], - ) -> Commit: - async with await self.get_session() as session: - commit = Commit( - hash=hash, - message=message, - agent_id=agent_id, - project_path=project_path, - snapshot=snapshot, - ) - session.add(commit) - await session.commit() - await session.refresh(commit) - return commit - - async def get_commit(self, hash: str) -> Commit | None: - async with await self.get_session() as session: - result = await session.execute(select(Commit).where(Commit.hash == hash)) - return result.scalar_one_or_none() - - async def get_commit_by_id(self, commit_id: int) -> Commit | None: - async with await self.get_session() as session: - result = await session.execute(select(Commit).where(Commit.id == commit_id)) - return result.scalar_one_or_none() - - async def list_commits( - self, - agent_id: str | None = None, - project_path: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> list[Commit]: - async with await self.get_session() as session: - query = select(Commit) - - if agent_id: - query = query.where(Commit.agent_id == agent_id) - if project_path: - query = query.where(Commit.project_path == project_path) - - query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset) - result = await session.execute(query) - return list(result.scalars().all()) - - async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]: - async with await self.get_session() as session: - query = select(MemoryEntry) - if project_path: - query = query.where(MemoryEntry.project_path == project_path) - - result = await session.execute(query) - entries = result.scalars().all() - return [entry.to_dict() for entry in entries] - - async def close(self) -> None: - if self.engine: - await self.engine.dispose() +repository content \ No newline at end of file