from typing import List, Dict, Set, Optional from dataclasses import dataclass, field from pathlib import Path from codechunk.core.chunking import ParsedChunk @dataclass class DependencyNode: chunk_name: str file_path: Path module_name: str dependencies: Set[str] = field(default_factory=set) dependents: Set[str] = field(default_factory=set) is_circular: bool = False class DependencyAnalyzer: def __init__(self): self.dependency_graph: Dict[str, DependencyNode] = {} self.module_to_chunks: Dict[str, List[str]] = {} def analyze_dependencies(self, chunks: List[ParsedChunk], project_files: List[Path]) -> Dict[str, DependencyNode]: """Analyze dependencies between chunks.""" self.dependency_graph = {} self.module_to_chunks = {} project_files_set = set(project_files) module_cache = self._build_module_cache(project_files) for chunk in chunks: node = DependencyNode( chunk_name=chunk.name, file_path=chunk.metadata.file_path, module_name=self._get_module_name(chunk.metadata.file_path, project_files_set) ) for imp in chunk.metadata.imports: resolved = self._resolve_import(imp, chunk.metadata.file_path, project_files_set, module_cache) if resolved: node.dependencies.add(resolved) self.dependency_graph[chunk.name] = node module = node.module_name if module not in self.module_to_chunks: self.module_to_chunks[module] = [] self.module_to_chunks[module].append(chunk.name) self._build_dependency_links() self._detect_circular_dependencies() return self.dependency_graph def _build_module_cache(self, project_files: List[Path]) -> Dict[Path, str]: """Build cache of file to module name mappings.""" cache = {} for file_path in project_files: module_name = self._get_module_name(file_path, set(project_files)) cache[file_path] = module_name return cache def _get_module_name(self, file_path: Path, project_root: Set[Path]) -> str: """Get module name from file path.""" try: if project_root: root = min(project_root, key=lambda p: len(p.parts)) rel_path = file_path.relative_to(root) else: rel_path = file_path parts = list(rel_path.parts) if not parts: return file_path.stem if parts[-1] == "__init__.py": parts = parts[:-1] else: parts[-1] = parts[-1].rsplit('.', 1)[0] return '.'.join(parts) except (ValueError, IndexError): return file_path.stem def _resolve_import(self, import_str: str, current_file: Path, project_root: Set[Path], module_cache: Dict[Path, str]) -> Optional[str]: """Resolve import string to module name.""" clean_import = import_str.strip() parts = clean_import.split('.') base_module = parts[0] if base_module == 'self' or base_module == '.': return self._get_module_name(current_file, project_root) if parts[0] in ['os', 'sys', 're', 'json', 'yaml', 'typing', 'collections', 'itertools', 'functools', 'abc', 'dataclasses', 'enum', 'pathlib', 'abc', 'threading', 'multiprocessing', 'asyncio', 'requests', 'flask', 'django', 'fastapi', 'numpy', 'pandas', 'torch', 'tensorflow', 'matplotlib', 'scipy', 'sklearn']: return None for file_path, module_name in module_cache.items(): if module_name.endswith(base_module) or module_name == base_module: return module_name if len(parts) > 1: parent_module = '.'.join(parts[:-1]) if module_name.endswith(parent_module) or module_name == parent_module: return module_name return clean_import def _build_dependency_links(self): """Build reverse dependency links (dependents).""" for node in self.dependency_graph.values(): for dep in node.dependencies: if dep in self.dependency_graph: self.dependency_graph[dep].dependents.add(node.chunk_name) def _detect_circular_dependencies(self): """Detect circular dependencies in the graph.""" visited = set() rec_stack = set() def detect_cycle(node_name: str, path: List[str]) -> bool: visited.add(node_name) rec_stack.add(node_name) path.append(node_name) node = self.dependency_graph.get(node_name) if node: for dep in node.dependencies: if dep not in visited: if detect_cycle(dep, path): return True elif dep in rec_stack: for n in path + [dep]: if n in self.dependency_graph: self.dependency_graph[n].is_circular = True return True rec_stack.remove(node_name) path.pop() return False for node_name in self.dependency_graph: if node_name not in visited: detect_cycle(node_name, []) def get_essential_chunks(self, selected_chunks: List[str]) -> List[str]: """Get all chunks needed including transitive dependencies.""" essential = set(selected_chunks) to_process = list(selected_chunks) while to_process: chunk_name = to_process.pop(0) node = self.dependency_graph.get(chunk_name) if node: for dep in node.dependencies: if dep not in essential: essential.add(dep) to_process.append(dep) return list(essential) def get_impacted_chunks(self, modified_chunks: List[str]) -> List[str]: """Get all chunks that depend on the modified chunks.""" impacted = set(modified_chunks) to_process = list(modified_chunks) while to_process: chunk_name = to_process.pop(0) node = self.dependency_graph.get(chunk_name) if node: for dependent in node.dependents: if dependent not in impacted: impacted.add(dependent) to_process.append(dependent) return list(impacted) def get_dependency_stats(self) -> Dict[str, int]: """Get statistics about dependencies.""" stats = { "total_nodes": len(self.dependency_graph), "nodes_with_deps": 0, "nodes_with_dependents": 0, "circular_deps": 0, "total_edges": 0, "max_dependencies": 0, "max_dependents": 0, } for node in self.dependency_graph.values(): dep_count = len(node.dependencies) depent_count = len(node.dependents) stats["total_edges"] += dep_count if dep_count > 0: stats["nodes_with_deps"] += 1 if depent_count > 0: stats["nodes_with_dependents"] += 1 if node.is_circular: stats["circular_deps"] += 1 stats["max_dependencies"] = max(stats["max_dependencies"], dep_count) stats["max_dependents"] = max(stats["max_dependents"], depent_count) return stats