diff --git a/config_convert/utils/flatten.py b/config_convert/utils/flatten.py new file mode 100644 index 0000000..2c6a67e --- /dev/null +++ b/config_convert/utils/flatten.py @@ -0,0 +1,152 @@ +"""Utilities for flattening and unflattening nested dictionaries.""" + +import re +from collections.abc import MutableMapping +from typing import Any, Dict, List, Tuple + + +def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]: + """Flatten a nested dictionary to dot-notation keys. + + Args: + data: The nested dictionary to flatten + parent_key: The parent key prefix (used internally for recursion) + sep: The separator to use between keys (default: ".") + + Returns: + A flattened dictionary with dot-notation keys + """ + items: Dict[str, Any] = {} + for key, value in data.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + items.update(flatten_dict(value, new_key, sep)) + elif isinstance(value, list): + items.update(flatten_list(value, new_key, sep)) + else: + items[new_key] = value + return items + + +def flatten_list(data: List[Any], parent_key: str, sep: str = ".") -> Dict[str, Any]: + """Flatten a list to bracket-notation keys. + + Args: + data: The list to flatten + parent_key: The parent key prefix + sep: The separator to use (not typically used for lists) + + Returns: + A dictionary with bracket-notation keys + """ + items: Dict[str, Any] = {} + for index, value in enumerate(data): + new_key = f"{parent_key}[{index}]" + if isinstance(value, dict): + items.update(flatten_dict(value, new_key, sep)) + elif isinstance(value, list): + items.update(flatten_list(value, new_key, sep)) + else: + items[new_key] = value + return items + + +def unflatten_dict(data: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """Unflatten a dictionary with dot-notation keys back to nested structure. + + Args: + data: The flattened dictionary with dot-notation keys + sep: The separator used in the keys (default: ".") + + Returns: + A nested dictionary + """ + result: Dict[str, Any] = {} + + for flat_key, value in data.items(): + parts = _split_key(flat_key, sep) + _set_nested(result, parts, value) + + return result + + +def _split_key(key: str, sep: str) -> List[str]: + """Split a key, handling both dot notation and bracket notation.""" + result = [] + current = "" + + i = 0 + while i < len(key): + if key[i] == "[": + if current: + result.append(current) + current = "" + j = key.find("]", i) + if j == -1: + raise ValueError(f"Invalid bracket notation in key: {key}") + result.append(key[i + 1:j]) + i = j + 1 + elif key[i] == sep: + if current: + result.append(current) + current = "" + i += 1 + else: + current += key[i] + i += 1 + + if current: + result.append(current) + + return result + + +def _set_nested(d: Any, parts: List[str], value: Any) -> None: + """Set a value in a nested structure, creating intermediate dicts or lists as needed.""" + current = d + path = [] # Track the path to enable modification at the right level + + for i, part in enumerate(parts[:-1]): + if isinstance(current, list): + if not part.isdigit(): + raise ValueError(f"Cannot use string key '{part}' with list") + idx = int(part) + while len(current) <= idx: + current.append(None) + if current[idx] is None: + if parts[i + 1].isdigit(): + current[idx] = [] + else: + current[idx] = {} + current = current[idx] + elif isinstance(current, dict): + if part in current: + existing = current[part] + if isinstance(existing, list): + current = existing + elif isinstance(existing, dict): + current = existing + else: + raise ValueError(f"Key conflict: '{part}' exists as both leaf and intermediate") + else: + next_part = parts[i + 1] if i + 1 < len(parts) else None + if next_part is not None and next_part.isdigit(): + current[part] = [] + current = current[part] + else: + current[part] = {} + current = current[part] + else: + raise ValueError(f"Cannot traverse through non-dict/list object") + + last_part = parts[-1] + if isinstance(current, list): + if last_part.isdigit(): + idx = int(last_part) + while len(current) <= idx: + current.append(None) + current[idx] = value + else: + raise ValueError(f"Cannot use string key '{last_part}' with list") + else: + current[last_part] = value