diff --git a/src/parsers/rust.py b/src/parsers/rust.py new file mode 100644 index 0000000..ee7c637 --- /dev/null +++ b/src/parsers/rust.py @@ -0,0 +1,237 @@ +from pathlib import Path +import re + +from src.parsers.base import BaseParser, ParserResult, Entity, EntityType + + +class RustParser(BaseParser): + SUPPORTED_EXTENSIONS = [".rs"] + + def __init__(self): + pass + + def parse(self, file_path: Path, content: str) -> ParserResult: + result = ParserResult(file_path=file_path, language="rust") + try: + result.entities = self.extract_entities(content, file_path) + result.imports = self.extract_imports(content) + return result + except Exception as e: + result.errors.append(f"Parse error: {str(e)}") + return result + + def extract_entities(self, content: str, file_path: Path) -> list[Entity]: + entities = [] + entities.extend(self._extract_functions(content, file_path)) + entities.extend(self._extract_structs(content, file_path)) + entities.extend(self._extract_traits(content, file_path)) + entities.extend(self._extract_impls(content, file_path)) + return entities + + def _extract_functions(self, content: str, file_path: Path) -> list[Entity]: + functions = [] + lines = content.split('\n') + pattern = r'^fn\s+([a-zA-Z_][a-zA-Z0-9_]*)(\s*<[^>]*>)?\s*\(([^)]*)\)\s*(->\s*[\w\s<>,]+)?\s*\{?$' + + for i, line in enumerate(lines): + line = line.strip() + match = re.match(pattern, line) + if match: + func_name = match.group(1) + params = match.group(3) + + start_line = i + 1 + end_line = self._find_braces_end(lines, i) + + code_lines = lines[i:end_line] + code = '\n'.join(code_lines) + + entity = Entity( + name=func_name, + entity_type=EntityType.FUNCTION, + file_path=file_path, + start_line=start_line, + end_line=end_line, + code=code, + attributes={"parameters": self._parse_rust_params(params)}, + calls=self._extract_function_calls(code), + ) + functions.append(entity) + + return functions + + def _extract_structs(self, content: str, file_path: Path) -> list[Entity]: + structs = [] + lines = content.split('\n') + + for i, line in enumerate(lines): + line = line.strip() + match = re.match(r'struct\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*(\{|;|$)', line) + if match: + struct_name = match.group(1) + + start_line = i + 1 + if '{' in line: + end_line = self._find_braces_end(lines, i) + else: + end_line = i + 1 + + code_lines = lines[i:end_line] + code = '\n'.join(code_lines) + + entity = Entity( + name=struct_name, + entity_type=EntityType.CLASS, + file_path=file_path, + start_line=start_line, + end_line=end_line, + code=code, + ) + structs.append(entity) + + return structs + + def _extract_traits(self, content: str, file_path: Path) -> list[Entity]: + traits = [] + lines = content.split('\n') + + for i, line in enumerate(lines): + line = line.strip() + match = re.match(r'trait\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*(\{|$)', line) + if match: + trait_name = match.group(1) + + start_line = i + 1 + end_line = self._find_braces_end(lines, i) + + code_lines = lines[i:end_line] + code = '\n'.join(code_lines) + + entity = Entity( + name=trait_name, + entity_type=EntityType.CLASS, + file_path=file_path, + start_line=start_line, + end_line=end_line, + code=code, + ) + traits.append(entity) + + return traits + + def _extract_impls(self, content: str, file_path: Path) -> list[Entity]: + impls = [] + lines = content.split('\n') + + for i, line in enumerate(lines): + line = line.strip() + match = re.match(r'impl\s+(?:[a-zA-Z_][a-zA-Z0-9_]*\s+)?impl\s*\{', line) + if match: + start_line = i + 1 + end_line = self._find_braces_end(lines, i) + + code_lines = lines[i:end_line] + code = '\n'.join(code_lines) + + entity = Entity( + name="impl", + entity_type=EntityType.CLASS, + file_path=file_path, + start_line=start_line, + end_line=end_line, + code=code, + children=self._extract_impl_methods(code, file_path, start_line), + ) + impls.append(entity) + + return impls + + def _extract_impl_methods(self, content: str, file_path: Path, base_line: int) -> list[Entity]: + methods = [] + lines = content.split('\n') + pattern = r'fn\s+([a-zA-Z_][a-zA-Z0-9_]*)(\s*<[^>]*>)?\s*\(([^)]*)\)\s*(->\s*[\w\s<>,]+)?\s*\{?$' + + for i, line in enumerate(lines): + line = line.strip() + match = re.match(pattern, line) + if match: + method_name = match.group(1) + params = match.group(3) + + start_line = base_line + i + end_line = base_line + self._find_braces_end(lines, i) + + code_lines = lines[i:end_line] + code = '\n'.join(code_lines) + + entity = Entity( + name=method_name, + entity_type=EntityType.METHOD, + file_path=file_path, + start_line=start_line, + end_line=end_line, + code=code, + attributes={"parameters": self._parse_rust_params(params)}, + calls=self._extract_function_calls(code), + ) + methods.append(entity) + + return methods + + def _parse_rust_params(self, params: str) -> list[str]: + param_list = [] + for param in params.split(','): + param = param.strip() + if param: + parts = param.split() + if len(parts) >= 2: + param_list.append(parts[-1]) + return param_list + + def _extract_function_calls(self, code: str) -> list[str]: + calls = [] + pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\([^)]*\)' + for match in re.finditer(pattern, code): + func_name = match.group(1) + if func_name not in ['if', 'while', 'for', 'match', 'return', 'println', 'print', 'panic']: + calls.append(func_name) + return list(set(calls)) + + def _find_braces_end(self, lines: list[str], start_index: int) -> int: + brace_count = 0 + in_string = False + string_char = None + + for i, line in enumerate(lines[start_index:], start_index): + for j, char in enumerate(line): + if char in ['"', "'"] and (j == 0 or line[j-1] != '\\'): + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + elif not in_string and char == '{': + brace_count += 1 + elif not in_string and char == '}': + brace_count -= 1 + if brace_count == 0: + return i + 1 + + return len(lines) + + def extract_imports(self, content: str) -> list[str]: + imports = [] + lines = content.split('\n') + + for line in lines: + line = line.strip() + match = re.match(r'use\s+([^;]+);', line) + if match: + import_path = match.group(1).strip() + imports.append(import_path) + + return imports + + def extract_calls(self, content: str) -> list[str]: + return self._extract_function_calls(content)