diff --git a/codesnap/codesnap/core/extractor.py b/codesnap/codesnap/core/extractor.py new file mode 100644 index 0000000..77f6294 --- /dev/null +++ b/codesnap/codesnap/core/extractor.py @@ -0,0 +1,274 @@ +"""Function and class extraction module using tree-sitter queries.""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + + +@dataclass +class FunctionInfo: + """Information about a extracted function.""" + + name: str + file_path: Path + language: str + start_line: int + end_line: int + parameters: list[str] = field(default_factory=list) + return_type: Optional[str] = None + docstring: Optional[str] = None + decorators: list[str] = field(default_factory=list) + is_method: bool = False + class_name: Optional[str] = None + + +@dataclass +class ClassInfo: + """Information about a extracted class.""" + + name: str + file_path: Path + language: str + start_line: int + end_line: int + docstring: Optional[str] = None + base_classes: list[str] = field(default_factory=list) + decorators: list[str] = field(default_factory=list) + methods: list[FunctionInfo] = field(default_factory=list) + + +@dataclass +class ExtractedCode: + """All extracted code elements from a file.""" + + file_path: Path + language: str + functions: list[FunctionInfo] = field(default_factory=list) + classes: list[ClassInfo] = field(default_factory=list) + + +class FunctionExtractor: + """Extracts functions and classes from source code.""" + + PYTHON_FUNC_QUERY = """ + (function_definition + name: (identifier) @func_name + parameters: (parameters) @params + return_type: (type_annotation)? @return_type + body: (block) @body) @func_def + """ + + PYTHON_CLASS_QUERY = """ + (class_definition + name: (identifier) @class_name + base_classes: (argument_list)? @bases + body: (block) @body) @class_def + """ + + JS_FUNC_QUERY = """ + (function_declaration + name: (identifier) @func_name + parameters: (formal_parameters) @params + body: (block_statement) @body) @func_def + """ + + JS_ARROW_FUNC_QUERY = """ + (arrow_function + name: (identifier)? @func_name + parameters: (formal_parameters) @params + body: [(block_statement) (expression)] @body) @func_def + """ + + JS_CLASS_QUERY = """ + (class_declaration + name: (identifier) @class_name + body: (class_body) @body) @class_def + """ + + def __init__(self) -> None: + self._func_queries: dict[str, str] = {} + self._class_queries: dict[str, str] = {} + self._setup_queries() + + def _setup_queries(self) -> None: + """Set up tree-sitter queries for different languages.""" + self._func_queries = { + "python": self.PYTHON_FUNC_QUERY, + "javascript": self.JS_FUNC_QUERY, + "typescript": self.JS_FUNC_QUERY, + } + self._class_queries = { + "python": self.PYTHON_CLASS_QUERY, + "javascript": self.JS_CLASS_QUERY, + "typescript": self.JS_CLASS_QUERY, + } + + def extract_from_content( + self, content: str, path: Path, language: str + ) -> ExtractedCode: + """Extract functions and classes from file content.""" + extracted = ExtractedCode(file_path=path, language=language) + + if language == "python": + self._extract_python(content, path, extracted) + elif language in ("javascript", "typescript"): + self._extract_javascript(content, path, extracted) + + return extracted + + def _extract_python( + self, content: str, path: Path, extracted: ExtractedCode + ) -> None: + """Extract Python functions and classes.""" + lines = content.split("\n") + in_docstring = False + current_docstring: list[str] = [] + + for i, line in enumerate(lines): + stripped = line.strip() + + if '"""' in stripped or "'''" in stripped: + if in_docstring: + docstring_content = "\n".join(current_docstring) + if extracted.classes: + extracted.classes[-1].docstring = docstring_content + elif extracted.functions: + extracted.functions[-1].docstring = docstring_content + in_docstring = False + current_docstring = [] + else: + in_docstring = True + current_docstring = [] + + import re + func_pattern = re.compile( + r"^\s*def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]*)\)\s*(->\s*[\w\[\]]+\s*)?:" + ) + class_pattern = re.compile(r"^\s*class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*(\([^)]*\))?:") + + in_class = False + current_class: Optional[ClassInfo] = None + + for i, line in enumerate(lines): + func_match = func_pattern.match(line) + class_match = class_pattern.match(line) + + if class_match: + class_name = class_match.group(1) + bases = class_match.group(2) if class_match.group(2) else "" + base_list = [b.strip() for b in bases.strip("()").split(",") if b.strip()] + + current_class = ClassInfo( + name=class_name, + file_path=path, + language="python", + start_line=i + 1, + end_line=i + 1, + base_classes=base_list, + ) + extracted.classes.append(current_class) + in_class = True + elif func_match: + func_name = func_match.group(1) + params_str = func_match.group(2) or "" + return_type = func_match.group(3) if func_match.group(3) else None + + params = [p.strip().split(":")[0] for p in params_str.split(",") if p.strip()] + + func_info = FunctionInfo( + name=func_name, + file_path=path, + language="python", + start_line=i + 1, + end_line=i + 1, + parameters=params, + return_type=return_type.strip() if return_type else None, + is_method=in_class, + class_name=current_class.name if current_class else None, + ) + + if current_class: + current_class.methods.append(func_info) + extracted.functions.append(func_info) + + if extracted.classes: + extracted.classes[-1].end_line = len(lines) + + def _extract_javascript( + self, content: str, path: Path, extracted: ExtractedCode + ) -> None: + """Extract JavaScript/TypeScript functions and classes.""" + lines = content.split("\n") + import re + + func_pattern = re.compile( + r"^\s*(?:async\s+)?function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]*)\)" + ) + arrow_func_pattern = re.compile( + r"^\s*(?:const|let|var)\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>" + ) + class_pattern = re.compile(r"^\s*class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:extends\s+[\w$]+)?") + + in_class = False + current_class: Optional[ClassInfo] = None + + for i, line in enumerate(lines): + func_match = func_pattern.match(line) + arrow_match = arrow_func_pattern.match(line) + class_match = class_pattern.match(line) + + if class_match: + class_name = class_match.group(1) + extends_match = re.search(r"extends\s+([\w$]+)", line) + bases = [extends_match.group(1)] if extends_match else [] + + current_class = ClassInfo( + name=class_name, + file_path=path, + language=extracted.language, + start_line=i + 1, + end_line=i + 1, + base_classes=bases, + ) + extracted.classes.append(current_class) + in_class = True + elif func_match: + func_name = func_match.group(1) + params_str = func_match.group(2) or "" + params = [p.strip() for p in params_str.split(",") if p.strip()] + + func_info = FunctionInfo( + name=func_name, + file_path=path, + language=extracted.language, + start_line=i + 1, + end_line=i + 1, + parameters=params, + is_method=in_class, + class_name=current_class.name if current_class else None, + ) + + if current_class: + current_class.methods.append(func_info) + extracted.functions.append(func_info) + elif arrow_match: + func_name = arrow_match.group(1) + func_info = FunctionInfo( + name=func_name, + file_path=path, + language=extracted.language, + start_line=i + 1, + end_line=i + 1, + is_method=in_class, + class_name=current_class.name if current_class else None, + ) + + if current_class: + current_class.methods.append(func_info) + extracted.functions.append(func_info) + + if in_class and "{" in line: + pass + + if extracted.classes: + extracted.classes[-1].end_line = len(lines)