diff --git a/stubgen/inferrer.py b/stubgen/inferrer.py new file mode 100644 index 0000000..4ab5a9e --- /dev/null +++ b/stubgen/inferrer.py @@ -0,0 +1,533 @@ +"""Type inference engine for Python code analysis.""" + +import ast +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + +from stubgen.parser import ( + FileInfo, Function, Variable, Class, TypeHint +) + + +class TypeCandidate: + """Represents a candidate type with confidence score.""" + + def __init__( + self, + type_str: str, + confidence: float, + source: str = "inference" + ): + self.type_str = type_str + self.confidence = confidence + self.source = source + + def __repr__(self): + return f"TypeCandidate({self.type_str}, confidence={self.confidence})" + + +class ReturnTypeAnalyzer: + """Analyzes return statements to infer function return types.""" + + def __init__(self, func: Function): + self.func = func + + def analyze(self) -> Optional[TypeCandidate]: + return_types: Set[str] = set() + + for stmt in self.func.body: + result = self._analyze_statement(stmt) + if result: + return_types.add(result) + + if not return_types: + if self._has_yield(stmt for stmt in self.func.body): + return TypeCandidate("Generator", 0.9, "yield") + return None + + if len(return_types) == 1: + return_type = return_types.pop() + confidence = 1.0 + if return_type == "None": + confidence = 0.7 + return TypeCandidate(return_type, confidence, "return") + + if return_types == {"None", "bool"}: + return TypeCandidate("Optional[bool]", 0.8, "return") + if return_types == {"None", "int"}: + return TypeCandidate("Optional[int]", 0.8, "return") + if return_types == {"None", "str"}: + return TypeCandidate("Optional[str]", 0.8, "return") + if return_types == {"None"}: + return TypeCandidate("None", 1.0, "return") + + return TypeCandidate("Any", 0.5, "return") + + def _has_yield(self, statements) -> bool: + for stmt in statements: + if isinstance(stmt, (ast.Yield, ast.YieldFrom)): + return True + for child in ast.walk(stmt): + if isinstance(child, (ast.Yield, ast.YieldFrom)): + return True + return False + + def _analyze_statement(self, stmt: ast.stmt) -> Optional[str]: + if isinstance(stmt, ast.Return): + if stmt.value is None: + return "None" + return self._infer_type(stmt.value) + elif isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant): + return self._infer_type(stmt.value) + elif isinstance(stmt, ast.Assign): + if stmt.value: + return self._infer_type(stmt.value) + elif isinstance(stmt, (ast.If, ast.For, ast.While)): + for child in stmt.body: + result = self._analyze_statement(child) + if result: + return result + if hasattr(stmt, 'orelse'): + for child in stmt.orelse: + result = self._analyze_statement(child) + if result: + return result + elif isinstance(stmt, ast.Try): + for child in stmt.body: + result = self._analyze_statement(child) + if result: + return result + return None + + def _infer_type(self, node: ast.expr) -> str: + if isinstance(node, ast.Constant): + value = node.value + if isinstance(value, bool): + return "bool" + elif isinstance(value, int): + return "int" + elif isinstance(value, float): + return "float" + elif isinstance(value, str): + return "str" + elif value is None: + return "None" + elif isinstance(node, ast.List): + if node.elts: + types = {self._infer_type(elt) for elt in node.elts} + if len(types) == 1: + return f"List[{types.pop()}]" + return "List[Any]" + elif isinstance(node, ast.Tuple): + if node.elts: + types = [self._infer_type(elt) for elt in node.elts] + return f"Tuple[{', '.join(types)}]" + return "Tuple[()]" + elif isinstance(node, ast.Dict): + return "Dict[Any, Any]" + elif isinstance(node, ast.Set): + if node.elts: + types = {self._infer_type(elt) for elt in node.elts} + if len(types) == 1: + return f"Set[{types.pop()}]" + return "Set[Any]" + elif isinstance(node, ast.Name): + return self._infer_name_type(node) + elif isinstance(node, ast.BinOp): + left_type = self._infer_type(node.left) + right_type = self._infer_type(node.right) + if left_type == right_type: + return left_type + if left_type in ("int", "float") and right_type in ("int", "float"): + return "float" if "float" in (left_type, right_type) else "int" + return "Any" + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + func_name = node.func.id + return self._infer_call_type(func_name, node) + return "Any" + elif isinstance(node, ast.Attribute): + return "Any" + return "Any" + + def _infer_name_type(self, node: ast.Name) -> str: + if node.id in ("True", "False"): + return "bool" + return "Any" + + def _infer_call_type(self, func_name: str, node: ast.Call) -> str: + type_hints = { + "len": "int", + "str": "str", + "int": "int", + "float": "float", + "bool": "bool", + "list": "list", + "dict": "dict", + "set": "set", + "tuple": "tuple", + "range": "range", + "enumerate": "enumerate", + "zip": "zip", + "map": "map", + "filter": "filter", + "sorted": "list", + "reversed": "list", + "type": "type", + "isinstance": "bool", + "hasattr": "bool", + "getattr": "Any", + "open": "IO", + } + return type_hints.get(func_name, "Any") + + +class DefaultValueAnalyzer: + """Analyzes default values to infer parameter types.""" + + def analyze(self, default: Optional[ast.expr]) -> Optional[TypeCandidate]: + if default is None: + return None + + default_type = self._infer_type(default) + + confidence_map = { + "None": 0.7, + "str": 0.8, + "int": 0.8, + "float": 0.8, + "bool": 0.8, + "list": 0.6, + "dict": 0.6, + "set": 0.6, + "tuple": 0.6, + "Any": 0.5, + } + + confidence = confidence_map.get(default_type, 0.6) + return TypeCandidate(default_type, confidence, "default") + + def _infer_type(self, node: ast.expr) -> str: + if isinstance(node, ast.Constant): + value = node.value + if isinstance(value, bool): + return "bool" + elif isinstance(value, int): + return "int" + elif isinstance(value, float): + return "float" + elif isinstance(value, str): + return "str" + elif value is None: + return "None" + elif isinstance(node, ast.ListConstant): + return "list" + elif isinstance(node, ast.Dict): + return "dict" + elif isinstance(node, ast.Tuple): + return "tuple" + elif isinstance(node, ast.Set): + return "set" + elif isinstance(node, ast.Name): + if node.id in ("None", "True", "False"): + return "None" if node.id == "None" else "bool" + elif isinstance(node, ast.NameConstant): + if node.value is None: + return "None" + elif node.value is bool: + return "bool" + return "Any" + + +class AssignmentAnalyzer: + """Analyzes variable assignments to infer types.""" + + def __init__(self, file_info: FileInfo): + self.file_info = file_info + self.assignments: Dict[str, str] = {} + + def analyze(self) -> Dict[str, TypeCandidate]: + results = {} + + for var in self.file_info.variables: + if var.value: + inferred = self._infer_type(var.value) + confidence = 0.8 if inferred != "Any" else 0.5 + results[var.name] = TypeCandidate(inferred, confidence, "assignment") + elif var.annotation: + results[var.name] = TypeCandidate( + var.annotation.to_string(), + 1.0, + "annotation" + ) + + return results + + def _infer_type(self, node: ast.expr) -> str: + if isinstance(node, ast.Constant): + value = node.value + if isinstance(value, bool): + return "bool" + elif isinstance(value, int): + return "int" + elif isinstance(value, float): + return "float" + elif isinstance(value, str): + return "str" + elif value is None: + return "None" + elif isinstance(node, ast.List): + if node.elts: + types = {self._infer_type(elt) for elt in node.elts} + if len(types) == 1: + return f"List[{types.pop()}]" + return "list" + elif isinstance(node, ast.Tuple): + return "tuple" + elif isinstance(node, ast.Dict): + return "dict" + elif isinstance(node, ast.Set): + return "set" + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + func_name = node.func.id + type_map = { + "list": "list", + "dict": "dict", + "set": "set", + "tuple": "tuple", + "int": "int", + "str": "str", + "float": "float", + "bool": "bool", + } + return type_map.get(func_name, "Any") + return "Any" + return "Any" + + +class DocstringParser: + """Parses docstrings for type hints (Google and NumPy style).""" + + def __init__(self, docstring: str): + self.docstring = docstring or "" + + def parse_params(self) -> Dict[str, str]: + params = {} + + google_match = re.search( + r'Args:\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)', + self.docstring, + re.DOTALL + ) + if google_match: + args_section = google_match.group(1) + for match in re.finditer(r'(\w+)\s*\(([^)]+)\)\s*:\s*(.*)', args_section): + param_name = match.group(1) + param_type = match.group(2).strip() + params[param_name] = param_type + + numpy_match = re.search( + r'Parameters\s*\n\s*[-]+\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)', + self.docstring, + re.DOTALL + ) + if numpy_match: + params_section = numpy_match.group(1) + for match in re.finditer(r'(\w+)\s*(?:[-\s]+)?(?:type| dtype):\s*(.*)', params_section): + param_name = match.group(1) + param_type = match.group(2).strip() + if param_name not in params: + params[param_name] = param_type + + return params + + def parse_returns(self) -> Optional[str]: + google_match = re.search( + r'Returns?:\s*\n\s*(?:.*?):\s*(.*?)(?:\n\n|\n[A-Z]|\Z)', + self.docstring, + re.DOTALL + ) + if google_match: + return google_match.group(1).strip() + + numpy_match = re.search( + r'Returns?\s*\n\s*[-]+\s*\n\s*(?:.*?):\s*(.*?)(?:\n\n|\Z)', + self.docstring, + re.DOTALL + ) + if numpy_match: + return numpy_match.group(1).strip() + + return None + + def parse_attributes(self) -> Dict[str, str]: + attrs = {} + + google_match = re.search( + r'Attributes:\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)', + self.docstring, + re.DOTALL + ) + if google_match: + attrs_section = google_match.group(1) + for match in re.finditer(r'(\w+)\s*\(([^)]+)\)\s*:\s*(.*)', attrs_section): + attr_name = match.group(1) + attr_type = match.group(2).strip() + attrs[attr_name] = attr_type + + return attrs + + +class Inferrer: + """Main type inference engine that combines all analyzers.""" + + def __init__(self, infer_depth: int = 3): + self.infer_depth = infer_depth + + def infer_function_signature(self, func: Function, docstring: str = None) -> Dict[str, Any]: + """Infer complete function signature including parameter and return types.""" + args_info = {} + + docstring_parser = DocstringParser(docstring) if docstring else DocstringParser("") + doc_params = docstring_parser.parse_params() + + for arg in func.args.args: + arg_info = {"name": arg.arg} + + if arg.annotation: + arg_info["type"] = arg.annotation + arg_info["source"] = "annotation" + arg_info["confidence"] = 1.0 + elif arg.arg in doc_params: + arg_info["type"] = doc_params[arg.arg] + arg_info["source"] = "docstring" + arg_info["confidence"] = 0.9 + else: + default_idx = len(func.args.args) - len(func.args.defaults) + func.args.args.index(arg) + if default_idx >= 0 and default_idx < len(func.args.defaults): + default_analyzer = DefaultValueAnalyzer() + candidate = default_analyzer.analyze(func.args.defaults[default_idx]) + if candidate: + arg_info["type"] = candidate.type_str + arg_info["source"] = "default" + arg_info["confidence"] = candidate.confidence + else: + arg_info["type"] = "Any" + arg_info["source"] = "none" + arg_info["confidence"] = 0.5 + else: + arg_info["type"] = "Any" + arg_info["source"] = "none" + arg_info["confidence"] = 0.5 + + args_info[arg.arg] = arg_info + + for kwonly in func.args.kwonlyargs: + arg_info = {"name": kwonly.arg} + if kwonly.annotation: + arg_info["type"] = kwonly.annotation + arg_info["source"] = "annotation" + arg_info["confidence"] = 1.0 + else: + arg_info["type"] = "Any" + arg_info["source"] = "none" + arg_info["confidence"] = 0.5 + args_info[kwonly.arg] = arg_info + + if func.args.vararg: + args_info[func.args.vararg.arg] = { + "name": func.args.vararg.arg, + "type": "*Tuple[Any, ...]", + "source": "vararg", + "confidence": 1.0 + } + + if func.args.kwarg: + args_info[func.args.kwarg.arg] = { + "name": func.args.kwarg.arg, + "type": "**Dict[str, Any]", + "source": "kwarg", + "confidence": 1.0 + } + + return_type_info = {"type": "None", "source": "default", "confidence": 1.0} + + if func.returns: + return_type_info = { + "type": func.returns.to_string(), + "source": "annotation", + "confidence": 1.0 + } + else: + return_analyzer = ReturnTypeAnalyzer(func) + candidate = return_analyzer.analyze() + if candidate: + return_type_info = { + "type": candidate.type_str, + "source": candidate.source, + "confidence": candidate.confidence + } + elif docstring: + doc_returns = docstring_parser.parse_returns() + if doc_returns: + return_type_info = { + "type": doc_returns, + "source": "docstring", + "confidence": 0.9 + } + + return { + "args": args_info, + "returns": return_type_info, + "is_async": func.is_async, + "decorators": func.decorators + } + + def infer_file(self, file_info: FileInfo) -> Dict[str, Any]: + """Infer types for all components in a file.""" + results = { + "functions": {}, + "classes": {}, + "variables": {} + } + + for func in file_info.functions: + results["functions"][func.name] = self.infer_function_signature( + func, file_info.docstring + ) + + for cls in file_info.classes: + class_info = { + "name": cls.name, + "bases": cls.bases, + "methods": {}, + "attributes": {} + } + + for method in cls.methods: + class_info["methods"][method.name] = self.infer_function_signature( + method, None + ) + + for attr in cls.attributes: + if attr.annotation: + class_info["attributes"][attr.name] = { + "type": attr.annotation.to_string(), + "source": "annotation", + "confidence": 1.0 + } + else: + class_info["attributes"][attr.name] = { + "type": "Any", + "source": "none", + "confidence": 0.5 + } + + results["classes"][cls.name] = class_info + + assignment_analyzer = AssignmentAnalyzer(file_info) + results["variables"] = assignment_analyzer.analyze() + + return results