Initial upload: testdata-cli with CI/CD workflow
This commit is contained in:
118
src/testdatagen/formatters/sql_formatter.py
Normal file
118
src/testdatagen/formatters/sql_formatter.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""SQL output formatter."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class SQLFormatter:
|
||||
"""Formatter that outputs data as SQL INSERT statements."""
|
||||
|
||||
def __init__(self, table_name: str = "generated_table"):
|
||||
"""Initialize the SQL formatter.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table for INSERT statements
|
||||
"""
|
||||
self.table_name = self._validate_table_name(table_name)
|
||||
|
||||
def format(self, records: List[Dict[str, Any]]) -> str:
|
||||
"""Format records as SQL INSERT statements.
|
||||
|
||||
Args:
|
||||
records: List of data records to format
|
||||
|
||||
Returns:
|
||||
SQL INSERT statements
|
||||
"""
|
||||
if not records:
|
||||
return ""
|
||||
|
||||
if not records[0]:
|
||||
return ""
|
||||
|
||||
columns = list(records[0].keys())
|
||||
column_list = ", ".join(columns)
|
||||
|
||||
statements = []
|
||||
for record in records:
|
||||
values = []
|
||||
for col in columns:
|
||||
value = record.get(col)
|
||||
values.append(self._format_value(value))
|
||||
|
||||
values_list = ", ".join(values)
|
||||
statement = f"INSERT INTO {self.table_name} ({column_list}) VALUES ({values_list});"
|
||||
statements.append(statement)
|
||||
|
||||
return "\n".join(statements)
|
||||
|
||||
def _format_value(self, value: Any) -> str:
|
||||
"""Format a value for SQL.
|
||||
|
||||
Args:
|
||||
value: Value to format
|
||||
|
||||
Returns:
|
||||
SQL-formatted value string
|
||||
"""
|
||||
if value is None:
|
||||
return "NULL"
|
||||
|
||||
if isinstance(value, bool):
|
||||
return "TRUE" if value else "FALSE"
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
escaped = value.replace("'", "''")
|
||||
return f"'{escaped}'"
|
||||
|
||||
if isinstance(value, (list, dict)):
|
||||
import json
|
||||
json_str = json.dumps(value).replace("'", "''")
|
||||
return f"'{json_str}'"
|
||||
|
||||
return f"'{str(value).replace(chr(39), chr(39)+chr(39))}'"
|
||||
|
||||
def _validate_table_name(self, table_name: str) -> str:
|
||||
"""Validate and sanitize table name to prevent SQL injection.
|
||||
|
||||
Args:
|
||||
table_name: Table name to validate
|
||||
|
||||
Returns:
|
||||
Validated table name
|
||||
|
||||
Raises:
|
||||
ValueError: If table name contains invalid characters
|
||||
"""
|
||||
if not table_name:
|
||||
return "generated_table"
|
||||
|
||||
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name):
|
||||
raise ValueError(
|
||||
f"Invalid table name '{table_name}'. "
|
||||
"Table name must start with a letter or underscore "
|
||||
"and contain only letters, numbers, and underscores."
|
||||
)
|
||||
|
||||
reserved_words = {
|
||||
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
||||
"ALTER", "TABLE", "DATABASE", "INDEX", "VIEW", "FROM",
|
||||
"WHERE", "AND", "OR", "NOT", "NULL", "TRUE", "FALSE"
|
||||
}
|
||||
|
||||
if table_name.upper() in reserved_words:
|
||||
raise ValueError(
|
||||
f"Table name '{table_name}' is a reserved word. "
|
||||
"Please use a different table name."
|
||||
)
|
||||
|
||||
if len(table_name) > 64:
|
||||
raise ValueError(
|
||||
f"Table name '{table_name}' is too long. "
|
||||
"Maximum length is 64 characters."
|
||||
)
|
||||
|
||||
return table_name
|
||||
Reference in New Issue
Block a user