diff --git a/devtoolbelt/commands/database.py b/devtoolbelt/commands/database.py new file mode 100644 index 0000000..d58fc70 --- /dev/null +++ b/devtoolbelt/commands/database.py @@ -0,0 +1,416 @@ +"""Database commands for Devtoolbelt.""" + +from typing import Any, Dict, Optional + +import click +from rich import print as rprint +from rich.table import Table +from sqlalchemy import create_engine, inspect, text +from sqlalchemy.exc import SQLAlchemyError + +from ..config import get_config +from ..utils import console + + +@click.group() +def database(): + """Database inspection and query commands.""" + pass + + +@database.command("list") +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +def list_databases(config: Optional[str]): + """List configured databases.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if not databases: + rprint("[yellow]No databases configured.[/yellow]") + rprint("Add databases to your config file or use 'db connect' command.") + return + + table = Table(title="Configured Databases") + table.add_column("Name", style="cyan") + table.add_column("Type", style="green") + table.add_column("Host", style="magenta") + table.add_column("Database", style="yellow") + + for name, db_config in databases.items(): + db_type = db_config.get("type", "unknown") + host = db_config.get("host", "localhost") + db_name = db_config.get("database", name) + table.add_row(name, db_type, host, db_name) + + console.print(table) + + +@database.command("connect") +@click.argument("name") +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +def connect_db(name: str, config: Optional[str]): + """Connect to a database and show basic info.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if name not in databases: + rprint(f"[red]Database '{name}' not found in configuration.[/red]") + return + + db_config = databases[name] + conn_str = _build_connection_string(db_config) + + try: + engine = create_engine(conn_str) + with engine.connect(): + inspector = inspect(engine) + tables = inspector.get_table_names() + db_name = db_config.get("database", name) + + info_table = Table(title=f"Database: {name}") + info_table.add_column("Property", style="cyan") + info_table.add_column("Value", style="green") + + info_table.add_row("Database Name", db_name) + info_table.add_row("Host", db_config.get("host", "localhost")) + info_table.add_row("Port", str(db_config.get("port", ""))) + info_table.add_row("Tables Count", str(len(tables))) + + console.print(info_table) + + if tables: + rprint("\n[bold cyan]Tables:[/bold cyan]") + for i, table in enumerate(tables, 1): + rprint(f" {i}. {table}") + + except SQLAlchemyError as e: + rprint(f"[red]Error connecting to database: {e}[/red]") + + +@database.command("tables") +@click.argument("name") +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +@click.option( + "--verbose", "-v", + is_flag=True, + help="Show detailed table information." +) +def list_tables(name: str, config: Optional[str], verbose: bool): + """List tables in a database.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if name not in databases: + rprint(f"[red]Database '{name}' not found in configuration.[/red]") + return + + db_config = databases[name] + conn_str = _build_connection_string(db_config) + + try: + engine = create_engine(conn_str) + inspector = inspect(engine) + tables = inspector.get_table_names() + + if verbose: + table = Table(title=f"Tables in {name}") + table.add_column("Table Name", style="cyan") + table.add_column("Columns", style="green") + table.add_column("Type", style="magenta") + + for t in tables: + columns = inspector.get_columns(t) + table.add_row( + t, + str(len(columns)), + "Table" + ) + + console.print(table) + else: + rprint(f"[bold]Tables in {name}:[/bold]") + for i, table in enumerate(tables, 1): + rprint(f" {i}. {table}") + + except SQLAlchemyError as e: + rprint(f"[red]Error listing tables: {e}[/red]") + + +@database.command("schema") +@click.argument("name") +@click.argument("table") +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +def show_schema(name: str, table: str, config: Optional[str]): + """Show schema for a table.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if name not in databases: + rprint(f"[red]Database '{name}' not found in configuration.[/red]") + return + + db_config = databases[name] + conn_str = _build_connection_string(db_config) + + try: + engine = create_engine(conn_str) + inspector = inspect(engine) + + if table not in inspector.get_table_names(): + rprint(f"[red]Table '{table}' not found in database.[/red]") + return + + columns = inspector.get_columns(table) + foreign_keys = inspector.get_foreign_keys(table) + indexes = inspector.get_indexes(table) + primary_key = inspector.get_pk_constraint(table) + + schema_table = Table(title=f"Schema: {table}") + schema_table.add_column("Column", style="cyan") + schema_table.add_column("Type", style="green") + schema_table.add_column("Nullable", style="magenta") + schema_table.add_column("Default", style="yellow") + + for col in columns: + col_type = str(col['type']) + nullable = "Yes" if col['nullable'] else "No" + default = str(col['default']) if col['default'] else "-" + schema_table.add_row( + col['name'], + col_type, + nullable, + default + ) + + console.print(schema_table) + + if primary_key.get('constrained_columns'): + rprint(f"\n[bold cyan]Primary Key:[/bold cyan] {', '.join(primary_key['constrained_columns'])}") + + if foreign_keys: + rprint("\n[bold cyan]Foreign Keys:[/bold cyan]") + for fk in foreign_keys: + rprint(f" {fk['constrained_columns']} -> {fk['referred_table']}.{fk['referred_columns']}") + + if indexes: + rprint("\n[bold cyan]Indexes:[/bold cyan]") + for idx in indexes: + col_names = [c for c in idx['column_names'] if c is not None] + rprint(f" {idx['name']}: {', '.join(col_names)}") + + except SQLAlchemyError as e: + rprint(f"[red]Error getting schema: {e}[/red]") + + +@database.command("query") +@click.argument("name") +@click.argument("query") +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +@click.option( + "--limit", "-l", + type=int, + default=100, + help="Maximum number of rows to return." +) +@click.option( + "--format", "-f", + type=click.Choice(["table", "json", "csv"]), + default="table", + help="Output format." +) +def execute_query( + name: str, + query: str, + config: Optional[str], + limit: int, + format: str +): + """Execute a query on a database.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if name not in databases: + rprint(f"[red]Database '{name}' not found in configuration.[/red]") + return + + db_config = databases[name] + conn_str = _build_connection_string(db_config) + + if limit > 0: + if not query.lower().strip().startswith("select"): + rprint("[yellow]Warning: LIMIT is only applied to SELECT queries.[/yellow]") + else: + if "limit" not in query.lower(): + query = f"{query} LIMIT {limit}" + + try: + engine = create_engine(conn_str) + with engine.connect() as conn: + result = conn.execute(text(query)) + columns = result.keys() + rows = result.fetchall() + + if format == "json": + import json + data = [dict(zip(columns, row)) for row in rows] + click.echo(json.dumps(data, indent=2, default=str)) + elif format == "csv": + import csv + import io + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(columns) + writer.writerows(rows) + click.echo(output.getvalue()) + else: + table = Table(title=f"Query Results from {name}") + for col in columns: + table.add_column(str(col), style="green") + for row in rows: + table.add_row(*[str(cell) for cell in row]) + console.print(table) + + rprint(f"\n[dim]Rows returned: {len(rows)}[/dim]") + + except SQLAlchemyError as e: + rprint(f"[red]Query error: {e}[/red]") + + +@database.command("add") +@click.argument("name") +@click.option( + "--type", "-t", + type=click.Choice(["postgresql", "mysql", "sqlite", "mssql"]), + required=True, + help="Database type." +) +@click.option( + "--host", "-H", + default="localhost", + help="Database host." +) +@click.option( + "--port", "-p", + type=int, + help="Database port." +) +@click.option( + "--database", "-d", + required=True, + help="Database name." +) +@click.option( + "--user", "-u", + help="Database user." +) +@click.option( + "--password", "-P", + help="Database password." +) +@click.option( + "--config", "-c", + type=click.Path(exists=True), + help="Path to configuration file." +) +def add_database( + name: str, + type: str, + host: str, + port: Optional[int], + database: str, + user: Optional[str], + password: Optional[str], + config: Optional[str] +): + """Add a database configuration.""" + cfg = get_config(config) + databases = cfg.get_database_configs() + + if name in databases: + rprint(f"[yellow]Database '{name}' already exists. Overwriting.[/yellow]") + + db_config: Dict[str, Any] = { + "type": type, + "host": host, + "database": database + } + + if port: + db_config["port"] = port + if user: + db_config["user"] = user + if password: + db_config["password"] = password + + databases[name] = db_config + cfg.set("databases", databases) + cfg.save() + + rprint(f"[green]Database '{name}' added successfully.[/green]") + + +def _build_connection_string(db_config: Dict[str, Any]) -> str: + """Build SQLAlchemy connection string from config. + + Args: + db_config: Database configuration dictionary. + + Returns: + Connection string. + """ + db_type = db_config.get("type", "") + + if db_type == "postgresql": + user = db_config.get("user", "") + password = db_config.get("password", "") + host = db_config.get("host", "localhost") + port = db_config.get("port", 5432) + database = db_config.get("database", "") + if password: + return f"postgresql://{user}:{password}@{host}:{port}/{database}" + return f"postgresql://{user}@{host}:{port}/{database}" + + elif db_type == "mysql": + user = db_config.get("user", "") + password = db_config.get("password", "") + host = db_config.get("host", "localhost") + port = db_config.get("port", 3306) + database = db_config.get("database", "") + if password: + return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}" + return f"mysql+pymysql://{user}@{host}:{port}/{database}" + + elif db_type == "sqlite": + database = db_config.get("database", "") + return f"sqlite:///{database}" + + elif db_type == "mssql": + user = db_config.get("user", "") + password = db_config.get("password", "") + host = db_config.get("host", "localhost") + port = db_config.get("port", 1433) + database = db_config.get("database", "") + return f"mssql+pyodbc://{user}:{password}@{host}:{port}/{database}?driver=ODBC+Driver+17+for+SQL+Server" + + else: + raise ValueError(f"Unsupported database type: {db_type}")