diff --git a/src/auto_readme/analyzers/go_analyzer.py b/src/auto_readme/analyzers/go_analyzer.py new file mode 100644 index 0000000..884a569 --- /dev/null +++ b/src/auto_readme/analyzers/go_analyzer.py @@ -0,0 +1,152 @@ +"""Go code analyzer using tree-sitter.""" + +from pathlib import Path +from typing import Optional +from tree_sitter import Language, Node, Parser + +from tree_sitter_go import language as go_language + +from . import BaseAnalyzer +from ..models import Function, Class, ImportStatement + + +class GoAnalyzer(BaseAnalyzer): + """Analyzer for Go source files.""" + + SUPPORTED_EXTENSIONS = {".go"} + + def can_analyze(self, path: Path) -> bool: + """Check if this analyzer can handle the file.""" + return path.suffix.lower() in self.SUPPORTED_EXTENSIONS + + def analyze(self, path: Path) -> dict: + """Analyze a Go file and extract functions, structs, and imports.""" + content = self._get_file_content(path) + if not content: + return {"functions": [], "classes": [], "imports": []} + + content_bytes = content.encode("utf-8") + + try: + lang = Language(go_language()) + parser = Parser(language=lang) + tree = parser.parse(content_bytes) + except Exception: + return {"functions": [], "classes": [], "imports": []} + + functions = self._extract_functions(tree.root_node, content, content_bytes) + classes = self._extract_structs(tree.root_node, content_bytes) + imports = self._extract_imports(tree.root_node, content_bytes) + + return { + "functions": functions, + "classes": classes, + "imports": imports, + } + + def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]: + """Extract function definitions from the AST.""" + functions = [] + + if node.type == "function_declaration": + func = self._parse_function(node, content, content_bytes) + if func: + functions.append(func) + + for child in node.children: + funcs = self._extract_functions(child, content, content_bytes) + functions.extend(funcs) + + return functions + + def _extract_structs(self, node: Node, content_bytes: bytes) -> list[Class]: + """Extract struct definitions from the AST.""" + structs = [] + + if node.type == "type_spec": + for child in node.children: + if child.type == "type_identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + struct = Class( + name=name, + line_number=node.start_point[0] + 1, + ) + structs.append(struct) + break + + for child in node.children: + structs.extend(self._extract_structs(child, content_bytes)) + + return structs + + def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]: + """Extract import statements from the AST.""" + imports = [] + + if node.type == "import_declaration": + imp = self._parse_import(node, content_bytes) + if imp: + imports.append(imp) + + for child in node.children: + imports.extend(self._extract_imports(child, content_bytes)) + + return imports + + def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]: + """Parse a function definition node.""" + name = None + parameters = [] + return_type = None + line_number = node.start_point[0] + 1 + + for child in node.children: + if child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "parameter_list": + parameters = self._parse_parameters(child, content_bytes) + elif child.type == "type_identifier" or child.type == "qualified_type": + return_type = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "pointer_type": + for grandchild in child.children: + if grandchild.type == "type_identifier": + return_type = "*" + content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + break + + return Function( + name=name or "unknown", + parameters=parameters, + return_type=return_type, + line_number=line_number, + ) + + def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]: + """Parse an import declaration node.""" + line_number = node.start_point[0] + 1 + module = None + alias = None + + for child in node.children: + if child.type == "import_spec": + for grandchild in child.children: + if grandchild.type == "string": + module = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8").strip('"') + elif grandchild.type == "identifier": + alias = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + + return ImportStatement( + module=module or "", + alias=alias, + line_number=line_number, + ) + + def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]: + """Parse function parameters.""" + params = [] + for child in node.children: + if child.type == "parameter_declaration": + for grandchild in child.children: + if grandchild.type == "identifier": + params.append(content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")) + break + return params