Add memory_manager source files and tests
This commit is contained in:
@@ -1,40 +1,44 @@
|
|||||||
"""Async repository for memory entries and commits."""
|
"""Async repository for database operations."""
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
import os
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from typing import Any
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from memory_manager.db.models import Base, MemoryEntry, Commit, CommitEntry, MemoryCategory
|
from sqlalchemy import delete, select, text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from memory_manager.db.models import Commit, MemoryCategory, MemoryEntry, init_db
|
||||||
|
|
||||||
|
|
||||||
class MemoryRepository:
|
class MemoryRepository:
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False)
|
self.engine: Any = None
|
||||||
self.async_session = async_sessionmaker(self.engine, class_=AsyncSession)
|
self._session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||||
|
|
||||||
async def init_db(self):
|
async def initialize(self) -> None:
|
||||||
async with self.engine.begin() as conn:
|
db_dir = os.path.dirname(self.db_path)
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
if db_dir and not os.path.exists(db_dir):
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
|
||||||
async def close(self):
|
self.engine = await init_db(self.db_path)
|
||||||
await self.engine.dispose()
|
self._session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def get_session(self) -> AsyncSession:
|
||||||
|
if not self._session_factory:
|
||||||
|
await self.initialize()
|
||||||
|
assert self._session_factory is not None
|
||||||
|
return self._session_factory()
|
||||||
|
|
||||||
async def create_entry(
|
async def create_entry(
|
||||||
self,
|
self,
|
||||||
title: str,
|
title: str,
|
||||||
content: str,
|
content: str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
tags: List[str],
|
tags: list[str],
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
project_path: str,
|
project_path: str,
|
||||||
parent_id: Optional[int] = None,
|
|
||||||
) -> MemoryEntry:
|
) -> MemoryEntry:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
entry = MemoryEntry(
|
entry = MemoryEntry(
|
||||||
title=title,
|
title=title,
|
||||||
content=content,
|
content=content,
|
||||||
@@ -42,52 +46,31 @@ class MemoryRepository:
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
project_path=project_path,
|
project_path=project_path,
|
||||||
parent_id=parent_id,
|
|
||||||
)
|
)
|
||||||
session.add(entry)
|
session.add(entry)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(entry)
|
await session.refresh(entry)
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
async def get_entry(self, entry_id: int) -> Optional[MemoryEntry]:
|
async def get_entry(self, entry_id: int) -> MemoryEntry | None:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||||
select(MemoryEntry).where(MemoryEntry.id == entry_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def list_entries(
|
|
||||||
self,
|
|
||||||
category: Optional[MemoryCategory] = None,
|
|
||||||
agent_id: Optional[str] = None,
|
|
||||||
limit: int = 100,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> List[MemoryEntry]:
|
|
||||||
async with self.async_session() as session:
|
|
||||||
query = select(MemoryEntry)
|
|
||||||
if category:
|
|
||||||
query = query.where(MemoryEntry.category == category)
|
|
||||||
if agent_id:
|
|
||||||
query = query.where(MemoryEntry.agent_id == agent_id)
|
|
||||||
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
|
|
||||||
result = await session.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def update_entry(
|
async def update_entry(
|
||||||
self,
|
self,
|
||||||
entry_id: int,
|
entry_id: int,
|
||||||
title: Optional[str] = None,
|
title: str | None = None,
|
||||||
content: Optional[str] = None,
|
content: str | None = None,
|
||||||
category: Optional[MemoryCategory] = None,
|
category: MemoryCategory | None = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: list[str] | None = None,
|
||||||
) -> Optional[MemoryEntry]:
|
) -> MemoryEntry | None:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(select(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||||
select(MemoryEntry).where(MemoryEntry.id == entry_id)
|
|
||||||
)
|
|
||||||
entry = result.scalar_one_or_none()
|
entry = result.scalar_one_or_none()
|
||||||
if not entry:
|
if not entry:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if title is not None:
|
if title is not None:
|
||||||
entry.title = title
|
entry.title = title
|
||||||
if content is not None:
|
if content is not None:
|
||||||
@@ -96,144 +79,154 @@ class MemoryRepository:
|
|||||||
entry.category = category
|
entry.category = category
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
entry.tags = tags
|
entry.tags = tags
|
||||||
entry.updated_at = datetime.utcnow()
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(entry)
|
await session.refresh(entry)
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
async def delete_entry(self, entry_id: int) -> bool:
|
async def delete_entry(self, entry_id: int) -> bool:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(delete(MemoryEntry).where(MemoryEntry.id == entry_id))
|
||||||
select(MemoryEntry).where(MemoryEntry.id == entry_id)
|
|
||||||
)
|
|
||||||
entry = result.scalar_one_or_none()
|
|
||||||
if not entry:
|
|
||||||
return False
|
|
||||||
await session.delete(entry)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return True
|
rowcount = getattr(result, "rowcount", 0)
|
||||||
|
return rowcount is not None and rowcount > 0
|
||||||
|
|
||||||
|
async def list_entries(
|
||||||
|
self,
|
||||||
|
category: MemoryCategory | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
project_path: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[MemoryEntry]:
|
||||||
|
async with await self.get_session() as session:
|
||||||
|
query = select(MemoryEntry)
|
||||||
|
|
||||||
|
if category:
|
||||||
|
query = query.where(MemoryEntry.category == category)
|
||||||
|
if agent_id:
|
||||||
|
query = query.where(MemoryEntry.agent_id == agent_id)
|
||||||
|
if project_path:
|
||||||
|
query = query.where(MemoryEntry.project_path == project_path)
|
||||||
|
|
||||||
|
query = query.order_by(MemoryEntry.created_at.desc()).limit(limit).offset(offset)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def search_entries(
|
async def search_entries(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query_text: str,
|
||||||
category: Optional[MemoryCategory] = None,
|
category: MemoryCategory | None = None,
|
||||||
tags: Optional[List[str]] = None,
|
agent_id: str | None = None,
|
||||||
) -> List[MemoryEntry]:
|
project_path: str | None = None,
|
||||||
async with self.async_session() as session:
|
limit: int = 100,
|
||||||
search_pattern = f"%{query}%"
|
) -> list[MemoryEntry]:
|
||||||
stmt = select(MemoryEntry).where(
|
async with await self.get_session() as session:
|
||||||
(MemoryEntry.title.like(search_pattern)) |
|
fts_query = f'"{query_text}"'
|
||||||
(MemoryEntry.content.like(search_pattern))
|
|
||||||
)
|
sql_parts = ["""
|
||||||
|
SELECT m.* FROM memory_entries m
|
||||||
|
INNER JOIN memory_entries_fts fts ON m.id = fts.rowid
|
||||||
|
WHERE memory_entries_fts MATCH :query
|
||||||
|
"""]
|
||||||
|
|
||||||
|
params: dict[str, Any] = {"query": fts_query}
|
||||||
|
|
||||||
if category:
|
if category:
|
||||||
stmt = stmt.where(MemoryEntry.category == category)
|
sql_parts.append(" AND m.category = :category")
|
||||||
if tags:
|
params["category"] = category.value
|
||||||
for tag in tags:
|
if agent_id:
|
||||||
stmt = stmt.where(MemoryEntry.tags.contains([tag]))
|
sql_parts.append(" AND m.agent_id = :agent_id")
|
||||||
stmt = stmt.order_by(MemoryEntry.created_at.desc())
|
params["agent_id"] = agent_id
|
||||||
result = await session.execute(stmt)
|
if project_path:
|
||||||
return list(result.scalars().all())
|
sql_parts.append(" AND m.project_path = :project_path")
|
||||||
|
params["project_path"] = project_path
|
||||||
|
|
||||||
|
sql_parts.append(" LIMIT :limit")
|
||||||
|
params["limit"] = limit
|
||||||
|
|
||||||
|
sql = text("".join(sql_parts))
|
||||||
|
|
||||||
|
result = await session.execute(sql, params)
|
||||||
|
rows = result.fetchall()
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for row in rows:
|
||||||
|
entry = MemoryEntry(
|
||||||
|
id=row.id,
|
||||||
|
title=row.title,
|
||||||
|
content=row.content,
|
||||||
|
category=MemoryCategory(row.category),
|
||||||
|
tags=row.tags,
|
||||||
|
agent_id=row.agent_id,
|
||||||
|
project_path=row.project_path,
|
||||||
|
created_at=row.created_at,
|
||||||
|
updated_at=row.updated_at,
|
||||||
|
)
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
async def create_commit(
|
async def create_commit(
|
||||||
self,
|
self,
|
||||||
|
hash: str,
|
||||||
message: str,
|
message: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
project_path: str,
|
project_path: str,
|
||||||
|
snapshot: list[dict[str, Any]],
|
||||||
) -> Commit:
|
) -> Commit:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
entries_result = await session.execute(select(MemoryEntry))
|
|
||||||
entries = list(entries_result.scalars().all())
|
|
||||||
|
|
||||||
entries_snapshot = []
|
|
||||||
for entry in entries:
|
|
||||||
entries_snapshot.append({
|
|
||||||
"id": entry.id,
|
|
||||||
"title": entry.title,
|
|
||||||
"content": entry.content,
|
|
||||||
"category": entry.category.value,
|
|
||||||
"tags": entry.tags,
|
|
||||||
"agent_id": entry.agent_id,
|
|
||||||
"project_path": entry.project_path,
|
|
||||||
})
|
|
||||||
|
|
||||||
snapshot_json = json.dumps(entries_snapshot, sort_keys=True)
|
|
||||||
commit_hash = hashlib.sha1(
|
|
||||||
f"{snapshot_json}{datetime.utcnow().isoformat()}".encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
commit = Commit(
|
commit = Commit(
|
||||||
hash=commit_hash,
|
hash=hash,
|
||||||
message=message,
|
message=message,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
project_path=project_path,
|
project_path=project_path,
|
||||||
|
snapshot=snapshot,
|
||||||
)
|
)
|
||||||
session.add(commit)
|
session.add(commit)
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
for entry in entries:
|
|
||||||
commit_entry = CommitEntry(
|
|
||||||
commit_id=commit.id,
|
|
||||||
memory_entry_id=entry.id,
|
|
||||||
entry_snapshot={
|
|
||||||
"id": entry.id,
|
|
||||||
"title": entry.title,
|
|
||||||
"content": entry.content,
|
|
||||||
"category": entry.category.value,
|
|
||||||
"tags": entry.tags,
|
|
||||||
"agent_id": entry.agent_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
session.add(commit_entry)
|
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(commit)
|
await session.refresh(commit)
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
async def get_commits(self, limit: int = 50) -> List[Commit]:
|
async def get_commit(self, hash: str) -> Commit | None:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(select(Commit).where(Commit.hash == hash))
|
||||||
select(Commit).order_by(Commit.created_at.desc()).limit(limit)
|
|
||||||
)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def get_commit(self, commit_hash: str) -> Optional[Commit]:
|
|
||||||
async with self.async_session() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(Commit).where(Commit.hash == commit_hash)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def get_commit_entries(self, commit_hash: str) -> List[CommitEntry]:
|
async def get_commit_by_id(self, commit_id: int) -> Commit | None:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
commit_result = await session.execute(
|
result = await session.execute(select(Commit).where(Commit.id == commit_id))
|
||||||
select(Commit).where(Commit.hash == commit_hash)
|
return result.scalar_one_or_none()
|
||||||
)
|
|
||||||
commit = commit_result.scalar_one_or_none()
|
async def list_commits(
|
||||||
if not commit:
|
self,
|
||||||
return []
|
agent_id: str | None = None,
|
||||||
result = await session.execute(
|
project_path: str | None = None,
|
||||||
select(CommitEntry)
|
limit: int = 100,
|
||||||
.options(selectinload(CommitEntry.memory_entry))
|
offset: int = 0,
|
||||||
.where(CommitEntry.commit_id == commit.id)
|
) -> list[Commit]:
|
||||||
)
|
async with await self.get_session() as session:
|
||||||
|
query = select(Commit)
|
||||||
|
|
||||||
|
if agent_id:
|
||||||
|
query = query.where(Commit.agent_id == agent_id)
|
||||||
|
if project_path:
|
||||||
|
query = query.where(Commit.project_path == project_path)
|
||||||
|
|
||||||
|
query = query.order_by(Commit.created_at.desc()).limit(limit).offset(offset)
|
||||||
|
result = await session.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def get_stats(self) -> dict:
|
async def get_all_entries_snapshot(self, project_path: str | None = None) -> list[dict[str, Any]]:
|
||||||
async with self.async_session() as session:
|
async with await self.get_session() as session:
|
||||||
total_entries = await session.scalar(select(func.count(MemoryEntry.id)))
|
query = select(MemoryEntry)
|
||||||
total_commits = await session.scalar(select(func.count(Commit.id)))
|
if project_path:
|
||||||
|
query = query.where(MemoryEntry.project_path == project_path)
|
||||||
|
|
||||||
category_counts = {}
|
result = await session.execute(query)
|
||||||
for category in MemoryCategory:
|
entries = result.scalars().all()
|
||||||
count = await session.scalar(
|
return [entry.to_dict() for entry in entries]
|
||||||
select(func.count(MemoryEntry.id)).where(
|
|
||||||
MemoryEntry.category == category
|
|
||||||
)
|
|
||||||
)
|
|
||||||
category_counts[category.value] = count or 0
|
|
||||||
|
|
||||||
return {
|
async def close(self) -> None:
|
||||||
"total_entries": total_entries or 0,
|
if self.engine:
|
||||||
"total_commits": total_commits or 0,
|
await self.engine.dispose()
|
||||||
"category_counts": category_counts,
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user