diff --git a/src/testdatagen/formatters/sql_formatter.py b/src/testdatagen/formatters/sql_formatter.py new file mode 100644 index 0000000..72836f9 --- /dev/null +++ b/src/testdatagen/formatters/sql_formatter.py @@ -0,0 +1,118 @@ +"""SQL output formatter.""" + +import re +from typing import Any, Dict, List + + +class SQLFormatter: + """Formatter that outputs data as SQL INSERT statements.""" + + def __init__(self, table_name: str = "generated_table"): + """Initialize the SQL formatter. + + Args: + table_name: Name of the table for INSERT statements + """ + self.table_name = self._validate_table_name(table_name) + + def format(self, records: List[Dict[str, Any]]) -> str: + """Format records as SQL INSERT statements. + + Args: + records: List of data records to format + + Returns: + SQL INSERT statements + """ + if not records: + return "" + + if not records[0]: + return "" + + columns = list(records[0].keys()) + column_list = ", ".join(columns) + + statements = [] + for record in records: + values = [] + for col in columns: + value = record.get(col) + values.append(self._format_value(value)) + + values_list = ", ".join(values) + statement = f"INSERT INTO {self.table_name} ({column_list}) VALUES ({values_list});" + statements.append(statement) + + return "\n".join(statements) + + def _format_value(self, value: Any) -> str: + """Format a value for SQL. + + Args: + value: Value to format + + Returns: + SQL-formatted value string + """ + if value is None: + return "NULL" + + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + + if isinstance(value, (int, float)): + return str(value) + + if isinstance(value, str): + escaped = value.replace("'", "''") + return f"'{escaped}'" + + if isinstance(value, (list, dict)): + import json + json_str = json.dumps(value).replace("'", "''") + return f"'{json_str}'" + + return f"'{str(value).replace(chr(39), chr(39)+chr(39))}'" + + def _validate_table_name(self, table_name: str) -> str: + """Validate and sanitize table name to prevent SQL injection. + + Args: + table_name: Table name to validate + + Returns: + Validated table name + + Raises: + ValueError: If table name contains invalid characters + """ + if not table_name: + return "generated_table" + + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name): + raise ValueError( + f"Invalid table name '{table_name}'. " + "Table name must start with a letter or underscore " + "and contain only letters, numbers, and underscores." + ) + + reserved_words = { + "SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", + "ALTER", "TABLE", "DATABASE", "INDEX", "VIEW", "FROM", + "WHERE", "AND", "OR", "NOT", "NULL", "TRUE", "FALSE" + } + + if table_name.upper() in reserved_words: + raise ValueError( + f"Table name '{table_name}' is a reserved word. " + "Please use a different table name." + ) + + if len(table_name) > 64: + raise ValueError( + f"Table name '{table_name}' is too long. " + "Maximum length is 64 characters." + ) + + return table_name \ No newline at end of file