diff --git a/src/generators/merge.py b/src/generators/merge.py new file mode 100644 index 0000000..52e9a09 --- /dev/null +++ b/src/generators/merge.py @@ -0,0 +1,264 @@ +"""Pattern merging logic for handling multiple stacks.""" + +from typing import Dict, List, Optional, Set, Tuple + + +class PatternMerger: + """Merges patterns from multiple sources with deduplication and conflict resolution.""" + + def __init__(self) -> None: + self.priority_weights = { + "os": 10, + "ide": 20, + "language": 30, + "framework": 40, + "custom": 50, + } + + def merge( + self, + patterns_by_category: Dict[str, List[str]], + custom_patterns: Optional[List[str]] = None, + strategy: str = "priority", + ) -> Tuple[List[str], List[str]]: + """Merge patterns from multiple categories. + + Args: + patterns_by_category: Dict mapping category names to pattern lists. + custom_patterns: List of custom user patterns. + strategy: Merge strategy - "priority", "append", or "prepend". + + Returns: + Tuple of (merged_patterns, conflicts). + """ + if custom_patterns: + patterns_by_category["custom"] = custom_patterns + + self._get_category_order(set(patterns_by_category.keys())) + + if strategy == "append": + return self._merge_append(patterns_by_category), [] + elif strategy == "prepend": + return self._merge_prepend(patterns_by_category), [] + else: + return self._merge_priority(patterns_by_category) + + def _get_category_order(self, categories: Set[str]) -> List[str]: + """Get ordered list of categories based on priority. + + Args: + categories: Set of category names. + + Returns: + Ordered list of categories. + """ + priority_order = ["os", "ide", "language", "framework", "custom"] + ordered = [] + for cat in priority_order: + if cat in categories: + ordered.append(cat) + for cat in sorted(categories): + if cat not in ordered: + ordered.append(cat) + return ordered + + def _merge_priority(self, patterns_by_category: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: + """Merge patterns by priority (lower priority categories first). + + Args: + patterns_by_category: Dict mapping category names to pattern lists. + + Returns: + Tuple of (merged_patterns, conflicts). + """ + category_order = self._get_category_order(set(patterns_by_category.keys())) + seen_patterns: Dict[str, str] = {} + merged: List[str] = [] + conflicts: List[str] = [] + + category_patterns: List[Tuple[str, List[str], int]] = [] + for category in category_order: + patterns = patterns_by_category.get(category, []) + weight = self.priority_weights.get(category, 0) + category_patterns.append((category, patterns, weight)) + + category_patterns.sort(key=lambda x: x[2]) + + for category, patterns, _ in category_patterns: + for pattern in patterns: + if not pattern or pattern.strip().startswith("#"): + continue + + normalized = self._normalize_pattern(pattern) + if not normalized: + continue + + if normalized in seen_patterns: + if seen_patterns[normalized] != category: + conflicts.append(f"{normalized} (from {seen_patterns[normalized]} and {category})") + else: + seen_patterns[normalized] = category + merged.append(pattern) + + return merged, conflicts + + def _merge_append(self, patterns_by_category: Dict[str, List[str]]) -> List[str]: + """Merge patterns by appending in order. + + Args: + patterns_by_category: Dict mapping category names to pattern lists. + + Returns: + List of merged patterns. + """ + category_order = self._get_category_order(set(patterns_by_category.keys())) + seen_patterns: Set[str] = set() + merged: List[str] = [] + + for category in category_order: + patterns = patterns_by_category.get(category, []) + for pattern in patterns: + if not pattern or pattern.strip().startswith("#"): + continue + normalized = self._normalize_pattern(pattern) + if normalized and normalized not in seen_patterns: + seen_patterns.add(normalized) + merged.append(pattern) + + return merged + + def _merge_prepend(self, patterns_by_category: Dict[str, List[str]]) -> List[str]: + """Merge patterns by prepending in order. + + Args: + patterns_by_category: Dict mapping category names to pattern lists. + + Returns: + List of merged patterns. + """ + category_order = self._get_category_order(set(patterns_by_category.keys())) + seen_patterns: Set[str] = set() + merged: List[str] = [] + + for category in reversed(category_order): + patterns = patterns_by_category.get(category, []) + section: List[str] = [] + for pattern in patterns: + if not pattern or pattern.strip().startswith("#"): + continue + normalized = self._normalize_pattern(pattern) + if normalized and normalized not in seen_patterns: + seen_patterns.add(normalized) + section.append(pattern) + merged = section + merged + + return merged + + def _normalize_pattern(self, pattern: str) -> Optional[str]: + """Normalize a pattern for comparison. + + Args: + pattern: Pattern to normalize. + + Returns: + Normalized pattern or None. + """ + if not pattern: + return None + + pattern = pattern.strip() + if not pattern: + return None + + if pattern.startswith("#"): + return None + + return pattern + + def deduplicate(self, patterns: List[str]) -> List[str]: + """Remove duplicate patterns while preserving order. + + Args: + patterns: List of patterns to deduplicate. + + Returns: + Deduplicated list of patterns. + """ + seen: Set[str] = set() + unique: List[str] = [] + + for pattern in patterns: + normalized = self._normalize_pattern(pattern) + if normalized and normalized not in seen: + seen.add(normalized) + unique.append(pattern) + + return unique + + def resolve_conflict(self, pattern: str, existing: str, resolution: str) -> str: + """Resolve a pattern conflict. + + Args: + pattern: New pattern being added. + existing: Existing pattern. + resolution: Resolution strategy - "keep_existing", "keep_new", "combine". + + Returns: + Resolved pattern. + """ + if resolution == "keep_existing": + return existing + elif resolution == "keep_new": + return pattern + elif resolution == "combine": + return f"{existing}\n{pattern}" + else: + return existing + + def get_pattern_sources(self, patterns: List[str]) -> Dict[str, Set[str]]: + """Identify which source files/patterns came from. + + Args: + patterns: List of patterns. + + Returns: + Dict mapping pattern categories to their patterns. + """ + sources: Dict[str, Set[str]] = {} + + for pattern in patterns: + if not pattern or pattern.strip().startswith("#"): + continue + normalized = self._normalize_pattern(pattern) + if normalized: + category = self._categorize_pattern(normalized) + if category not in sources: + sources[category] = set() + sources[category].add(normalized) + + return sources + + def _categorize_pattern(self, pattern: str) -> str: + """Categorize a pattern based on its content. + + Args: + pattern: Pattern to categorize. + + Returns: + Category name. + """ + os_indicators = [".DS_Store", "Thumbs.db", "~"] + ide_indicators = [".vscode", ".idea", ".project", ".settings"] + + for indicator in os_indicators: + if indicator in pattern: + return "os" + + for indicator in ide_indicators: + if indicator in pattern: + return "ide" + + if "node_modules" in pattern or "*.log" in pattern: + return "language" + + return "other"