From 9ac55b05523f412af40c8b95c9d66133673db71f Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Sun, 22 Mar 2026 16:18:53 +0000 Subject: [PATCH] Add memory_manager source files and tests --- src/memory_manager/db/repository.py | 311 ++++++++++++++-------------- 1 file changed, 152 insertions(+), 159 deletions(-) diff --git a/src/memory_manager/db/repository.py b/src/memory_manager/db/repository.py index f7f15e1..cc4be5a 100644 --- a/src/memory_manager/db/repository.py +++ b/src/memory_manager/db/repository.py @@ -1,40 +1,44 @@ -"""Async repository for memory entries and commits.""" -import hashlib -import json -from datetime import datetime -from typing import List, Optional +"""Async repository for database operations.""" -from sqlalchemy import select, func -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker -from sqlalchemy.orm import selectinload +import os +from typing import Any -from memory_manager.db.models import Base, MemoryEntry, Commit, CommitEntry, MemoryCategory +from sqlalchemy import delete, select, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db class MemoryRepository: def __init__(self, db_path: str): self.db_path = db_path - self.engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False) - self.async_session = async_sessionmaker(self.engine, class_=AsyncSession) + self.engine: Any = None + self._session_factory: async_sessionmaker[AsyncSession] | None = None - async def init_db(self): - async with self.engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + async def initialize(self) -> None: + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) - async def close(self): - await self.engine.dispose() + self.engine = await init_db(self.db_path) + self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False) + + async def get_session(self) -> AsyncSession: + if not self._session_factory: + await self.initialize() + assert self._session_factory is not None + return self._session_factory() async def create_entry( self, title: str, content: str, category: MemoryCategory, - tags: List[str], + tags: list[str], agent_id: str, project_path: str, - parent_id: Optional[int] = None, ) -> MemoryEntry: - async with self.async_session() as session: + async with await self.get_session() as session: entry = MemoryEntry( title=title, content=content, @@ -42,52 +46,31 @@ class MemoryRepository: tags=tags, agent_id=agent_id, project_path=project_path, - parent_id=parent_id, ) session.add(entry) await session.commit() await session.refresh(entry) return entry - async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]: - async with self.async_session() as session: - result = await session.execute( - select(MemoryEntry).where(MemoryEntry.id == entry_id) - ) + async def get_entry(self, entry_id: int) -> MemoryEntry | None: + async with await self.get_session() as session: + result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) return result.scalar_one_or_none() - async def list_entries( - self, - category: Optional[MemoryCategory] = None, - agent_id: Optional[str] = None, - limit: int = 100, - offset: int = 0, - ) -> List[MemoryEntry]: - async with self.async_session() as session: - query = select(MemoryEntry) - if category: - query = query.where(MemoryEntry.category == category) - if agent_id: - query = query.where(MemoryEntry.agent_id == agent_id) - query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset) - result = await session.execute(query) - return list(result.scalars().all()) - async def update_entry( self, entry_id: int, - title: Optional[str] = None, - content: Optional[str] = None, - category: Optional[MemoryCategory] = None, - tags: Optional[List[str]] = None, - ) -> Optional[MemoryEntry]: - async with self.async_session() as session: - result = await session.execute( - select(MemoryEntry).where(MemoryEntry.id == entry_id) - ) + title: str | None = None, + content: str | None = None, + category: MemoryCategory | None = None, + tags: list[str] | None = None, + ) -> MemoryEntry | None: + async with await self.get_session() as session: + result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id)) entry = result.scalar_one_or_none() if not entry: return None + if title is not None: entry.title = title if content is not None: @@ -96,144 +79,154 @@ class MemoryRepository: entry.category = category if tags is not None: entry.tags = tags - entry.updated_at = datetime.utcnow() + await session.commit() await session.refresh(entry) return entry async def delete_entry(self, entry_id: int) -> bool: - async with self.async_session() as session: - result = await session.execute( - select(MemoryEntry).where(MemoryEntry.id == entry_id) - ) - entry = result.scalar_one_or_none() - if not entry: - return False - await session.delete(entry) + async with await self.get_session() as session: + result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id)) await session.commit() - return True + rowcount = getattr(result, "rowcount", 0) + return rowcount is not None and rowcount > 0 + + async def list_entries( + self, + category: MemoryCategory | None = None, + agent_id: str | None = None, + project_path: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[MemoryEntry]: + async with await self.get_session() as session: + query = select(MemoryEntry) + + if category: + query = query.where(MemoryEntry.category == category) + if agent_id: + query = query.where(MemoryEntry.agent_id == agent_id) + if project_path: + query = query.where(MemoryEntry.project_path == project_path) + + query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset) + result = await session.execute(query) + return list(result.scalars().all()) async def search_entries( self, - query: str, - category: Optional[MemoryCategory] = None, - tags: Optional[List[str]] = None, - ) -> List[MemoryEntry]: - async with self.async_session() as session: - search_pattern = f"%{query}%" - stmt = select(MemoryEntry).where( - (MemoryEntry.title.like(search_pattern)) | - (MemoryEntry.content.like(search_pattern)) - ) + query_text: str, + category: MemoryCategory | None = None, + agent_id: str | None = None, + project_path: str | None = None, + limit: int = 100, + ) -> list[MemoryEntry]: + async with await self.get_session() as session: + fts_query = f'"{query_text}"' + + sql_parts = [""" + SELECT m.* FROM memory_entries m + INNER JOIN memory_entries_fts fts ON m.id = fts.rowid + WHERE memory_entries_fts MATCH :query + """] + + params: dict[str, Any] = {"query": fts_query} + if category: - stmt = stmt.where(MemoryEntry.category == category) - if tags: - for tag in tags: - stmt = stmt.where(MemoryEntry.tags.contains([tag])) - stmt = stmt.order_by(MemoryEntry.created_at.desc()) - result = await session.execute(stmt) - return list(result.scalars().all()) + sql_parts.append(" AND m.category = :category") + params["category"] = category.value + if agent_id: + sql_parts.append(" AND m.agent_id = :agent_id") + params["agent_id"] = agent_id + if project_path: + sql_parts.append(" AND m.project_path = :project_path") + params["project_path"] = project_path + + sql_parts.append(" LIMIT :limit") + params["limit"] = limit + + sql = text("".join(sql_parts)) + + result = await session.execute(sql, params) + rows = result.fetchall() + + entries = [] + for row in rows: + entry = MemoryEntry( + id=row.id, + title=row.title, + content=row.content, + category=MemoryCategory(row.category), + tags=row.tags, + agent_id=row.agent_id, + project_path=row.project_path, + created_at=row.created_at, + updated_at=row.updated_at, + ) + entries.append(entry) + + return entries async def create_commit( self, + hash: str, message: str, agent_id: str, project_path: str, + snapshot: list[dict[str, Any]], ) -> Commit: - async with self.async_session() as session: - entries_result = await session.execute(select(MemoryEntry)) - entries = list(entries_result.scalars().all()) - - entries_snapshot = [] - for entry in entries: - entries_snapshot.append({ - "id": entry.id, - "title": entry.title, - "content": entry.content, - "category": entry.category.value, - "tags": entry.tags, - "agent_id": entry.agent_id, - "project_path": entry.project_path, - }) - - snapshot_json = json.dumps(entries_snapshot, sort_keys=True) - commit_hash = hashlib.sha1( - f"{snapshot_json}{datetime.utcnow().isoformat()}".encode() - ).hexdigest() - + async with await self.get_session() as session: commit = Commit( - hash=commit_hash, + hash=hash, message=message, agent_id=agent_id, project_path=project_path, + snapshot=snapshot, ) session.add(commit) - await session.flush() - - for entry in entries: - commit_entry = CommitEntry( - commit_id=commit.id, - memory_entry_id=entry.id, - entry_snapshot={ - "id": entry.id, - "title": entry.title, - "content": entry.content, - "category": entry.category.value, - "tags": entry.tags, - "agent_id": entry.agent_id, - }, - ) - session.add(commit_entry) - await session.commit() await session.refresh(commit) return commit - async def get_commits(self, limit: int = 50) -> List[Commit]: - async with self.async_session() as session: - result = await session.execute( - select(Commit).order_by(Commit.created_at.desc()).limit(limit) - ) - return list(result.scalars().all()) - - async def get_commit(self, commit_hash: str) -> Optional[Commit]: - async with self.async_session() as session: - result = await session.execute( - select(Commit).where(Commit.hash == commit_hash) - ) + async def get_commit(self, hash: str) -> Commit | None: + async with await self.get_session() as session: + result = await session.execute(select(Commit).where(Commit.hash == hash)) return result.scalar_one_or_none() - async def get_commit_entries(self, commit_hash: str) -> List[CommitEntry]: - async with self.async_session() as session: - commit_result = await session.execute( - select(Commit).where(Commit.hash == commit_hash) - ) - commit = commit_result.scalar_one_or_none() - if not commit: - return [] - result = await session.execute( - select(CommitEntry) - .options(selectinload(CommitEntry.memory_entry)) - .where(CommitEntry.commit_id == commit.id) - ) + async def get_commit_by_id(self, commit_id: int) -> Commit | None: + async with await self.get_session() as session: + result = await session.execute(select(Commit).where(Commit.id == commit_id)) + return result.scalar_one_or_none() + + async def list_commits( + self, + agent_id: str | None = None, + project_path: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[Commit]: + async with await self.get_session() as session: + query = select(Commit) + + if agent_id: + query = query.where(Commit.agent_id == agent_id) + if project_path: + query = query.where(Commit.project_path == project_path) + + query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset) + result = await session.execute(query) return list(result.scalars().all()) - async def get_stats(self) -> dict: - async with self.async_session() as session: - total_entries = await session.scalar(select(func.count(MemoryEntry.id))) - total_commits = await session.scalar(select(func.count(Commit.id))) + async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]: + async with await self.get_session() as session: + query = select(MemoryEntry) + if project_path: + query = query.where(MemoryEntry.project_path == project_path) - category_counts = {} - for category in MemoryCategory: - count = await session.scalar( - select(func.count(MemoryEntry.id)).where( - MemoryEntry.category == category - ) - ) - category_counts[category.value] = count or 0 + result = await session.execute(query) + entries = result.scalars().all() + return [entry.to_dict() for entry in entries] - return { - "total_entries": total_entries or 0, - "total_commits": total_commits or 0, - "category_counts": category_counts, - } + async def close(self) -> None: + if self.engine: + await self.engine.dispose()