Add memory_manager source files and tests
Some checks failed
CI / test (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-03-22 16:18:53 +00:00
parent b5708913d1
commit 9ac55b0552

View File

@@ -1,40 +1,44 @@
"""Async repository for memory entries and commits."""
import hashlib
import json
from datetime import datetime
from typing import List, Optional
"""Async repository for database operations."""
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import selectinload
import os
from typing import Any
from memory_manager.db.models import Base, MemoryEntry, Commit, CommitEntry, MemoryCategory
from sqlalchemy import delete, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
class MemoryRepository:
def __init__(self, db_path: str):
self.db_path = db_path
self.engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False)
self.async_session = async_sessionmaker(self.engine, class_=AsyncSession)
self.engine: Any = None
self._session_factory: async_sessionmaker[AsyncSession] | None = None
async def init_db(self):
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def initialize(self) -> None:
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
async def close(self):
await self.engine.dispose()
self.engine = await init_db(self.db_path)
self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
async def get_session(self) -> AsyncSession:
if not self._session_factory:
await self.initialize()
assert self._session_factory is not None
return self._session_factory()
async def create_entry(
self,
title: str,
content: str,
category: MemoryCategory,
tags: List[str],
tags: list[str],
agent_id: str,
project_path: str,
parent_id: Optional[int] = None,
) -> MemoryEntry:
async with self.async_session() as session:
async with await self.get_session() as session:
entry = MemoryEntry(
title=title,
content=content,
@@ -42,52 +46,31 @@ class MemoryRepository:
tags=tags,
agent_id=agent_id,
project_path=project_path,
parent_id=parent_id,
)
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]:
async with self.async_session() as session:
result = await session.execute(
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
async def get_entry(self, entry_id: int) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
return result.scalar_one_or_none()
async def list_entries(
self,
category: Optional[MemoryCategory] = None,
agent_id: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[MemoryEntry]:
async with self.async_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def update_entry(
self,
entry_id: int,
title: Optional[str] = None,
content: Optional[str] = None,
category: Optional[MemoryCategory] = None,
tags: Optional[List[str]] = None,
) -> Optional[MemoryEntry]:
async with self.async_session() as session:
result = await session.execute(
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
title: str | None = None,
content: str | None = None,
category: MemoryCategory | None = None,
tags: list[str] | None = None,
) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
entry = result.scalar_one_or_none()
if not entry:
return None
if title is not None:
entry.title = title
if content is not None:
@@ -96,144 +79,154 @@ class MemoryRepository:
entry.category = category
if tags is not None:
entry.tags = tags
entry.updated_at = datetime.utcnow()
await session.commit()
await session.refresh(entry)
return entry
async def delete_entry(self, entry_id: int) -> bool:
async with self.async_session() as session:
result = await session.execute(
select(MemoryEntry).where(MemoryEntry.id == entry_id)
)
entry = result.scalar_one_or_none()
if not entry:
return False
await session.delete(entry)
async with await self.get_session() as session:
result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
await session.commit()
return True
rowcount = getattr(result, "rowcount", 0)
return rowcount is not None and rowcount > 0
async def list_entries(
self,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def search_entries(
self,
query: str,
category: Optional[MemoryCategory] = None,
tags: Optional[List[str]] = None,
) -> List[MemoryEntry]:
async with self.async_session() as session:
search_pattern = f"%{query}%"
stmt = select(MemoryEntry).where(
(MemoryEntry.title.like(search_pattern)) |
(MemoryEntry.content.like(search_pattern))
)
query_text: str,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
fts_query = f'"{query_text}"'
sql_parts = ["""
SELECT m.* FROM memory_entries m
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
WHERE memory_entries_fts MATCH :query
"""]
params: dict[str, Any] = {"query": fts_query}
if category:
stmt = stmt.where(MemoryEntry.category == category)
if tags:
for tag in tags:
stmt = stmt.where(MemoryEntry.tags.contains([tag]))
stmt = stmt.order_by(MemoryEntry.created_at.desc())
result = await session.execute(stmt)
return list(result.scalars().all())
sql_parts.append(" AND m.category = :category")
params["category"] = category.value
if agent_id:
sql_parts.append(" AND m.agent_id = :agent_id")
params["agent_id"] = agent_id
if project_path:
sql_parts.append(" AND m.project_path = :project_path")
params["project_path"] = project_path
sql_parts.append(" LIMIT :limit")
params["limit"] = limit
sql = text("".join(sql_parts))
result = await session.execute(sql, params)
rows = result.fetchall()
entries = []
for row in rows:
entry = MemoryEntry(
id=row.id,
title=row.title,
content=row.content,
category=MemoryCategory(row.category),
tags=row.tags,
agent_id=row.agent_id,
project_path=row.project_path,
created_at=row.created_at,
updated_at=row.updated_at,
)
entries.append(entry)
return entries
async def create_commit(
self,
hash: str,
message: str,
agent_id: str,
project_path: str,
snapshot: list[dict[str, Any]],
) -> Commit:
async with self.async_session() as session:
entries_result = await session.execute(select(MemoryEntry))
entries = list(entries_result.scalars().all())
entries_snapshot = []
for entry in entries:
entries_snapshot.append({
"id": entry.id,
"title": entry.title,
"content": entry.content,
"category": entry.category.value,
"tags": entry.tags,
"agent_id": entry.agent_id,
"project_path": entry.project_path,
})
snapshot_json = json.dumps(entries_snapshot, sort_keys=True)
commit_hash = hashlib.sha1(
f"{snapshot_json}{datetime.utcnow().isoformat()}".encode()
).hexdigest()
async with await self.get_session() as session:
commit = Commit(
hash=commit_hash,
hash=hash,
message=message,
agent_id=agent_id,
project_path=project_path,
snapshot=snapshot,
)
session.add(commit)
await session.flush()
for entry in entries:
commit_entry = CommitEntry(
commit_id=commit.id,
memory_entry_id=entry.id,
entry_snapshot={
"id": entry.id,
"title": entry.title,
"content": entry.content,
"category": entry.category.value,
"tags": entry.tags,
"agent_id": entry.agent_id,
},
)
session.add(commit_entry)
await session.commit()
await session.refresh(commit)
return commit
async def get_commits(self, limit: int = 50) -> List[Commit]:
async with self.async_session() as session:
result = await session.execute(
select(Commit).order_by(Commit.created_at.desc()).limit(limit)
)
return list(result.scalars().all())
async def get_commit(self, commit_hash: str) -> Optional[Commit]:
async with self.async_session() as session:
result = await session.execute(
select(Commit).where(Commit.hash == commit_hash)
)
async def get_commit(self, hash: str) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.hash == hash))
return result.scalar_one_or_none()
async def get_commit_entries(self, commit_hash: str) -> List[CommitEntry]:
async with self.async_session() as session:
commit_result = await session.execute(
select(Commit).where(Commit.hash == commit_hash)
)
commit = commit_result.scalar_one_or_none()
if not commit:
return []
result = await session.execute(
select(CommitEntry)
.options(selectinload(CommitEntry.memory_entry))
.where(CommitEntry.commit_id == commit.id)
)
async def get_commit_by_id(self, commit_id: int) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.id == commit_id))
return result.scalar_one_or_none()
async def list_commits(
self,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[Commit]:
async with await self.get_session() as session:
query = select(Commit)
if agent_id:
query = query.where(Commit.agent_id == agent_id)
if project_path:
query = query.where(Commit.project_path == project_path)
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def get_stats(self) -> dict:
async with self.async_session() as session:
total_entries = await session.scalar(select(func.count(MemoryEntry.id)))
total_commits = await session.scalar(select(func.count(Commit.id)))
async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
category_counts = {}
for category in MemoryCategory:
count = await session.scalar(
select(func.count(MemoryEntry.id)).where(
MemoryEntry.category == category
)
)
category_counts[category.value] = count or 0
result = await session.execute(query)
entries = result.scalars().all()
return [entry.to_dict() for entry in entries]
return {
"total_entries": total_entries or 0,
"total_commits": total_commits or 0,
"category_counts": category_counts,
}
async def close(self) -> None:
if self.engine:
await self.engine.dispose()