"""Rust code analyzer using tree-sitter.""" from pathlib import Path from typing import Optional from tree_sitter import Language, Node, Parser from tree_sitter_rust import language as rust_language from . import BaseAnalyzer from ..models import Function, Class, ImportStatement class RustAnalyzer(BaseAnalyzer): """Analyzer for Rust source files.""" SUPPORTED_EXTENSIONS = {".rs"} def can_analyze(self, path: Path) -> bool: """Check if this analyzer can handle the file.""" return path.suffix.lower() in self.SUPPORTED_EXTENSIONS def analyze(self, path: Path) -> dict: """Analyze a Rust file and extract functions, structs, and imports.""" content = self._get_file_content(path) if not content: return {"functions": [], "classes": [], "imports": []} content_bytes = content.encode("utf-8") try: lang = Language(rust_language()) parser = Parser(language=lang) tree = parser.parse(content_bytes) except Exception: return {"functions": [], "classes": [], "imports": []} functions = self._extract_functions(tree.root_node, content, content_bytes) classes = self._extract_structs(tree.root_node, content_bytes) imports = self._extract_imports(tree.root_node, content_bytes) return { "functions": functions, "classes": classes, "imports": imports, } def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]: """Extract function definitions from the AST.""" functions = [] if node.type in ("function_item", "function_signature"): func = self._parse_function(node, content, content_bytes) if func: functions.append(func) for child in node.children: funcs = self._extract_functions(child, content, content_bytes) functions.extend(funcs) return functions def _extract_structs(self, node: Node, content_bytes: bytes) -> list[Class]: """Extract struct/enum definitions from the AST.""" structs = [] if node.type == "struct_item": name = None fields = [] for child in node.children: if child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "field_declaration_list": for grandchild in child.children: if grandchild.type == "field_identifier": fields.append( content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") ) if name: structs.append( Class( name=name, attributes=fields, line_number=node.start_point[0] + 1, ) ) elif node.type == "enum_item": name = None variants = [] for child in node.children: if child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "enum_variant_list": for grandchild in child.children: if grandchild.type == "identifier": variants.append( content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") ) if name: structs.append( Class( name=name, attributes=variants, line_number=node.start_point[0] + 1, ) ) for child in node.children: structs.extend(self._extract_structs(child, content_bytes)) return structs def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]: """Extract use statements from the AST.""" imports = [] if node.type == "use_declaration": imp = self._parse_import(node, content_bytes) if imp: imports.append(imp) for child in node.children: imports.extend(self._extract_imports(child, content_bytes)) return imports def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]: """Parse a function definition node.""" name = None parameters = [] return_type = None visibility = "private" line_number = node.start_point[0] + 1 for child in node.children: if child.type == "visibility_modifier": visibility = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "identifier": name = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "parameters": parameters = self._parse_parameters(child, content_bytes) elif child.type == "return_type": for grandchild in child.children: if grandchild.type in ("type_identifier", "qualified_type"): return_type = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") return Function( name=name or "unknown", parameters=parameters, return_type=return_type, line_number=line_number, visibility=visibility if visibility != "pub(crate)" else "public", ) def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]: """Parse a use declaration node.""" line_number = node.start_point[0] + 1 module = None for child in node.children: if child.type == "use_path": module = content_bytes[child.start_byte : child.end_byte].decode("utf-8") elif child.type == "use_as_path": for grandchild in child.children: if grandchild.type == "use_path": module = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8") break return ImportStatement( module=module or "", line_number=line_number, ) def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]: """Parse function parameters.""" params = [] for child in node.children: if child.type == "parameter": for grandchild in child.children: if grandchild.type == "identifier": params.append(content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")) break elif grandchild.type == "pattern": for ggchild in grandchild.children: if ggchild.type == "identifier": params.append(content_bytes[ggchild.start_byte : ggchild.end_byte].decode("utf-8")) break return params