diff --git a/codechunk/core/dependency.py b/codechunk/core/dependency.py new file mode 100644 index 0000000..2cc1bc6 --- /dev/null +++ b/codechunk/core/dependency.py @@ -0,0 +1,215 @@ +from typing import List, Dict, Set, Optional +from dataclasses import dataclass, field +from pathlib import Path +from codechunk.core.chunking import ParsedChunk + + +@dataclass +class DependencyNode: + chunk_name: str + file_path: Path + module_name: str + dependencies: Set[str] = field(default_factory=set) + dependents: Set[str] = field(default_factory=set) + is_circular: bool = False + + +class DependencyAnalyzer: + def __init__(self): + self.dependency_graph: Dict[str, DependencyNode] = {} + self.module_to_chunks: Dict[str, List[str]] = {} + + def analyze_dependencies(self, chunks: List[ParsedChunk], + project_files: List[Path]) -> Dict[str, DependencyNode]: + """Analyze dependencies between chunks.""" + self.dependency_graph = {} + self.module_to_chunks = {} + + project_files_set = set(project_files) + module_cache = self._build_module_cache(project_files) + + for chunk in chunks: + node = DependencyNode( + chunk_name=chunk.name, + file_path=chunk.metadata.file_path, + module_name=self._get_module_name(chunk.metadata.file_path, project_files_set) + ) + + for imp in chunk.metadata.imports: + resolved = self._resolve_import(imp, chunk.metadata.file_path, + project_files_set, module_cache) + if resolved: + node.dependencies.add(resolved) + + self.dependency_graph[chunk.name] = node + + module = node.module_name + if module not in self.module_to_chunks: + self.module_to_chunks[module] = [] + self.module_to_chunks[module].append(chunk.name) + + self._build_dependency_links() + self._detect_circular_dependencies() + + return self.dependency_graph + + def _build_module_cache(self, project_files: List[Path]) -> Dict[Path, str]: + """Build cache of file to module name mappings.""" + cache = {} + for file_path in project_files: + module_name = self._get_module_name(file_path, set(project_files)) + cache[file_path] = module_name + return cache + + def _get_module_name(self, file_path: Path, project_root: Set[Path]) -> str: + """Get module name from file path.""" + try: + if project_root: + root = min(project_root, key=lambda p: len(p.parts)) + rel_path = file_path.relative_to(root) + else: + rel_path = file_path + + parts = list(rel_path.parts) + if not parts: + return file_path.stem + + if parts[-1] == "__init__.py": + parts = parts[:-1] + else: + parts[-1] = parts[-1].rsplit('.', 1)[0] + + return '.'.join(parts) + except (ValueError, IndexError): + return file_path.stem + + def _resolve_import(self, import_str: str, current_file: Path, + project_root: Set[Path], module_cache: Dict[Path, str]) -> Optional[str]: + """Resolve import string to module name.""" + clean_import = import_str.strip() + + parts = clean_import.split('.') + + base_module = parts[0] + + if base_module == 'self' or base_module == '.': + return self._get_module_name(current_file, project_root) + + if parts[0] in ['os', 'sys', 're', 'json', 'yaml', 'typing', 'collections', + 'itertools', 'functools', 'abc', 'dataclasses', 'enum', + 'pathlib', 'abc', 'threading', 'multiprocessing', 'asyncio', + 'requests', 'flask', 'django', 'fastapi', 'numpy', 'pandas', + 'torch', 'tensorflow', 'matplotlib', 'scipy', 'sklearn']: + return None + + for file_path, module_name in module_cache.items(): + if module_name.endswith(base_module) or module_name == base_module: + return module_name + + if len(parts) > 1: + parent_module = '.'.join(parts[:-1]) + if module_name.endswith(parent_module) or module_name == parent_module: + return module_name + + return clean_import + + def _build_dependency_links(self): + """Build reverse dependency links (dependents).""" + for node in self.dependency_graph.values(): + for dep in node.dependencies: + if dep in self.dependency_graph: + self.dependency_graph[dep].dependents.add(node.chunk_name) + + def _detect_circular_dependencies(self): + """Detect circular dependencies in the graph.""" + visited = set() + rec_stack = set() + + def detect_cycle(node_name: str, path: List[str]) -> bool: + visited.add(node_name) + rec_stack.add(node_name) + path.append(node_name) + + node = self.dependency_graph.get(node_name) + if node: + for dep in node.dependencies: + if dep not in visited: + if detect_cycle(dep, path): + return True + elif dep in rec_stack: + for n in path + [dep]: + if n in self.dependency_graph: + self.dependency_graph[n].is_circular = True + return True + + rec_stack.remove(node_name) + path.pop() + return False + + for node_name in self.dependency_graph: + if node_name not in visited: + detect_cycle(node_name, []) + + def get_essential_chunks(self, selected_chunks: List[str]) -> List[str]: + """Get all chunks needed including transitive dependencies.""" + essential = set(selected_chunks) + to_process = list(selected_chunks) + + while to_process: + chunk_name = to_process.pop(0) + node = self.dependency_graph.get(chunk_name) + if node: + for dep in node.dependencies: + if dep not in essential: + essential.add(dep) + to_process.append(dep) + + return list(essential) + + def get_impacted_chunks(self, modified_chunks: List[str]) -> List[str]: + """Get all chunks that depend on the modified chunks.""" + impacted = set(modified_chunks) + to_process = list(modified_chunks) + + while to_process: + chunk_name = to_process.pop(0) + node = self.dependency_graph.get(chunk_name) + if node: + for dependent in node.dependents: + if dependent not in impacted: + impacted.add(dependent) + to_process.append(dependent) + + return list(impacted) + + def get_dependency_stats(self) -> Dict[str, int]: + """Get statistics about dependencies.""" + stats = { + "total_nodes": len(self.dependency_graph), + "nodes_with_deps": 0, + "nodes_with_dependents": 0, + "circular_deps": 0, + "total_edges": 0, + "max_dependencies": 0, + "max_dependents": 0, + } + + for node in self.dependency_graph.values(): + dep_count = len(node.dependencies) + depent_count = len(node.dependents) + + stats["total_edges"] += dep_count + + if dep_count > 0: + stats["nodes_with_deps"] += 1 + + if depent_count > 0: + stats["nodes_with_dependents"] += 1 + + if node.is_circular: + stats["circular_deps"] += 1 + + stats["max_dependencies"] = max(stats["max_dependencies"], dep_count) + stats["max_dependents"] = max(stats["max_dependents"], depent_count) + + return stats