"""Async repository for database operations.""" import os from typing import Any from sqlalchemy import delete, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db class MemoryRepository: def __init__(self, db_path: str): self.db_path = db_path self.engine: Any = None self._session_factory: async_sessionmaker[AsyncSession] | None = None async def initialize(self) -> None: db_dir = os.path.dirname(self.db_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) self.engine = await init_db(self.db_path) self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False) async def get_session(self) -> AsyncSession: if not self._session_factory: await self.initialize() assert self._session_factory is not None return self._session_factory() async def create_entry( self, title: str, content: str, category: MemoryCategory, tags: list[str], agent_id: str, project_path: str, ) -> MemoryEntry: async with await self.get_session() as session: entry = MemoryEntry( title=title, content=content, category=category, tags=tags, agent_id=agent_id, project_path=project_path, ) session.add(entry) await session.commit() await session.refresh(entry) return entry async def get_entry(self, entry_id: int) -> MemoryEntry | None: async with await self.get_session() as session: result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) return result.scalar_one_or_none() async def update_entry( self, entry_id: int, title: str | None = None, content: str | None = None, category: MemoryCategory | None = None, tags: list[str] | None = None, ) -> MemoryEntry | None: async with await self.get_session() as session: result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) entry = result.scalar_one_or_none() if not entry: return None if title is not None: entry.title = title if content is not None: entry.content = content if category is not None: entry.category = category if tags is not None: entry.tags = tags await session.commit() await session.refresh(entry) return entry async def delete_entry(self, entry_id: int) -> bool: async with await self.get_session() as session: result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id)) await session.commit() rowcount = getattr(result, "rowcount", 0) return rowcount is not None and rowcount > 0 async def list_entries( self, category: MemoryCategory | None = None, agent_id: str | None = None, project_path: str | None = None, limit: int = 100, offset: int = 0, ) -> list[MemoryEntry]: async with await self.get_session() as session: query = select(MemoryEntry) if category: query = query.where(MemoryEntry.category == category) if agent_id: query = query.where(MemoryEntry.agent_id == agent_id) if project_path: query = query.where(MemoryEntry.project_path == project_path) query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset) result = await session.execute(query) return list(result.scalars().all()) async def search_entries( self, query_text: str, category: MemoryCategory | None = None, agent_id: str | None = None, project_path: str | None = None, limit: int = 100, ) -> list[MemoryEntry]: async with await self.get_session() as session: fts_query = f'"{query_text}"' sql_parts = [""" SELECT m.* FROM memory_entries m INNER JOIN memory_entries_fts fts ON m.id = fts.rowid WHERE memory_entries_fts MATCH :query """] params: dict[str, Any] = {"query": fts_query} if category: sql_parts.append(" AND m.category = :category") params["category"] = category.value if agent_id: sql_parts.append(" AND m.agent_id = :agent_id") params["agent_id"] = agent_id if project_path: sql_parts.append(" AND m.project_path = :project_path") params["project_path"] = project_path sql_parts.append(" LIMIT :limit") params["limit"] = limit sql = text("".join(sql_parts)) result = await session.execute(sql, params) rows = result.fetchall() entries = [] for row in rows: entry = MemoryEntry( id=row.id, title=row.title, content=row.content, category=MemoryCategory(row.category), tags=row.tags, agent_id=row.agent_id, project_path=row.project_path, created_at=row.created_at, updated_at=row.updated_at, ) entries.append(entry) return entries async def create_commit( self, hash: str, message: str, agent_id: str, project_path: str, snapshot: list[dict[str, Any]], ) -> Commit: async with await self.get_session() as session: commit = Commit( hash=hash, message=message, agent_id=agent_id, project_path=project_path, snapshot=snapshot, ) session.add(commit) await session.commit() await session.refresh(commit) return commit async def get_commit(self, hash: str) -> Commit | None: async with await self.get_session() as session: result = await session.execute(select(Commit).where(Commit.hash == hash)) return result.scalar_one_or_none() async def get_commit_by_id(self, commit_id: int) -> Commit | None: async with await self.get_session() as session: result = await session.execute(select(Commit).where(Commit.id == commit_id)) return result.scalar_one_or_none() async def list_commits( self, agent_id: str | None = None, project_path: str | None = None, limit: int = 100, offset: int = 0, ) -> list[Commit]: async with await self.get_session() as session: query = select(Commit) if agent_id: query = query.where(Commit.agent_id == agent_id) if project_path: query = query.where(Commit.project_path == project_path) query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset) result = await session.execute(query) return list(result.scalars().all()) async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]: async with await self.get_session() as session: query = select(MemoryEntry) if project_path: query = query.where(MemoryEntry.project_path == project_path) result = await session.execute(query) entries = result.scalars().all() return [entry.to_dict() for entry in entries] async def close(self) -> None: if self.engine: await self.engine.dispose()