Files
testdata-cli/database/db.py

603 lines
20 KiB
Python

"""
7000%AUTO Database Operations
Async SQLAlchemy database setup and CRUD operations
"""
import logging
from typing import Optional, List
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import select, func
from .models import Base, Idea, Project, AgentLog, IdeaSource, ProjectStatus, LogType
logger = logging.getLogger(__name__)
# Global engine and session factory
_engine = None
_session_factory = None
def async_session_factory():
"""Get the async session factory for direct use"""
if _session_factory is None:
raise RuntimeError("Database not initialized. Call init_db() first.")
return _session_factory()
async def init_db(database_url: Optional[str] = None):
"""Initialize database engine and create tables"""
global _engine, _session_factory
if database_url is None:
from config import settings
database_url = settings.DATABASE_URL
# Convert postgres:// to postgresql+asyncpg:// if needed
if database_url.startswith("postgres://"):
database_url = database_url.replace("postgres://", "postgresql+asyncpg://", 1)
elif database_url.startswith("postgresql://") and "+asyncpg" not in database_url:
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
_engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True
)
_session_factory = async_sessionmaker(
_engine,
class_=AsyncSession,
expire_on_commit=False
)
# Create tables
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database initialized successfully")
async def close_db():
"""Close database connection"""
global _engine, _session_factory
if _engine:
await _engine.dispose()
_engine = None
_session_factory = None
logger.info("Database connection closed")
@asynccontextmanager
async def get_db():
"""Get database session context manager"""
if _session_factory is None:
raise RuntimeError("Database not initialized. Call init_db() first.")
async with _session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# =============================================================================
# Idea CRUD Operations
# =============================================================================
async def create_idea(
title: str,
description: str,
source: str,
session: Optional[AsyncSession] = None
) -> Idea:
"""Create a new idea"""
async def _create(s: AsyncSession) -> Idea:
idea = Idea(
title=title,
description=description,
source=source if isinstance(source, str) else source.value
)
s.add(idea)
await s.flush()
await s.refresh(idea)
return idea
if session:
return await _create(session)
else:
async with get_db() as s:
return await _create(s)
async def get_idea_by_id(idea_id: int, session: Optional[AsyncSession] = None) -> Optional[Idea]:
"""Get idea by ID"""
async def _get(s: AsyncSession) -> Optional[Idea]:
result = await s.execute(select(Idea).where(Idea.id == idea_id))
return result.scalar_one_or_none()
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def get_unused_ideas(
limit: int = 10,
source: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[Idea]:
"""Get unused ideas"""
async def _get(s: AsyncSession) -> List[Idea]:
query = select(Idea).where(Idea.used == False)
if source:
query = query.where(Idea.source == source)
query = query.order_by(Idea.created_at.desc()).limit(limit)
result = await s.execute(query)
return list(result.scalars().all())
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def mark_idea_used(idea_id: int, session: Optional[AsyncSession] = None) -> bool:
"""Mark an idea as used"""
async def _mark(s: AsyncSession) -> bool:
result = await s.execute(select(Idea).where(Idea.id == idea_id))
idea = result.scalar_one_or_none()
if idea:
idea.used = True
return True
return False
if session:
return await _mark(session)
else:
async with get_db() as s:
return await _mark(s)
# =============================================================================
# Project CRUD Operations
# =============================================================================
async def create_project(
idea_id: int,
name: str,
plan_json: Optional[dict] = None,
session: Optional[AsyncSession] = None
) -> Project:
"""Create a new project"""
async def _create(s: AsyncSession) -> Project:
project = Project(
idea_id=idea_id,
name=name,
plan_json=plan_json,
status=ProjectStatus.IDEATION.value
)
s.add(project)
await s.flush()
await s.refresh(project)
return project
if session:
return await _create(session)
else:
async with get_db() as s:
return await _create(s)
async def get_project_by_id(project_id: int, session: Optional[AsyncSession] = None) -> Optional[Project]:
"""Get project by ID"""
async def _get(s: AsyncSession) -> Optional[Project]:
result = await s.execute(select(Project).where(Project.id == project_id))
return result.scalar_one_or_none()
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def get_active_project(session: Optional[AsyncSession] = None) -> Optional[Project]:
"""Get the currently active project (not completed/failed)"""
async def _get(s: AsyncSession) -> Optional[Project]:
query = select(Project).where(
Project.status.notin_([ProjectStatus.COMPLETED.value, ProjectStatus.FAILED.value])
).order_by(Project.created_at.desc()).limit(1)
result = await s.execute(query)
return result.scalar_one_or_none()
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def update_project_status(
project_id: int,
status: str,
gitea_url: Optional[str] = None,
x_post_url: Optional[str] = None,
dev_test_iterations: Optional[int] = None,
ci_test_iterations: Optional[int] = None,
current_agent: Optional[str] = None,
plan_json: Optional[dict] = None,
idea_json: Optional[dict] = None,
name: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> bool:
"""Update project status and optional fields"""
async def _update(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.status = status if isinstance(status, str) else status.value
if gitea_url is not None:
project.gitea_url = gitea_url
if x_post_url is not None:
project.x_post_url = x_post_url
if dev_test_iterations is not None:
project.dev_test_iterations = dev_test_iterations
if ci_test_iterations is not None:
project.ci_test_iterations = ci_test_iterations
if current_agent is not None:
project.current_agent = current_agent
if plan_json is not None:
project.plan_json = plan_json
if idea_json is not None:
project.idea_json = idea_json
if name is not None:
project.name = name
return True
return False
if session:
return await _update(session)
else:
async with get_db() as s:
return await _update(s)
# =============================================================================
# AgentLog CRUD Operations
# =============================================================================
async def add_agent_log(
project_id: int,
agent_name: str,
message: str,
log_type: str = LogType.INFO.value,
session: Optional[AsyncSession] = None
) -> AgentLog:
"""Add an agent log entry"""
async def _add(s: AsyncSession) -> AgentLog:
log = AgentLog(
project_id=project_id,
agent_name=agent_name,
message=message,
log_type=log_type if isinstance(log_type, str) else log_type.value
)
s.add(log)
await s.flush()
await s.refresh(log)
return log
if session:
return await _add(session)
else:
async with get_db() as s:
return await _add(s)
async def get_recent_logs(
limit: int = 50,
log_type: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[AgentLog]:
"""Get recent logs across all projects"""
async def _get(s: AsyncSession) -> List[AgentLog]:
query = select(AgentLog)
if log_type:
query = query.where(AgentLog.log_type == log_type)
query = query.order_by(AgentLog.created_at.desc()).limit(limit)
result = await s.execute(query)
return list(result.scalars().all())
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def get_project_logs(
project_id: int,
limit: int = 100,
log_type: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[AgentLog]:
"""Get logs for a specific project"""
async def _get(s: AsyncSession) -> List[AgentLog]:
query = select(AgentLog).where(AgentLog.project_id == project_id)
if log_type:
query = query.where(AgentLog.log_type == log_type)
query = query.order_by(AgentLog.created_at.desc()).limit(limit)
result = await s.execute(query)
return list(result.scalars().all())
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
# =============================================================================
# Statistics
# =============================================================================
async def get_project_idea_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted idea JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.idea_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def get_project_plan_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted plan JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.plan_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def set_project_idea_json(project_id: int, idea_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the idea JSON for a project (called by MCP submit_idea)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.idea_json = idea_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def set_project_plan_json(project_id: int, plan_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the plan JSON for a project (called by MCP submit_plan)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.plan_json = plan_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def get_project_test_result_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted test result JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.test_result_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def set_project_test_result_json(project_id: int, test_result_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the test result JSON for a project (called by MCP submit_test_result)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.test_result_json = test_result_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def get_project_implementation_status_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted implementation status JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.implementation_status_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def set_project_implementation_status_json(project_id: int, implementation_status_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the implementation status JSON for a project (called by MCP submit_implementation_status)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.implementation_status_json = implementation_status_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def clear_project_devtest_state(project_id: int, session: Optional[AsyncSession] = None) -> bool:
"""Clear test result and implementation status for a new dev-test iteration"""
async def _clear(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.test_result_json = None
project.implementation_status_json = None
return True
return False
if session:
return await _clear(session)
else:
async with get_db() as s:
return await _clear(s)
async def get_project_ci_result_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted CI result JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.ci_result_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def set_project_ci_result_json(project_id: int, ci_result_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the CI result JSON for a project (called by MCP submit_ci_result)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.ci_result_json = ci_result_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def get_project_upload_status_json(project_id: int, session: Optional[AsyncSession] = None) -> Optional[dict]:
"""Get the submitted upload status JSON for a project"""
async def _get(s: AsyncSession) -> Optional[dict]:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
return project.upload_status_json
return None
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)
async def set_project_upload_status_json(project_id: int, upload_status_json: dict, session: Optional[AsyncSession] = None) -> bool:
"""Set the upload status JSON for a project (called by MCP submit_upload_status)"""
async def _set(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.upload_status_json = upload_status_json
return True
return False
if session:
return await _set(session)
else:
async with get_db() as s:
return await _set(s)
async def clear_project_ci_state(project_id: int, session: Optional[AsyncSession] = None) -> bool:
"""Clear CI result and upload status for a new CI iteration"""
async def _clear(s: AsyncSession) -> bool:
result = await s.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.ci_result_json = None
project.upload_status_json = None
return True
return False
if session:
return await _clear(session)
else:
async with get_db() as s:
return await _clear(s)
async def get_stats(session: Optional[AsyncSession] = None) -> dict:
"""Get database statistics"""
async def _get(s: AsyncSession) -> dict:
ideas_count = await s.execute(select(func.count(Idea.id)))
projects_count = await s.execute(select(func.count(Project.id)))
completed_count = await s.execute(
select(func.count(Project.id)).where(Project.status == ProjectStatus.COMPLETED.value)
)
return {
"total_ideas": ideas_count.scalar() or 0,
"total_projects": projects_count.scalar() or 0,
"completed_projects": completed_count.scalar() or 0
}
if session:
return await _get(session)
else:
async with get_db() as s:
return await _get(s)