Add code analyzers (Python, JS, Go, Rust) with tree-sitter
This commit is contained in:
187
src/auto_readme/analyzers/rust_analyzer.py
Normal file
187
src/auto_readme/analyzers/rust_analyzer.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Rust code analyzer using tree-sitter."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
from tree_sitter_rust import language as rust_language
|
||||
|
||||
from . import BaseAnalyzer
|
||||
from ..models import Function, Class, ImportStatement
|
||||
|
||||
|
||||
class RustAnalyzer(BaseAnalyzer):
|
||||
"""Analyzer for Rust source files."""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".rs"}
|
||||
|
||||
def can_analyze(self, path: Path) -> bool:
|
||||
"""Check if this analyzer can handle the file."""
|
||||
return path.suffix.lower() in self.SUPPORTED_EXTENSIONS
|
||||
|
||||
def analyze(self, path: Path) -> dict:
|
||||
"""Analyze a Rust file and extract functions, structs, and imports."""
|
||||
content = self._get_file_content(path)
|
||||
if not content:
|
||||
return {"functions": [], "classes": [], "imports": []}
|
||||
|
||||
content_bytes = content.encode("utf-8")
|
||||
|
||||
try:
|
||||
lang = Language(rust_language())
|
||||
parser = Parser(language=lang)
|
||||
tree = parser.parse(content_bytes)
|
||||
except Exception:
|
||||
return {"functions": [], "classes": [], "imports": []}
|
||||
|
||||
functions = self._extract_functions(tree.root_node, content, content_bytes)
|
||||
classes = self._extract_structs(tree.root_node, content_bytes)
|
||||
imports = self._extract_imports(tree.root_node, content_bytes)
|
||||
|
||||
return {
|
||||
"functions": functions,
|
||||
"classes": classes,
|
||||
"imports": imports,
|
||||
}
|
||||
|
||||
def _extract_functions(self, node: Node, content: str, content_bytes: bytes) -> list[Function]:
|
||||
"""Extract function definitions from the AST."""
|
||||
functions = []
|
||||
|
||||
if node.type in ("function_item", "function_signature"):
|
||||
func = self._parse_function(node, content, content_bytes)
|
||||
if func:
|
||||
functions.append(func)
|
||||
|
||||
for child in node.children:
|
||||
funcs = self._extract_functions(child, content, content_bytes)
|
||||
functions.extend(funcs)
|
||||
|
||||
return functions
|
||||
|
||||
def _extract_structs(self, node: Node, content_bytes: bytes) -> list[Class]:
|
||||
"""Extract struct/enum definitions from the AST."""
|
||||
structs = []
|
||||
|
||||
if node.type == "struct_item":
|
||||
name = None
|
||||
fields = []
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
name = content_bytes[child.start_byte : child.end_byte].decode("utf-8")
|
||||
elif child.type == "field_declaration_list":
|
||||
for grandchild in child.children:
|
||||
if grandchild.type == "field_identifier":
|
||||
fields.append(
|
||||
content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")
|
||||
)
|
||||
if name:
|
||||
structs.append(
|
||||
Class(
|
||||
name=name,
|
||||
attributes=fields,
|
||||
line_number=node.start_point[0] + 1,
|
||||
)
|
||||
)
|
||||
elif node.type == "enum_item":
|
||||
name = None
|
||||
variants = []
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
name = content_bytes[child.start_byte : child.end_byte].decode("utf-8")
|
||||
elif child.type == "enum_variant_list":
|
||||
for grandchild in child.children:
|
||||
if grandchild.type == "identifier":
|
||||
variants.append(
|
||||
content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")
|
||||
)
|
||||
if name:
|
||||
structs.append(
|
||||
Class(
|
||||
name=name,
|
||||
attributes=variants,
|
||||
line_number=node.start_point[0] + 1,
|
||||
)
|
||||
)
|
||||
|
||||
for child in node.children:
|
||||
structs.extend(self._extract_structs(child, content_bytes))
|
||||
|
||||
return structs
|
||||
|
||||
def _extract_imports(self, node: Node, content_bytes: bytes) -> list[ImportStatement]:
|
||||
"""Extract use statements from the AST."""
|
||||
imports = []
|
||||
|
||||
if node.type == "use_declaration":
|
||||
imp = self._parse_import(node, content_bytes)
|
||||
if imp:
|
||||
imports.append(imp)
|
||||
|
||||
for child in node.children:
|
||||
imports.extend(self._extract_imports(child, content_bytes))
|
||||
|
||||
return imports
|
||||
|
||||
def _parse_function(self, node: Node, content: str, content_bytes: bytes) -> Optional[Function]:
|
||||
"""Parse a function definition node."""
|
||||
name = None
|
||||
parameters = []
|
||||
return_type = None
|
||||
visibility = "private"
|
||||
line_number = node.start_point[0] + 1
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "visibility_modifier":
|
||||
visibility = content_bytes[child.start_byte : child.end_byte].decode("utf-8")
|
||||
elif child.type == "identifier":
|
||||
name = content_bytes[child.start_byte : child.end_byte].decode("utf-8")
|
||||
elif child.type == "parameters":
|
||||
parameters = self._parse_parameters(child, content_bytes)
|
||||
elif child.type == "return_type":
|
||||
for grandchild in child.children:
|
||||
if grandchild.type in ("type_identifier", "qualified_type"):
|
||||
return_type = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")
|
||||
|
||||
return Function(
|
||||
name=name or "unknown",
|
||||
parameters=parameters,
|
||||
return_type=return_type,
|
||||
line_number=line_number,
|
||||
visibility=visibility if visibility != "pub(crate)" else "public",
|
||||
)
|
||||
|
||||
def _parse_import(self, node: Node, content_bytes: bytes) -> Optional[ImportStatement]:
|
||||
"""Parse a use declaration node."""
|
||||
line_number = node.start_point[0] + 1
|
||||
module = None
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "use_path":
|
||||
module = content_bytes[child.start_byte : child.end_byte].decode("utf-8")
|
||||
elif child.type == "use_as_path":
|
||||
for grandchild in child.children:
|
||||
if grandchild.type == "use_path":
|
||||
module = content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8")
|
||||
break
|
||||
|
||||
return ImportStatement(
|
||||
module=module or "",
|
||||
line_number=line_number,
|
||||
)
|
||||
|
||||
def _parse_parameters(self, node: Node, content_bytes: bytes) -> list[str]:
|
||||
"""Parse function parameters."""
|
||||
params = []
|
||||
for child in node.children:
|
||||
if child.type == "parameter":
|
||||
for grandchild in child.children:
|
||||
if grandchild.type == "identifier":
|
||||
params.append(content_bytes[grandchild.start_byte : grandchild.end_byte].decode("utf-8"))
|
||||
break
|
||||
elif grandchild.type == "pattern":
|
||||
for ggchild in grandchild.children:
|
||||
if ggchild.type == "identifier":
|
||||
params.append(content_bytes[ggchild.start_byte : ggchild.end_byte].decode("utf-8"))
|
||||
break
|
||||
return params
|
||||
Reference in New Issue
Block a user