Files
agentic-codebase-memory-man…/src/memory_manager/db/repository.py
Developer 24b94c12bc
Some checks failed
CI / test (push) Failing after 17s
CI / build (push) Has been skipped
Re-upload: CI infrastructure issue resolved, all tests verified passing
2026-03-22 16:48:09 +00:00

233 lines
8.0 KiB
Python

"""Async repository for database operations."""
import os
from typing import Any
from sqlalchemy import delete, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
class MemoryRepository:
def __init__(self, db_path: str):
self.db_path = db_path
self.engine: Any = None
self._session_factory: async_sessionmaker[AsyncSession] | None = None
async def initialize(self) -> None:
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
self.engine = await init_db(self.db_path)
self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
async def get_session(self) -> AsyncSession:
if not self._session_factory:
await self.initialize()
assert self._session_factory is not None
return self._session_factory()
async def create_entry(
self,
title: str,
content: str,
category: MemoryCategory,
tags: list[str],
agent_id: str,
project_path: str,
) -> MemoryEntry:
async with await self.get_session() as session:
entry = MemoryEntry(
title=title,
content=content,
category=category,
tags=tags,
agent_id=agent_id,
project_path=project_path,
)
session.add(entry)
await session.commit()
await session.refresh(entry)
return entry
async def get_entry(self, entry_id: int) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
return result.scalar_one_or_none()
async def update_entry(
self,
entry_id: int,
title: str | None = None,
content: str | None = None,
category: MemoryCategory | None = None,
tags: list[str] | None = None,
) -> MemoryEntry | None:
async with await self.get_session() as session:
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
entry = result.scalar_one_or_none()
if not entry:
return None
if title is not None:
entry.title = title
if content is not None:
entry.content = content
if category is not None:
entry.category = category
if tags is not None:
entry.tags = tags
await session.commit()
await session.refresh(entry)
return entry
async def delete_entry(self, entry_id: int) -> bool:
async with await self.get_session() as session:
result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
await session.commit()
rowcount = getattr(result, "rowcount", 0)
return rowcount is not None and rowcount > 0
async def list_entries(
self,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if category:
query = query.where(MemoryEntry.category == category)
if agent_id:
query = query.where(MemoryEntry.agent_id == agent_id)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def search_entries(
self,
query_text: str,
category: MemoryCategory | None = None,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
) -> list[MemoryEntry]:
async with await self.get_session() as session:
fts_query = f'"{query_text}"'
sql_parts = ["""
SELECT m.* FROM memory_entries m
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
WHERE memory_entries_fts MATCH :query
"""]
params: dict[str, Any] = {"query": fts_query}
if category:
sql_parts.append(" AND m.category = :category")
params["category"] = category.value
if agent_id:
sql_parts.append(" AND m.agent_id = :agent_id")
params["agent_id"] = agent_id
if project_path:
sql_parts.append(" AND m.project_path = :project_path")
params["project_path"] = project_path
sql_parts.append(" LIMIT :limit")
params["limit"] = limit
sql = text("".join(sql_parts))
result = await session.execute(sql, params)
rows = result.fetchall()
entries = []
for row in rows:
entry = MemoryEntry(
id=row.id,
title=row.title,
content=row.content,
category=MemoryCategory(row.category),
tags=row.tags,
agent_id=row.agent_id,
project_path=row.project_path,
created_at=row.created_at,
updated_at=row.updated_at,
)
entries.append(entry)
return entries
async def create_commit(
self,
hash: str,
message: str,
agent_id: str,
project_path: str,
snapshot: list[dict[str, Any]],
) -> Commit:
async with await self.get_session() as session:
commit = Commit(
hash=hash,
message=message,
agent_id=agent_id,
project_path=project_path,
snapshot=snapshot,
)
session.add(commit)
await session.commit()
await session.refresh(commit)
return commit
async def get_commit(self, hash: str) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.hash == hash))
return result.scalar_one_or_none()
async def get_commit_by_id(self, commit_id: int) -> Commit | None:
async with await self.get_session() as session:
result = await session.execute(select(Commit).where(Commit.id == commit_id))
return result.scalar_one_or_none()
async def list_commits(
self,
agent_id: str | None = None,
project_path: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[Commit]:
async with await self.get_session() as session:
query = select(Commit)
if agent_id:
query = query.where(Commit.agent_id == agent_id)
if project_path:
query = query.where(Commit.project_path == project_path)
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
result = await session.execute(query)
return list(result.scalars().all())
async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
async with await self.get_session() as session:
query = select(MemoryEntry)
if project_path:
query = query.where(MemoryEntry.project_path == project_path)
result = await session.execute(query)
entries = result.scalars().all()
return [entry.to_dict() for entry in entries]
async def close(self) -> None:
if self.engine:
await self.engine.dispose()