diff --git a/src/termdiagram/parser/symbol_extractor.py b/src/termdiagram/parser/symbol_extractor.py new file mode 100644 index 0000000..1e99ec5 --- /dev/null +++ b/src/termdiagram/parser/symbol_extractor.py @@ -0,0 +1,80 @@ +import re +from typing import Dict, List, Any +from ..models import ClassSymbol, FunctionSymbol, MethodSymbol + + +class SymbolExtractor: + def __init__(self): + self.patterns = { + "python": { + "classes": r"^class\s+(\w+)\s*(?:\(([^)]*)\))?:", + "functions": r"^def\s+(\w+)\s*\(([^)]*)\):", + "methods": r"^ def\s+(\w+)\s*\(([^)]*)\):", + "imports": r"^(?:from|import)\s+(.+)", + } + } + + def extract(self, content: str, language: str) -> Dict[str, List[Any]]: + if language not in self.patterns: + return {"classes": [], "functions": [], "imports": []} + + patterns = self.patterns[language] + lines = content.split("\n") + + classes = self._extract_classes(lines, patterns["classes"]) + functions = self._extract_functions(lines, patterns["functions"]) + imports = self._extract_imports(lines, patterns["imports"]) + + return { + "classes": classes, + "functions": functions, + "imports": imports, + } + + def _extract_classes(self, lines: List[str], pattern: str) -> List[ClassSymbol]: + classes = [] + current_class = None + + for i, line in enumerate(lines): + match = re.match(pattern, line) + if match: + name = match.group(1) + bases = [b.strip() for b in match.group(2).split(",")] if match.group(2) else [] + current_class = ClassSymbol(name=name, bases=bases, line_number=i + 1) + classes.append(current_class) + elif current_class and line.strip().startswith("def "): + method_match = re.match(r"^ def\s+(\w+)\s*\(([^)]*)\):", line) + if method_match: + method = MethodSymbol( + name=method_match.group(1), + params=[p.strip() for p in method_match.group(2).split(",") if p.strip()], + line_number=i + 1, + ) + current_class.methods.append(method) + + return classes + + def _extract_functions(self, lines: List[str], pattern: str) -> List[FunctionSymbol]: + functions = [] + + for i, line in enumerate(lines): + match = re.match(pattern, line) + if match: + func = FunctionSymbol( + name=match.group(1), + params=[p.strip() for p in match.group(2).split(",") if p.strip()], + line_number=i + 1, + ) + functions.append(func) + + return functions + + def _extract_imports(self, lines: List[str], pattern: str) -> List[str]: + imports = [] + + for line in lines: + match = re.match(pattern, line) + if match: + imports.append(match.group(1)) + + return imports