"""Python code analyzer using tree-sitter.""" from pathlib import Path from typing import Optional from tree_sitter import Language, Node, Parser from tree_sitter_python import language as python_language from . import BaseAnalyzer from ..models import Function, Class, ImportStatement class PythonAnalyzer(BaseAnalyzer): """Analyzer for Python source files.""" SUPPORTED_EXTENSIONS = {".py", ".pyi"} def can_analyze(self, path: Path) -> bool: """Check if this analyzer can handle the file.""" return path.suffix.lower() in self.SUPPORTED_EXTENSIONS def analyze(self, path: Path) -> dict: """Analyze a Python file and extract functions, classes, and imports.""" content = self._get_file_content(path) if not content: return {"functions": [], "classes": [], "imports": []} content_bytes = content.encode("utf-8") try: lang = Language(python_language()) parser = Parser(language=lang) tree = parser.parse(content_bytes) except Exception: return {"functions": [], "classes": [], "imports": []} functions = self._extract_functions(tree.root_node, content, content_bytes) classes = self._extract_classes(tree.root_node, content, content_bytes) imports = self._extract_imports(tree.root_node, content_bytes) return { "functions": functions, "classes": classes, "imports": imports, } def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]: """Extract function definitions from the AST.""" functions = [] if node.type == "function_definition": func = self._parse_function(node, content, content_bytes) if func: functions.append(func) for child in node.children: funcs = self._extract_functions(child, content, content_bytes) functions.extend(funcs) return functions def _extract_classes(self, node: Node, content: str, content_bytes: bytes) -> list[Class]: """Extract class definitions from the AST.""" classes = [] if node.type == "class_definition": cls = self._parse_class(node, content, content_bytes) if cls: classes.append(cls) for child in node.children: classes.extend(self._extract_classes(child, content, content_bytes)) return classes def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]: """Extract import statements from the AST.""" imports = [] if node.type in ("import_statement", "import_from_statement"): imp = self._parse_import(node, content_bytes) if imp: imports.append(imp) for child in node.children: imports.extend(self._extract_imports(child, content_bytes)) return imports def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]: """Parse a function definition node.""" name = None parameters = [] docstring = None line_number = node.start_point[0] + 1 for child in node.children: if child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "parameters": parameters = self._parse_parameters(child, content_bytes) elif child.type == "block": docstring = self._parse_docstring(child, content_bytes) return Function( name=name or "unknown", parameters=parameters, docstring=docstring, line_number=line_number, ) def _parse_class(self, node: Node, content: str, content_bytes: bytes) -> Optional[Class]: """Parse a class definition node.""" name = None base_classes = [] docstring = None methods = [] line_number = node.start_point[0] + 1 for child in node.children: if child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "argument_list": base_classes = self._parse_base_classes(child, content_bytes) elif child.type == "block": docstring = self._parse_docstring(child, content_bytes) methods = self._extract_functions(child, content, content_bytes) return Class( name=name or "Unknown", base_classes=base_classes, docstring=docstring, methods=methods, line_number=line_number, ) def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]: """Parse an import statement node.""" line_number = node.start_point[0] + 1 if node.type == "import_statement": module = content_bytes[node.start_byte : node.end_byte].decode("utf-8") return ImportStatement( module=module, line_number=line_number, ) if node.type == "import_from_statement": module = None items = [] for child in node.children: if child.type == "module": module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "dotted_name": module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "import_as_names" or child.type == "import_as_name": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") items.append(name) elif child.type == "wildcard_import": return ImportStatement( module=module or "", line_number=line_number, is_from=True, ) return ImportStatement( module=module or "", items=items, line_number=line_number, is_from=True, ) return None def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]: """Parse function parameters.""" params = [] for child in node.children: if child.type == "identifier": param = content_bytes[child.start_byte : child.end_byte].decode("utf-8") params.append(param) elif child.type in ("default_parameter", "typed_parameter"): for grandchild in child.children: if grandchild.type == "identifier": param = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") params.append(param) break return params def _parse_base_classes(self, node: Node, content_bytes: bytes) -> list[str]: """Parse class base classes.""" bases = [] for child in node.children: if child.type == "attribute" or child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") bases.append(name) return bases def _parse_docstring(self, node: Node, content_bytes: bytes) -> Optional[str]: """Parse a docstring from a block.""" if node.children and node.children[0].type == "expression_statement": expr = node.children[0] if expr.children and expr.children[0].type == "string": string_content = content_bytes[expr.children[0].start_byte : expr.children[0].end_byte].decode("utf-8") if string_content.startswith('"""') or string_content.startswith("'''"): return string_content.strip('"""').strip("'''").strip() return None