150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
"""Utilities for flattening and unflattening nested dictionaries."""
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
|
|
"""Flatten a nested dictionary to dot-notation keys.
|
|
|
|
Args:
|
|
data: The nested dictionary to flatten
|
|
parent_key: The parent key prefix (used internally for recursion)
|
|
sep: The separator to use between keys (default: ".")
|
|
|
|
Returns:
|
|
A flattened dictionary with dot-notation keys
|
|
"""
|
|
items: Dict[str, Any] = {}
|
|
for key, value in data.items():
|
|
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
|
if isinstance(value, dict):
|
|
items.update(flatten_dict(value, new_key, sep))
|
|
elif isinstance(value, list):
|
|
items.update(flatten_list(value, new_key, sep))
|
|
else:
|
|
items[new_key] = value
|
|
return items
|
|
|
|
|
|
def flatten_list(data: List[Any], parent_key: str, sep: str = ".") -> Dict[str, Any]:
|
|
"""Flatten a list to bracket-notation keys.
|
|
|
|
Args:
|
|
data: The list to flatten
|
|
parent_key: The parent key prefix
|
|
sep: The separator to use (not typically used for lists)
|
|
|
|
Returns:
|
|
A dictionary with bracket-notation keys
|
|
"""
|
|
items: Dict[str, Any] = {}
|
|
for index, value in enumerate(data):
|
|
new_key = f"{parent_key}[{index}]"
|
|
if isinstance(value, dict):
|
|
items.update(flatten_dict(value, new_key, sep))
|
|
elif isinstance(value, list):
|
|
items.update(flatten_list(value, new_key, sep))
|
|
else:
|
|
items[new_key] = value
|
|
return items
|
|
|
|
|
|
def unflatten_dict(data: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
|
|
"""Unflatten a dictionary with dot-notation keys back to nested structure.
|
|
|
|
Args:
|
|
data: The flattened dictionary with dot-notation keys
|
|
sep: The separator used in the keys (default: ".")
|
|
|
|
Returns:
|
|
A nested dictionary
|
|
"""
|
|
result: Dict[str, Any] = {}
|
|
|
|
for flat_key, value in data.items():
|
|
parts = _split_key(flat_key, sep)
|
|
_set_nested(result, parts, value)
|
|
|
|
return result
|
|
|
|
|
|
def _split_key(key: str, sep: str) -> List[str]:
|
|
"""Split a key, handling both dot notation and bracket notation."""
|
|
result = []
|
|
current = ""
|
|
|
|
i = 0
|
|
while i < len(key):
|
|
if key[i] == "[":
|
|
if current:
|
|
result.append(current)
|
|
current = ""
|
|
j = key.find("]", i)
|
|
if j == -1:
|
|
raise ValueError(f"Invalid bracket notation in key: {key}")
|
|
result.append(key[i + 1:j])
|
|
i = j + 1
|
|
elif key[i] == sep:
|
|
if current:
|
|
result.append(current)
|
|
current = ""
|
|
i += 1
|
|
else:
|
|
current += key[i]
|
|
i += 1
|
|
|
|
if current:
|
|
result.append(current)
|
|
|
|
return result
|
|
|
|
|
|
def _set_nested(d: Any, parts: List[str], value: Any) -> None:
|
|
"""Set a value in a nested structure, creating intermediate dicts or lists as needed."""
|
|
current = d
|
|
|
|
for i, part in enumerate(parts[:-1]):
|
|
if isinstance(current, list):
|
|
if not part.isdigit():
|
|
raise ValueError(f"Cannot use string key '{part}' with list")
|
|
idx = int(part)
|
|
while len(current) <= idx:
|
|
current.append(None)
|
|
if current[idx] is None:
|
|
if parts[i + 1].isdigit():
|
|
current[idx] = []
|
|
else:
|
|
current[idx] = {}
|
|
current = current[idx]
|
|
elif isinstance(current, dict):
|
|
if part in current:
|
|
existing = current[part]
|
|
if isinstance(existing, list):
|
|
current = existing
|
|
elif isinstance(existing, dict):
|
|
current = existing
|
|
else:
|
|
raise ValueError(f"Key conflict: '{part}' exists as both leaf and intermediate")
|
|
else:
|
|
next_part = parts[i + 1] if i + 1 < len(parts) else None
|
|
if next_part is not None and next_part.isdigit():
|
|
current[part] = []
|
|
current = current[part]
|
|
else:
|
|
current[part] = {}
|
|
current = current[part]
|
|
else:
|
|
raise ValueError("Cannot traverse through non-dict/list object")
|
|
|
|
last_part = parts[-1]
|
|
if isinstance(current, list):
|
|
if last_part.isdigit():
|
|
idx = int(last_part)
|
|
while len(current) <= idx:
|
|
current.append(None)
|
|
current[idx] = value
|
|
else:
|
|
raise ValueError(f"Cannot use string key '{last_part}' with list")
|
|
else:
|
|
current[last_part] = value
|