Add memory_manager source files and tests
Some checks failed
CI / test (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-03-22 16:18:53 +00:00
parent b5708913d1
commit 9ac55b0552

View File

@@ -1,40 +1,44 @@
"""Async repository for memory entries and commits.""" """Async repository for database operations."""
import hashlib
import json
from datetime import datetime
from typing import List, Optional
from sqlalchemy import select, func import os
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from typing import Any
from sqlalchemy.orm import selectinload
from memory_manager.db.models import Base, MemoryEntry, Commit, CommitEntry, MemoryCategory from sqlalchemy import delete, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
class MemoryRepository: class MemoryRepository:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
self.engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False) self.engine: Any = None
self.async_session = async_sessionmaker(self.engine, class_=AsyncSession) self._session_factory: async_sessionmaker[AsyncSession] | None = None
async def init_db(self): async def initialize(self) -> None:
async with self.engine.begin() as conn: db_dir = os.path.dirname(self.db_path)
await conn.run_sync(Base.metadata.create_all) if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
async def close(self): self.engine = await init_db(self.db_path)
await self.engine.dispose() self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
async def get_session(self) -> AsyncSession:
if not self._session_factory:
await self.initialize()
assert self._session_factory is not None
return self._session_factory()
async def create_entry( async def create_entry(
self, self,
title: str, title: str,
content: str, content: str,
category: MemoryCategory, category: MemoryCategory,
tags: List[str], tags: list[str],
agent_id: str, agent_id: str,
project_path: str, project_path: str,
parent_id: Optional[int] = None,
) -> MemoryEntry: ) -> MemoryEntry:
async with self.async_session() as session: async with await self.get_session() as session:
entry = MemoryEntry( entry = MemoryEntry(
title=title, title=title,
content=content, content=content,
@@ -42,52 +46,31 @@ class MemoryRepository:
tags=tags, tags=tags,
agent_id=agent_id, agent_id=agent_id,
project_path=project_path, project_path=project_path,
parent_id=parent_id,
) )
session.add(entry) session.add(entry)
await session.commit() await session.commit()
await session.refresh(entry) await session.refresh(entry)
return entry return entry
async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]: async def get_entry(self, entry_id: int) -> MemoryEntry | None:
async with self.async_session() as session: async with await self.get_session() as session:
result = await session.execute( result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def list_entries(
self,
category: Optional[MemoryCategory] = None,
agent_id: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[MemoryEntry]:
async with self.async_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def update_entry( async def update_entry(
self, self,
entry_id: int, entry_id: int,
title: Optional[str] = None, title: str | None = None,
content: Optional[str] = None, content: str | None = None,
category: Optional[MemoryCategory] = None, category: MemoryCategory | None = None,
tags: Optional[List[str]] = None, tags: list[str] | None = None,
) -> Optional[MemoryEntry]: ) -> MemoryEntry | None:
async with self.async_session() as session: async with await self.get_session() as session:
result = await session.execute( result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
entry = result.scalar_one_or_none() entry = result.scalar_one_or_none()
if not entry: if not entry:
return None return None
if title is not None: if title is not None:
entry.title = title entry.title = title
if content is not None: if content is not None:
@@ -96,144 +79,154 @@ class MemoryRepository:
entry.category = category entry.category = category
if tags is not None: if tags is not None:
entry.tags = tags entry.tags = tags
entry.updated_at = datetime.utcnow()
await session.commit() await session.commit()
await session.refresh(entry) await session.refresh(entry)
return entry return entry
async def delete_entry(self, entry_id: int) -> bool: async def delete_entry(self, entry_id: int) -> bool:
async with self.async_session() as session: async with await self.get_session() as session:
result = await session.execute( result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
entry = result.scalar_one_or_none()
if not entry:
return False
await session.delete(entry)
await session.commit() await session.commit()
return True rowcount = getattr(result, "rowcount", 0)
return rowcount is not None and rowcount > 0
async def list_entries(
self,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def search_entries( async def search_entries(
self, self,
query: str, query_text: str,
category: Optional[MemoryCategory] = None, category: MemoryCategory | None = None,
tags: Optional[List[str]] = None, agent_id: str | None = None,
) -> List[MemoryEntry]: project_path: str | None = None,
async with self.async_session() as session: limit: int = 100,
search_pattern = f"%{query}%" ) -> list[MemoryEntry]:
stmt = select(MemoryEntry).where( async with await self.get_session() as session:
(MemoryEntry.title.like(search_pattern)) | fts_query = f'"{query_text}"'
(MemoryEntry.content.like(search_pattern))
) sql_parts = ["""
SELECT m.* FROM memory_entries m
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
WHERE memory_entries_fts MATCH :query
"""]
params: dict[str, Any] = {"query": fts_query}
if category: if category:
stmt = stmt.where(MemoryEntry.category == category) sql_parts.append(" AND m.category = :category")
if tags: params["category"] = category.value
for tag in tags: if agent_id:
stmt = stmt.where(MemoryEntry.tags.contains([tag])) sql_parts.append(" AND m.agent_id = :agent_id")
stmt = stmt.order_by(MemoryEntry.created_at.desc()) params["agent_id"] = agent_id
result = await session.execute(stmt) if project_path:
return list(result.scalars().all()) sql_parts.append(" AND m.project_path = :project_path")
params["project_path"] = project_path
sql_parts.append(" LIMIT :limit")
params["limit"] = limit
sql = text("".join(sql_parts))
result = await session.execute(sql, params)
rows = result.fetchall()
entries = []
for row in rows:
entry = MemoryEntry(
id=row.id,
title=row.title,
content=row.content,
category=MemoryCategory(row.category),
tags=row.tags,
agent_id=row.agent_id,
project_path=row.project_path,
created_at=row.created_at,
updated_at=row.updated_at,
)
entries.append(entry)
return entries
async def create_commit( async def create_commit(
self, self,
hash: str,
message: str, message: str,
agent_id: str, agent_id: str,
project_path: str, project_path: str,
snapshot: list[dict[str, Any]],
) -> Commit: ) -> Commit:
async with self.async_session() as session: async with await self.get_session() as session:
entries_result = await session.execute(select(MemoryEntry))
entries = list(entries_result.scalars().all())
entries_snapshot = []
for entry in entries:
entries_snapshot.append({
"id": entry.id,
"title": entry.title,
"content": entry.content,
"category": entry.category.value,
"tags": entry.tags,
"agent_id": entry.agent_id,
"project_path": entry.project_path,
})
snapshot_json = json.dumps(entries_snapshot, sort_keys=True)
commit_hash = hashlib.sha1(
f"{snapshot_json}{datetime.utcnow().isoformat()}".encode()
).hexdigest()
commit = Commit( commit = Commit(
hash=commit_hash, hash=hash,
message=message, message=message,
agent_id=agent_id, agent_id=agent_id,
project_path=project_path, project_path=project_path,
snapshot=snapshot,
) )
session.add(commit) session.add(commit)
await session.flush()
for entry in entries:
commit_entry = CommitEntry(
commit_id=commit.id,
memory_entry_id=entry.id,
entry_snapshot={
"id": entry.id,
"title": entry.title,
"content": entry.content,
"category": entry.category.value,
"tags": entry.tags,
"agent_id": entry.agent_id,
},
)
session.add(commit_entry)
await session.commit() await session.commit()
await session.refresh(commit) await session.refresh(commit)
return commit return commit
async def get_commits(self, limit: int = 50) -> List[Commit]: async def get_commit(self, hash: str) -> Commit | None:
async with self.async_session() as session: async with await self.get_session() as session:
result = await session.execute( result = await session.execute(select(Commit).where(Commit.hash == hash))
select(Commit).order_by(Commit.created_at.desc()).limit(limit)
)
return list(result.scalars().all())
async def get_commit(self, commit_hash: str) -> Optional[Commit]:
async with self.async_session() as session:
result = await session.execute(
select(Commit).where(Commit.hash == commit_hash)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def get_commit_entries(self, commit_hash: str) -> List[CommitEntry]: async def get_commit_by_id(self, commit_id: int) -> Commit | None:
async with self.async_session() as session: async with await self.get_session() as session:
commit_result = await session.execute( result = await session.execute(select(Commit).where(Commit.id == commit_id))
select(Commit).where(Commit.hash == commit_hash) return result.scalar_one_or_none()
)
commit = commit_result.scalar_one_or_none() async def list_commits(
if not commit: self,
return [] agent_id: str | None = None,
result = await session.execute( project_path: str | None = None,
select(CommitEntry) limit: int = 100,
.options(selectinload(CommitEntry.memory_entry)) offset: int = 0,
.where(CommitEntry.commit_id == commit.id) ) -> list[Commit]:
) async with await self.get_session() as session:
query = select(Commit)
if agent_id:
query = query.where(Commit.agent_id == agent_id)
if project_path:
query = query.where(Commit.project_path == project_path)
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
async def get_stats(self) -> dict: async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
async with self.async_session() as session: async with await self.get_session() as session:
total_entries = await session.scalar(select(func.count(MemoryEntry.id))) query = select(MemoryEntry)
total_commits = await session.scalar(select(func.count(Commit.id))) if project_path:
query = query.where(MemoryEntry.project_path == project_path)
category_counts = {} result = await session.execute(query)
for category in MemoryCategory: entries = result.scalars().all()
count = await session.scalar( return [entry.to_dict() for entry in entries]
select(func.count(MemoryEntry.id)).where(
MemoryEntry.category == category
)
)
category_counts[category.value] = count or 0
return { async def close(self) -> None:
"total_entries": total_entries or 0, if self.engine:
"total_commits": total_commits or 0, await self.engine.dispose()
"category_counts": category_counts,
}