diff --git a/src/auto_readme/analyzers/rust_analyzer.py b/src/auto_readme/analyzers/rust_analyzer.py new file mode 100644 index 0000000..48047d2 --- /dev/null +++ b/src/auto_readme/analyzers/rust_analyzer.py @@ -0,0 +1,187 @@ +"""Rust code analyzer using tree-sitter.""" + +from pathlib import Path +from typing import Optional +from tree_sitter import Language, Node, Parser + +from tree_sitter_rust import language as rust_language + +from . import BaseAnalyzer +from ..models import Function, Class, ImportStatement + + +class RustAnalyzer(BaseAnalyzer): + """Analyzer for Rust source files.""" + + SUPPORTED_EXTENSIONS = {".rs"} + + def can_analyze(self, path: Path) -> bool: + """Check if this analyzer can handle the file.""" + return path.suffix.lower() in self.SUPPORTED_EXTENSIONS + + def analyze(self, path: Path) -> dict: + """Analyze a Rust file and extract functions, structs, and imports.""" + content = self._get_file_content(path) + if not content: + return {"functions": [], "classes": [], "imports": []} + + content_bytes = content.encode("utf-8") + + try: + lang = Language(rust_language()) + parser = Parser(language=lang) + tree = parser.parse(content_bytes) + except Exception: + return {"functions": [], "classes": [], "imports": []} + + functions = self._extract_functions(tree.root_node, content, content_bytes) + classes = self._extract_structs(tree.root_node, content_bytes) + imports = self._extract_imports(tree.root_node, content_bytes) + + return { + "functions": functions, + "classes": classes, + "imports": imports, + } + + def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]: + """Extract function definitions from the AST.""" + functions = [] + + if node.type in ("function_item", "function_signature"): + func = self._parse_function(node, content, content_bytes) + if func: + functions.append(func) + + for child in node.children: + funcs = self._extract_functions(child, content, content_bytes) + functions.extend(funcs) + + return functions + + def _extract_structs(self, node: Node, content_bytes: bytes) -> list[Class]: + """Extract struct/enum definitions from the AST.""" + structs = [] + + if node.type == "struct_item": + name = None + fields = [] + for child in node.children: + if child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "field_declaration_list": + for grandchild in child.children: + if grandchild.type == "field_identifier": + fields.append( + content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + ) + if name: + structs.append( + Class( + name=name, + attributes=fields, + line_number=node.start_point[0] + 1, + ) + ) + elif node.type == "enum_item": + name = None + variants = [] + for child in node.children: + if child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "enum_variant_list": + for grandchild in child.children: + if grandchild.type == "identifier": + variants.append( + content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + ) + if name: + structs.append( + Class( + name=name, + attributes=variants, + line_number=node.start_point[0] + 1, + ) + ) + + for child in node.children: + structs.extend(self._extract_structs(child, content_bytes)) + + return structs + + def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]: + """Extract use statements from the AST.""" + imports = [] + + if node.type == "use_declaration": + imp = self._parse_import(node, content_bytes) + if imp: + imports.append(imp) + + for child in node.children: + imports.extend(self._extract_imports(child, content_bytes)) + + return imports + + def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]: + """Parse a function definition node.""" + name = None + parameters = [] + return_type = None + visibility = "private" + line_number = node.start_point[0] + 1 + + for child in node.children: + if child.type == "visibility_modifier": + visibility = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "identifier": + name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "parameters": + parameters = self._parse_parameters(child, content_bytes) + elif child.type == "return_type": + for grandchild in child.children: + if grandchild.type in ("type_identifier", "qualified_type"): + return_type = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + + return Function( + name=name or "unknown", + parameters=parameters, + return_type=return_type, + line_number=line_number, + visibility=visibility if visibility != "pub(crate)" else "public", + ) + + def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]: + """Parse a use declaration node.""" + line_number = node.start_point[0] + 1 + module = None + + for child in node.children: + if child.type == "use_path": + module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") + elif child.type == "use_as_path": + for grandchild in child.children: + if grandchild.type == "use_path": + module = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") + break + + return ImportStatement( + module=module or "", + line_number=line_number, + ) + + def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]: + """Parse function parameters.""" + params = [] + for child in node.children: + if child.type == "parameter": + for grandchild in child.children: + if grandchild.type == "identifier": + params.append(content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")) + break + elif grandchild.type == "pattern": + for ggchild in grandchild.children: + if ggchild.type == "identifier": + params.append(content_bytes[ggchild.start_byte : ggchild.end_byte].decode("utf-8")) + break + return params