Initial commit: Add python-stub-generator project
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
This commit is contained in:
533
stubgen/inferrer.py
Normal file
533
stubgen/inferrer.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""Type inference engine for Python code analysis."""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from stubgen.parser import (
|
||||||
|
FileInfo, Function, Variable, Class, TypeHint
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TypeCandidate:
|
||||||
|
"""Represents a candidate type with confidence score."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type_str: str,
|
||||||
|
confidence: float,
|
||||||
|
source: str = "inference"
|
||||||
|
):
|
||||||
|
self.type_str = type_str
|
||||||
|
self.confidence = confidence
|
||||||
|
self.source = source
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"TypeCandidate({self.type_str}, confidence={self.confidence})"
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnTypeAnalyzer:
|
||||||
|
"""Analyzes return statements to infer function return types."""
|
||||||
|
|
||||||
|
def __init__(self, func: Function):
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def analyze(self) -> Optional[TypeCandidate]:
|
||||||
|
return_types: Set[str] = set()
|
||||||
|
|
||||||
|
for stmt in self.func.body:
|
||||||
|
result = self._analyze_statement(stmt)
|
||||||
|
if result:
|
||||||
|
return_types.add(result)
|
||||||
|
|
||||||
|
if not return_types:
|
||||||
|
if self._has_yield(stmt for stmt in self.func.body):
|
||||||
|
return TypeCandidate("Generator", 0.9, "yield")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(return_types) == 1:
|
||||||
|
return_type = return_types.pop()
|
||||||
|
confidence = 1.0
|
||||||
|
if return_type == "None":
|
||||||
|
confidence = 0.7
|
||||||
|
return TypeCandidate(return_type, confidence, "return")
|
||||||
|
|
||||||
|
if return_types == {"None", "bool"}:
|
||||||
|
return TypeCandidate("Optional[bool]", 0.8, "return")
|
||||||
|
if return_types == {"None", "int"}:
|
||||||
|
return TypeCandidate("Optional[int]", 0.8, "return")
|
||||||
|
if return_types == {"None", "str"}:
|
||||||
|
return TypeCandidate("Optional[str]", 0.8, "return")
|
||||||
|
if return_types == {"None"}:
|
||||||
|
return TypeCandidate("None", 1.0, "return")
|
||||||
|
|
||||||
|
return TypeCandidate("Any", 0.5, "return")
|
||||||
|
|
||||||
|
def _has_yield(self, statements) -> bool:
|
||||||
|
for stmt in statements:
|
||||||
|
if isinstance(stmt, (ast.Yield, ast.YieldFrom)):
|
||||||
|
return True
|
||||||
|
for child in ast.walk(stmt):
|
||||||
|
if isinstance(child, (ast.Yield, ast.YieldFrom)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _analyze_statement(self, stmt: ast.stmt) -> Optional[str]:
|
||||||
|
if isinstance(stmt, ast.Return):
|
||||||
|
if stmt.value is None:
|
||||||
|
return "None"
|
||||||
|
return self._infer_type(stmt.value)
|
||||||
|
elif isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant):
|
||||||
|
return self._infer_type(stmt.value)
|
||||||
|
elif isinstance(stmt, ast.Assign):
|
||||||
|
if stmt.value:
|
||||||
|
return self._infer_type(stmt.value)
|
||||||
|
elif isinstance(stmt, (ast.If, ast.For, ast.While)):
|
||||||
|
for child in stmt.body:
|
||||||
|
result = self._analyze_statement(child)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
if hasattr(stmt, 'orelse'):
|
||||||
|
for child in stmt.orelse:
|
||||||
|
result = self._analyze_statement(child)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
elif isinstance(stmt, ast.Try):
|
||||||
|
for child in stmt.body:
|
||||||
|
result = self._analyze_statement(child)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _infer_type(self, node: ast.expr) -> str:
|
||||||
|
if isinstance(node, ast.Constant):
|
||||||
|
value = node.value
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "bool"
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return "int"
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return "float"
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return "str"
|
||||||
|
elif value is None:
|
||||||
|
return "None"
|
||||||
|
elif isinstance(node, ast.List):
|
||||||
|
if node.elts:
|
||||||
|
types = {self._infer_type(elt) for elt in node.elts}
|
||||||
|
if len(types) == 1:
|
||||||
|
return f"List[{types.pop()}]"
|
||||||
|
return "List[Any]"
|
||||||
|
elif isinstance(node, ast.Tuple):
|
||||||
|
if node.elts:
|
||||||
|
types = [self._infer_type(elt) for elt in node.elts]
|
||||||
|
return f"Tuple[{', '.join(types)}]"
|
||||||
|
return "Tuple[()]"
|
||||||
|
elif isinstance(node, ast.Dict):
|
||||||
|
return "Dict[Any, Any]"
|
||||||
|
elif isinstance(node, ast.Set):
|
||||||
|
if node.elts:
|
||||||
|
types = {self._infer_type(elt) for elt in node.elts}
|
||||||
|
if len(types) == 1:
|
||||||
|
return f"Set[{types.pop()}]"
|
||||||
|
return "Set[Any]"
|
||||||
|
elif isinstance(node, ast.Name):
|
||||||
|
return self._infer_name_type(node)
|
||||||
|
elif isinstance(node, ast.BinOp):
|
||||||
|
left_type = self._infer_type(node.left)
|
||||||
|
right_type = self._infer_type(node.right)
|
||||||
|
if left_type == right_type:
|
||||||
|
return left_type
|
||||||
|
if left_type in ("int", "float") and right_type in ("int", "float"):
|
||||||
|
return "float" if "float" in (left_type, right_type) else "int"
|
||||||
|
return "Any"
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
func_name = node.func.id
|
||||||
|
return self._infer_call_type(func_name, node)
|
||||||
|
return "Any"
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
return "Any"
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
def _infer_name_type(self, node: ast.Name) -> str:
|
||||||
|
if node.id in ("True", "False"):
|
||||||
|
return "bool"
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
def _infer_call_type(self, func_name: str, node: ast.Call) -> str:
|
||||||
|
type_hints = {
|
||||||
|
"len": "int",
|
||||||
|
"str": "str",
|
||||||
|
"int": "int",
|
||||||
|
"float": "float",
|
||||||
|
"bool": "bool",
|
||||||
|
"list": "list",
|
||||||
|
"dict": "dict",
|
||||||
|
"set": "set",
|
||||||
|
"tuple": "tuple",
|
||||||
|
"range": "range",
|
||||||
|
"enumerate": "enumerate",
|
||||||
|
"zip": "zip",
|
||||||
|
"map": "map",
|
||||||
|
"filter": "filter",
|
||||||
|
"sorted": "list",
|
||||||
|
"reversed": "list",
|
||||||
|
"type": "type",
|
||||||
|
"isinstance": "bool",
|
||||||
|
"hasattr": "bool",
|
||||||
|
"getattr": "Any",
|
||||||
|
"open": "IO",
|
||||||
|
}
|
||||||
|
return type_hints.get(func_name, "Any")
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValueAnalyzer:
|
||||||
|
"""Analyzes default values to infer parameter types."""
|
||||||
|
|
||||||
|
def analyze(self, default: Optional[ast.expr]) -> Optional[TypeCandidate]:
|
||||||
|
if default is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
default_type = self._infer_type(default)
|
||||||
|
|
||||||
|
confidence_map = {
|
||||||
|
"None": 0.7,
|
||||||
|
"str": 0.8,
|
||||||
|
"int": 0.8,
|
||||||
|
"float": 0.8,
|
||||||
|
"bool": 0.8,
|
||||||
|
"list": 0.6,
|
||||||
|
"dict": 0.6,
|
||||||
|
"set": 0.6,
|
||||||
|
"tuple": 0.6,
|
||||||
|
"Any": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
confidence = confidence_map.get(default_type, 0.6)
|
||||||
|
return TypeCandidate(default_type, confidence, "default")
|
||||||
|
|
||||||
|
def _infer_type(self, node: ast.expr) -> str:
|
||||||
|
if isinstance(node, ast.Constant):
|
||||||
|
value = node.value
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "bool"
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return "int"
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return "float"
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return "str"
|
||||||
|
elif value is None:
|
||||||
|
return "None"
|
||||||
|
elif isinstance(node, ast.ListConstant):
|
||||||
|
return "list"
|
||||||
|
elif isinstance(node, ast.Dict):
|
||||||
|
return "dict"
|
||||||
|
elif isinstance(node, ast.Tuple):
|
||||||
|
return "tuple"
|
||||||
|
elif isinstance(node, ast.Set):
|
||||||
|
return "set"
|
||||||
|
elif isinstance(node, ast.Name):
|
||||||
|
if node.id in ("None", "True", "False"):
|
||||||
|
return "None" if node.id == "None" else "bool"
|
||||||
|
elif isinstance(node, ast.NameConstant):
|
||||||
|
if node.value is None:
|
||||||
|
return "None"
|
||||||
|
elif node.value is bool:
|
||||||
|
return "bool"
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
|
||||||
|
class AssignmentAnalyzer:
|
||||||
|
"""Analyzes variable assignments to infer types."""
|
||||||
|
|
||||||
|
def __init__(self, file_info: FileInfo):
|
||||||
|
self.file_info = file_info
|
||||||
|
self.assignments: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def analyze(self) -> Dict[str, TypeCandidate]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for var in self.file_info.variables:
|
||||||
|
if var.value:
|
||||||
|
inferred = self._infer_type(var.value)
|
||||||
|
confidence = 0.8 if inferred != "Any" else 0.5
|
||||||
|
results[var.name] = TypeCandidate(inferred, confidence, "assignment")
|
||||||
|
elif var.annotation:
|
||||||
|
results[var.name] = TypeCandidate(
|
||||||
|
var.annotation.to_string(),
|
||||||
|
1.0,
|
||||||
|
"annotation"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _infer_type(self, node: ast.expr) -> str:
|
||||||
|
if isinstance(node, ast.Constant):
|
||||||
|
value = node.value
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "bool"
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return "int"
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return "float"
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return "str"
|
||||||
|
elif value is None:
|
||||||
|
return "None"
|
||||||
|
elif isinstance(node, ast.List):
|
||||||
|
if node.elts:
|
||||||
|
types = {self._infer_type(elt) for elt in node.elts}
|
||||||
|
if len(types) == 1:
|
||||||
|
return f"List[{types.pop()}]"
|
||||||
|
return "list"
|
||||||
|
elif isinstance(node, ast.Tuple):
|
||||||
|
return "tuple"
|
||||||
|
elif isinstance(node, ast.Dict):
|
||||||
|
return "dict"
|
||||||
|
elif isinstance(node, ast.Set):
|
||||||
|
return "set"
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
func_name = node.func.id
|
||||||
|
type_map = {
|
||||||
|
"list": "list",
|
||||||
|
"dict": "dict",
|
||||||
|
"set": "set",
|
||||||
|
"tuple": "tuple",
|
||||||
|
"int": "int",
|
||||||
|
"str": "str",
|
||||||
|
"float": "float",
|
||||||
|
"bool": "bool",
|
||||||
|
}
|
||||||
|
return type_map.get(func_name, "Any")
|
||||||
|
return "Any"
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
|
||||||
|
class DocstringParser:
|
||||||
|
"""Parses docstrings for type hints (Google and NumPy style)."""
|
||||||
|
|
||||||
|
def __init__(self, docstring: str):
|
||||||
|
self.docstring = docstring or ""
|
||||||
|
|
||||||
|
def parse_params(self) -> Dict[str, str]:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
google_match = re.search(
|
||||||
|
r'Args:\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)',
|
||||||
|
self.docstring,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
if google_match:
|
||||||
|
args_section = google_match.group(1)
|
||||||
|
for match in re.finditer(r'(\w+)\s*\(([^)]+)\)\s*:\s*(.*)', args_section):
|
||||||
|
param_name = match.group(1)
|
||||||
|
param_type = match.group(2).strip()
|
||||||
|
params[param_name] = param_type
|
||||||
|
|
||||||
|
numpy_match = re.search(
|
||||||
|
r'Parameters\s*\n\s*[-]+\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)',
|
||||||
|
self.docstring,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
if numpy_match:
|
||||||
|
params_section = numpy_match.group(1)
|
||||||
|
for match in re.finditer(r'(\w+)\s*(?:[-\s]+)?(?:type| dtype):\s*(.*)', params_section):
|
||||||
|
param_name = match.group(1)
|
||||||
|
param_type = match.group(2).strip()
|
||||||
|
if param_name not in params:
|
||||||
|
params[param_name] = param_type
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def parse_returns(self) -> Optional[str]:
|
||||||
|
google_match = re.search(
|
||||||
|
r'Returns?:\s*\n\s*(?:.*?):\s*(.*?)(?:\n\n|\n[A-Z]|\Z)',
|
||||||
|
self.docstring,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
if google_match:
|
||||||
|
return google_match.group(1).strip()
|
||||||
|
|
||||||
|
numpy_match = re.search(
|
||||||
|
r'Returns?\s*\n\s*[-]+\s*\n\s*(?:.*?):\s*(.*?)(?:\n\n|\Z)',
|
||||||
|
self.docstring,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
if numpy_match:
|
||||||
|
return numpy_match.group(1).strip()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_attributes(self) -> Dict[str, str]:
|
||||||
|
attrs = {}
|
||||||
|
|
||||||
|
google_match = re.search(
|
||||||
|
r'Attributes:\s*\n(.*?)(?:\n\n|\n[A-Z]|\Z)',
|
||||||
|
self.docstring,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
if google_match:
|
||||||
|
attrs_section = google_match.group(1)
|
||||||
|
for match in re.finditer(r'(\w+)\s*\(([^)]+)\)\s*:\s*(.*)', attrs_section):
|
||||||
|
attr_name = match.group(1)
|
||||||
|
attr_type = match.group(2).strip()
|
||||||
|
attrs[attr_name] = attr_type
|
||||||
|
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
class Inferrer:
|
||||||
|
"""Main type inference engine that combines all analyzers."""
|
||||||
|
|
||||||
|
def __init__(self, infer_depth: int = 3):
|
||||||
|
self.infer_depth = infer_depth
|
||||||
|
|
||||||
|
def infer_function_signature(self, func: Function, docstring: str = None) -> Dict[str, Any]:
|
||||||
|
"""Infer complete function signature including parameter and return types."""
|
||||||
|
args_info = {}
|
||||||
|
|
||||||
|
docstring_parser = DocstringParser(docstring) if docstring else DocstringParser("")
|
||||||
|
doc_params = docstring_parser.parse_params()
|
||||||
|
|
||||||
|
for arg in func.args.args:
|
||||||
|
arg_info = {"name": arg.arg}
|
||||||
|
|
||||||
|
if arg.annotation:
|
||||||
|
arg_info["type"] = arg.annotation
|
||||||
|
arg_info["source"] = "annotation"
|
||||||
|
arg_info["confidence"] = 1.0
|
||||||
|
elif arg.arg in doc_params:
|
||||||
|
arg_info["type"] = doc_params[arg.arg]
|
||||||
|
arg_info["source"] = "docstring"
|
||||||
|
arg_info["confidence"] = 0.9
|
||||||
|
else:
|
||||||
|
default_idx = len(func.args.args) - len(func.args.defaults) + func.args.args.index(arg)
|
||||||
|
if default_idx >= 0 and default_idx < len(func.args.defaults):
|
||||||
|
default_analyzer = DefaultValueAnalyzer()
|
||||||
|
candidate = default_analyzer.analyze(func.args.defaults[default_idx])
|
||||||
|
if candidate:
|
||||||
|
arg_info["type"] = candidate.type_str
|
||||||
|
arg_info["source"] = "default"
|
||||||
|
arg_info["confidence"] = candidate.confidence
|
||||||
|
else:
|
||||||
|
arg_info["type"] = "Any"
|
||||||
|
arg_info["source"] = "none"
|
||||||
|
arg_info["confidence"] = 0.5
|
||||||
|
else:
|
||||||
|
arg_info["type"] = "Any"
|
||||||
|
arg_info["source"] = "none"
|
||||||
|
arg_info["confidence"] = 0.5
|
||||||
|
|
||||||
|
args_info[arg.arg] = arg_info
|
||||||
|
|
||||||
|
for kwonly in func.args.kwonlyargs:
|
||||||
|
arg_info = {"name": kwonly.arg}
|
||||||
|
if kwonly.annotation:
|
||||||
|
arg_info["type"] = kwonly.annotation
|
||||||
|
arg_info["source"] = "annotation"
|
||||||
|
arg_info["confidence"] = 1.0
|
||||||
|
else:
|
||||||
|
arg_info["type"] = "Any"
|
||||||
|
arg_info["source"] = "none"
|
||||||
|
arg_info["confidence"] = 0.5
|
||||||
|
args_info[kwonly.arg] = arg_info
|
||||||
|
|
||||||
|
if func.args.vararg:
|
||||||
|
args_info[func.args.vararg.arg] = {
|
||||||
|
"name": func.args.vararg.arg,
|
||||||
|
"type": "*Tuple[Any, ...]",
|
||||||
|
"source": "vararg",
|
||||||
|
"confidence": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if func.args.kwarg:
|
||||||
|
args_info[func.args.kwarg.arg] = {
|
||||||
|
"name": func.args.kwarg.arg,
|
||||||
|
"type": "**Dict[str, Any]",
|
||||||
|
"source": "kwarg",
|
||||||
|
"confidence": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return_type_info = {"type": "None", "source": "default", "confidence": 1.0}
|
||||||
|
|
||||||
|
if func.returns:
|
||||||
|
return_type_info = {
|
||||||
|
"type": func.returns.to_string(),
|
||||||
|
"source": "annotation",
|
||||||
|
"confidence": 1.0
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return_analyzer = ReturnTypeAnalyzer(func)
|
||||||
|
candidate = return_analyzer.analyze()
|
||||||
|
if candidate:
|
||||||
|
return_type_info = {
|
||||||
|
"type": candidate.type_str,
|
||||||
|
"source": candidate.source,
|
||||||
|
"confidence": candidate.confidence
|
||||||
|
}
|
||||||
|
elif docstring:
|
||||||
|
doc_returns = docstring_parser.parse_returns()
|
||||||
|
if doc_returns:
|
||||||
|
return_type_info = {
|
||||||
|
"type": doc_returns,
|
||||||
|
"source": "docstring",
|
||||||
|
"confidence": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"args": args_info,
|
||||||
|
"returns": return_type_info,
|
||||||
|
"is_async": func.is_async,
|
||||||
|
"decorators": func.decorators
|
||||||
|
}
|
||||||
|
|
||||||
|
def infer_file(self, file_info: FileInfo) -> Dict[str, Any]:
|
||||||
|
"""Infer types for all components in a file."""
|
||||||
|
results = {
|
||||||
|
"functions": {},
|
||||||
|
"classes": {},
|
||||||
|
"variables": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for func in file_info.functions:
|
||||||
|
results["functions"][func.name] = self.infer_function_signature(
|
||||||
|
func, file_info.docstring
|
||||||
|
)
|
||||||
|
|
||||||
|
for cls in file_info.classes:
|
||||||
|
class_info = {
|
||||||
|
"name": cls.name,
|
||||||
|
"bases": cls.bases,
|
||||||
|
"methods": {},
|
||||||
|
"attributes": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for method in cls.methods:
|
||||||
|
class_info["methods"][method.name] = self.infer_function_signature(
|
||||||
|
method, None
|
||||||
|
)
|
||||||
|
|
||||||
|
for attr in cls.attributes:
|
||||||
|
if attr.annotation:
|
||||||
|
class_info["attributes"][attr.name] = {
|
||||||
|
"type": attr.annotation.to_string(),
|
||||||
|
"source": "annotation",
|
||||||
|
"confidence": 1.0
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
class_info["attributes"][attr.name] = {
|
||||||
|
"type": "Any",
|
||||||
|
"source": "none",
|
||||||
|
"confidence": 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
results["classes"][cls.name] = class_info
|
||||||
|
|
||||||
|
assignment_analyzer = AssignmentAnalyzer(file_info)
|
||||||
|
results["variables"] = assignment_analyzer.analyze()
|
||||||
|
|
||||||
|
return results
|
||||||
Reference in New Issue
Block a user