diff --git a/src/auto_readme/analyzers/python_analyzer.py b/src/auto_readme/analyzers/python_analyzer.py new file mode 100644 index 0000000..37249c1 --- /dev/null +++ b/src/auto_readme/analyzers/python_analyzer.py @@ -0,0 +1,209 @@ +"""Python code analyzer using tree-sitter.""" + +from pathlib import Path +from typing import Optional +import tree_sitter +from tree_sitter import Language, Node, Parser + +from tree_sitter_python import language as python_language + +from . import BaseAnalyzer +from ..models import Function, Class, ImportStatement, SourceFile + + +class PythonAnalyzer(BaseAnalyzer): + """Analyzer for Python source files.""" + + SUPPORTED_EXTENSIONS = {".py", ".pyi"} + + def can_analyze(self, path: Path) -> bool: + """Check if this analyzer can handle the file.""" + return path.suffix.lower() in self.SUPPORTED_EXTENSIONS + + def analyze(self, path: Path) -> dict: + """Analyze a Python file and extract functions, classes, and imports.""" + content = self._get_file_content(path) + if not content: + return {"functions": [], "classes": [], "imports": []} + + content_bytes = content.encode("utf-8") + + try: + lang = Language(python_language()) + parser = Parser(language=lang) + tree = parser.parse(content_bytes) + except Exception: + return {"functions": [], "classes": [], "imports": []} + + functions = self._extract_functions(tree.root_node, content, content_bytes) + classes = self._extract_classes(tree.root_node, content, content_bytes) + imports = self._extract_imports(tree.root_node, content_bytes) + + return { + "functions": functions, + "classes": classes, + "imports": imports, + } + + def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]: + """Extract function definitions from the AST.""" + functions = [] + + if node.type == "function_definition": + func = self._parse_function(node, content, content_bytes) + if func: + functions.append(func) + + for child in node.children: + funcs = self._extract_functions(child, content, content_bytes) + functions.extend(funcs) + + return functions + + def _extract_classes(self, node: Node, content: str, content_bytes: bytes) -> list[Class]: + """Extract class definitions from the AST.""" + classes = [] + + if node.type == "class_definition": + cls = self._parse_class(node, content, content_bytes) + if cls: + classes.append(cls) + + for child in node.children: + classes.extend(self._extract_classes(child, content, content_bytes)) + + return classes + + def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]: + """Extract import statements from the AST.""" + imports = [] + + if node.type in ("import_statement", "import_from_statement"): + imp = self._parse_import(node, content_bytes) + if imp: + imports.append(imp) + + for child in node.children: + imports.extend(self._extract_imports(child, content_bytes)) + + return imports + + def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]: + """Parse a function definition node.""" + name = None + parameters = [] + docstring = None + line_number = node.start_point[0] + 1 + + for child in node.children: + if child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "parameters": + parameters = self._parse_parameters(child, content_bytes) + elif child.type == "block": + docstring = self._parse_docstring(child, content_bytes) + + return Function( + name=name or "unknown", + parameters=parameters, + docstring=docstring, + line_number=line_number, + ) + + def _parse_class(self, node: Node, content: str, content_bytes: bytes) -> Optional[Class]: + """Parse a class definition node.""" + name = None + base_classes = [] + docstring = None + methods = [] + line_number = node.start_point[0] + 1 + + for child in node.children: + if child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "argument_list": + base_classes = self._parse_base_classes(child, content_bytes) + elif child.type == "block": + docstring = self._parse_docstring(child, content_bytes) + methods = self._extract_functions(child, content, content_bytes) + + return Class( + name=name or "Unknown", + base_classes=base_classes, + docstring=docstring, + methods=methods, + line_number=line_number, + ) + + def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]: + """Parse an import statement node.""" + line_number = node.start_point[0] + 1 + + if node.type == "import_statement": + module = content_bytes[node.start_byte : node.end_byte].decode("utf-8") + return ImportStatement( + module=module, + line_number=line_number, + ) + + if node.type == "import_from_statement": + module = None + items = [] + + for child in node.children: + if child.type == "module": + module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "dotted_name": + module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "import_as_names" or child.type == "import_as_name": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + items.append(name) + elif child.type == "wildcard_import": + return ImportStatement( + module=module or "", + line_number=line_number, + is_from=True, + ) + + return ImportStatement( + module=module or "", + items=items, + line_number=line_number, + is_from=True, + ) + + return None + + def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]: + """Parse function parameters.""" + params = [] + for child in node.children: + if child.type == "identifier": + param = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + params.append(param) + elif child.type in ("default_parameter", "typed_parameter"): + for grandchild in child.children: + if grandchild.type == "identifier": + param = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + params.append(param) + break + return params + + def _parse_base_classes(self, node: Node, content_bytes: bytes) -> list[str]: + """Parse class base classes.""" + bases = [] + for child in node.children: + if child.type == "attribute" or child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + bases.append(name) + return bases + + def _parse_docstring(self, node: Node, content_bytes: bytes) -> Optional[str]: + """Parse a docstring from a block.""" + if node.children and node.children[0].type == "expression_statement": + expr = node.children[0] + if expr.children and expr.children[0].type == "string": + string_content = content_bytes[expr.children[0].start_byte : expr.children[0].end_byte].decode("utf-8") + if string_content.startswith('"""') or string_content.startswith("'''"): + return string_content.strip('"""').strip("'''").strip() + return None