450 lines
17 KiB
Python
450 lines
17 KiB
Python
"""Stub generator module for creating .pyi files."""
|
|
|
|
import ast
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set
|
|
|
|
from stubgen.parser import FileInfo, Function, Class, Variable, Import
|
|
from stubgen.inferrer import Inferrer
|
|
|
|
|
|
def _split_type_args(s: str) -> List[str]:
|
|
"""Split type arguments handling nested brackets."""
|
|
result = []
|
|
current = ""
|
|
depth = 0
|
|
|
|
for char in s:
|
|
if char == '[':
|
|
depth += 1
|
|
current += char
|
|
elif char == ']':
|
|
depth -= 1
|
|
current += char
|
|
elif char == ',' and depth == 0:
|
|
if current.strip():
|
|
result.append(current.strip())
|
|
current = ""
|
|
else:
|
|
current += char
|
|
|
|
if current.strip():
|
|
result.append(current.strip())
|
|
|
|
return result
|
|
|
|
|
|
class StubWriter:
|
|
"""Writes .pyi stub files from parsed and inferred information."""
|
|
|
|
def __init__(self, inferrer: Optional[Inferrer] = None):
|
|
self.inferrer = inferrer or Inferrer()
|
|
self.imports: Set[str] = set()
|
|
|
|
def generate(self, file_info: FileInfo) -> str:
|
|
"""Generate a complete .pyi stub content."""
|
|
self.imports = set()
|
|
lines = []
|
|
|
|
if file_info.docstring:
|
|
lines.append(f'"""{file_info.docstring}"""')
|
|
lines.append("")
|
|
|
|
self._process_imports(file_info.imports)
|
|
|
|
if self.imports:
|
|
lines.extend(sorted(self.imports))
|
|
lines.append("")
|
|
|
|
for var in file_info.variables:
|
|
var_lines = self._format_variable(var)
|
|
lines.extend(var_lines)
|
|
|
|
if file_info.variables:
|
|
lines.append("")
|
|
|
|
for cls in file_info.classes:
|
|
class_lines = self._format_class(cls)
|
|
lines.extend(class_lines)
|
|
lines.append("")
|
|
|
|
for func in file_info.functions:
|
|
func_lines = self._format_function(func)
|
|
lines.extend(func_lines)
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _process_imports(self, imports: List[Import]):
|
|
for imp in imports:
|
|
if imp.is_from:
|
|
prefix = "." * imp.level
|
|
if imp.module:
|
|
prefix += imp.module
|
|
if imp.names:
|
|
parts = []
|
|
for name in imp.names:
|
|
if name in imp.as_names:
|
|
parts.append(f"{name} as {imp.as_names[name]}")
|
|
else:
|
|
parts.append(name)
|
|
import_str = f"from {prefix} import {', '.join(parts)}"
|
|
self.imports.add(import_str)
|
|
else:
|
|
if imp.module and imp.names:
|
|
for name in imp.names:
|
|
if name in imp.as_names:
|
|
self.imports.add(f"import {imp.module} as {imp.as_names[name]}")
|
|
else:
|
|
self.imports.add(f"import {imp.module}.{name}")
|
|
elif imp.module:
|
|
self.imports.add(f"import {imp.module}")
|
|
elif imp.names:
|
|
for name in imp.names:
|
|
if name in imp.as_names:
|
|
self.imports.add(f"import {name} as {imp.as_names[name]}")
|
|
else:
|
|
self.imports.add(f"import {name}")
|
|
|
|
def _format_variable(self, var: Variable) -> List[str]:
|
|
lines = []
|
|
|
|
if var.annotation:
|
|
annotation_str = var.annotation.to_string()
|
|
self._add_needed_import(annotation_str)
|
|
lines.append(f"{var.name}: {annotation_str}")
|
|
else:
|
|
lines.append(f"{var.name}: Any")
|
|
|
|
return lines
|
|
|
|
def _format_function(self, func: Function, inferred_info: Dict = None) -> List[str]:
|
|
lines = []
|
|
|
|
for decorator in func.decorators:
|
|
lines.append(f"@{decorator}")
|
|
|
|
async_prefix = "async " if func.is_async else ""
|
|
|
|
args_parts = []
|
|
|
|
defaults_offset = len(func.args.args) - len(func.args.defaults)
|
|
|
|
for i, arg in enumerate(func.args.args):
|
|
arg_str = arg.arg
|
|
|
|
if arg.annotation:
|
|
ann_str = ast.unparse(arg.annotation)
|
|
self._add_needed_import(ann_str)
|
|
arg_str += f": {ann_str}"
|
|
elif inferred_info and arg.arg in inferred_info.get("args", {}):
|
|
arg_info = inferred_info["args"][arg.arg]
|
|
arg_type = arg_info.get("type", "Any")
|
|
self._add_needed_import(arg_type)
|
|
if "None" in str(arg_type):
|
|
arg_str += f": {arg_type}"
|
|
elif i >= defaults_offset and i - defaults_offset < len(func.args.defaults):
|
|
default = func.args.defaults[i - defaults_offset]
|
|
default_type = self._infer_default_type(default)
|
|
if default_type != "Any":
|
|
self._add_needed_import(default_type)
|
|
arg_str += f": {default_type}"
|
|
|
|
args_parts.append(arg_str)
|
|
|
|
if func.args.vararg:
|
|
vararg_type = "tuple[Any, ...]"
|
|
self._add_needed_import(vararg_type)
|
|
args_parts.append(f"*{func.args.vararg.arg}: {vararg_type}")
|
|
|
|
for kwonly in func.args.kwonlyargs:
|
|
kw_str = kwonly.arg
|
|
if kwonly.annotation:
|
|
kw_str += f": {ast.unparse(kwonly.annotation)}"
|
|
args_parts.append(kw_str)
|
|
|
|
if func.args.kwarg:
|
|
kwarg_type = "dict[str, Any]"
|
|
self._add_needed_import(kwarg_type)
|
|
args_parts.append(f"**{func.args.kwarg.arg}: {kwarg_type}")
|
|
|
|
args_str = ", ".join(args_parts)
|
|
|
|
return_type = "None"
|
|
if func.returns:
|
|
if hasattr(func.returns, 'annotation') and func.returns.annotation:
|
|
return_type = ast.unparse(func.returns.annotation)
|
|
else:
|
|
return_type = ast.unparse(func.returns)
|
|
self._add_needed_import(return_type)
|
|
elif inferred_info:
|
|
ret_info = inferred_info.get("returns", {})
|
|
return_type = ret_info.get("type", "None")
|
|
self._add_needed_import(return_type)
|
|
|
|
lines.append(f"{async_prefix}def {func.name}({args_str}) -> {return_type}: ...")
|
|
|
|
return lines
|
|
|
|
def _format_class(self, cls: Class, inferred_info: Dict = None) -> List[str]:
|
|
lines = []
|
|
|
|
for decorator in cls.decorators:
|
|
lines.append(f"@{decorator}")
|
|
|
|
bases = ", ".join(cls.bases) if cls.bases else "object"
|
|
lines.append(f"class {cls.name}({bases}):")
|
|
|
|
indent = " "
|
|
|
|
for attr in cls.attributes:
|
|
if attr.annotation:
|
|
attr_type = attr.annotation.to_string()
|
|
self._add_needed_import(attr_type)
|
|
lines.append(f"{indent}{attr.name}: {attr_type}")
|
|
else:
|
|
lines.append(f"{indent}{attr.name}: Any")
|
|
|
|
if cls.attributes and cls.methods:
|
|
lines.append("")
|
|
|
|
if inferred_info:
|
|
methods = inferred_info.get("methods", {})
|
|
else:
|
|
methods = {}
|
|
|
|
for method in cls.methods:
|
|
method_inferred = methods.get(method.name, {})
|
|
method_lines = self._format_function(method, method_inferred)
|
|
for line in method_lines:
|
|
lines.append(f"{indent}{line}")
|
|
|
|
for nested_cls in cls.body:
|
|
if isinstance(nested_cls, ast.ClassDef):
|
|
nested_info = inferred_info.get("classes", {}).get(nested_cls.name, {}) if inferred_info else {}
|
|
for line in lines:
|
|
pass
|
|
lines.append(f"{indent}class {nested_cls.name}:")
|
|
lines.append(f"{indent} pass")
|
|
|
|
if not cls.attributes and not cls.methods and not any(isinstance(n, ast.ClassDef) for n in cls.body):
|
|
lines.append(f"{indent}pass")
|
|
|
|
return lines
|
|
|
|
def _infer_default_type(self, default: ast.expr) -> str:
|
|
if default is None:
|
|
return "Any"
|
|
|
|
if isinstance(default, ast.Constant):
|
|
value = default.value
|
|
if value is None:
|
|
return "None"
|
|
elif isinstance(value, bool):
|
|
return "bool"
|
|
elif isinstance(value, int):
|
|
return "int"
|
|
elif isinstance(value, float):
|
|
return "float"
|
|
elif isinstance(value, str):
|
|
return "str"
|
|
elif isinstance(default, ast.List):
|
|
return "list"
|
|
elif isinstance(default, ast.Dict):
|
|
return "dict"
|
|
elif isinstance(default, ast.Tuple):
|
|
return "tuple"
|
|
elif isinstance(default, ast.Set):
|
|
return "set"
|
|
|
|
return "Any"
|
|
|
|
def _add_needed_import(self, type_str: str):
|
|
"""Add import statements for types that need them."""
|
|
if not type_str or type_str == "Any":
|
|
return
|
|
|
|
type_str = type_str.strip()
|
|
|
|
container_match = re.match(r'(List|Dict|Tuple|Set|FrozenSet|Callable|Type|Union|Optional)\[', type_str)
|
|
if container_match:
|
|
container_type = container_match.group(1)
|
|
if container_type == "List":
|
|
self.imports.add("from typing import List")
|
|
elif container_type == "Dict":
|
|
self.imports.add("from typing import Dict")
|
|
elif container_type == "Tuple":
|
|
self.imports.add("from typing import Tuple")
|
|
elif container_type == "Set":
|
|
self.imports.add("from typing import Set")
|
|
elif container_type == "FrozenSet":
|
|
self.imports.add("from typing import FrozenSet")
|
|
elif container_type == "Callable":
|
|
self.imports.add("from typing import Callable")
|
|
elif container_type == "Type":
|
|
self.imports.add("from typing import Type")
|
|
elif container_type in ("Union", "Optional"):
|
|
self.imports.add("from typing import Union, Optional")
|
|
|
|
inner_types = type_str[container_match.end():-1]
|
|
inner_type_list = _split_type_args(inner_types)
|
|
for inner in inner_type_list:
|
|
self._add_needed_import(inner)
|
|
|
|
generic_match = re.match(r'([A-Z][a-zA-Z0-9_]*)\[', type_str)
|
|
if generic_match:
|
|
type_name = generic_match.group(1)
|
|
builtin_types = {"str", "int", "float", "bool", "list", "dict", "set", "tuple", "type", "object"}
|
|
if type_name not in builtin_types:
|
|
self.imports.add(f"from typing import {type_name}")
|
|
|
|
self.imports.add("from typing import Any")
|
|
|
|
|
|
class StubGenerator:
|
|
"""Main stub generator class that orchestrates parsing and generation."""
|
|
|
|
def __init__(self, inferrer: Optional[Inferrer] = None, interactive: bool = False):
|
|
self.inferrer = inferrer or Inferrer()
|
|
self.writer = StubWriter(self.inferrer)
|
|
self.interactive = interactive
|
|
|
|
def confirm_types_interactively(self, inferred_types: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Allow user to confirm or modify inferred types."""
|
|
import click
|
|
|
|
confirmed_types = inferred_types.copy()
|
|
|
|
if not inferred_types:
|
|
return confirmed_types
|
|
|
|
for func_name, type_info in inferred_types.get("functions", {}).items():
|
|
click.echo(f"\nFunction: {func_name}")
|
|
|
|
args_info = type_info.get("args", {})
|
|
for arg_name, arg_info in args_info.items():
|
|
if arg_info.get("source") != "annotation":
|
|
current_type = arg_info.get("type", "Any")
|
|
confidence = arg_info.get("confidence", 0.5)
|
|
click.echo(f" {arg_name}: {current_type} (confidence: {confidence})")
|
|
|
|
returns_info = type_info.get("returns", {})
|
|
if returns_info.get("source") != "annotation":
|
|
current_type = returns_info.get("type", "None")
|
|
confidence = returns_info.get("confidence", 0.5)
|
|
click.echo(f" return: {current_type} (confidence: {confidence})")
|
|
|
|
return confirmed_types
|
|
|
|
def generate(self, tree, inferred_types: Optional[Dict[str, Any]] = None) -> str:
|
|
"""Generate stub content from AST tree and inferred types."""
|
|
from stubgen.parser import FileInfo
|
|
|
|
file_info = FileInfo(path=Path("."))
|
|
file_info = self._extract_info_from_tree(tree, file_info)
|
|
|
|
return self.writer.generate(file_info)
|
|
|
|
def _extract_info_from_tree(self, tree, file_info: FileInfo) -> FileInfo:
|
|
"""Extract file info from AST tree."""
|
|
from stubgen.parser import Function, Class, Variable, TypeHint
|
|
import ast
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
|
|
if node.col_offset == 0:
|
|
func = Function(
|
|
name=node.name,
|
|
args=node.args,
|
|
returns=TypeHint(annotation=node.returns) if node.returns else None,
|
|
decorators=[],
|
|
body=list(node.body),
|
|
is_async=isinstance(node, ast.AsyncFunctionDef)
|
|
)
|
|
file_info.functions.append(func)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ClassDef):
|
|
cls = Class(
|
|
name=node.name,
|
|
bases=[ast.unparse(base) for base in node.bases],
|
|
methods=[],
|
|
attributes=[],
|
|
decorators=[],
|
|
body=list(node.body)
|
|
)
|
|
|
|
for item in node.body:
|
|
if isinstance(item, ast.FunctionDef) or isinstance(item, ast.AsyncFunctionDef):
|
|
method = Function(
|
|
name=item.name,
|
|
args=item.args,
|
|
returns=TypeHint(annotation=item.returns) if item.returns else None,
|
|
decorators=[],
|
|
body=list(item.body),
|
|
is_async=isinstance(item, ast.AsyncFunctionDef)
|
|
)
|
|
cls.methods.append(method)
|
|
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
|
|
var = Variable(
|
|
name=item.target.id,
|
|
annotation=TypeHint(annotation=item.annotation) if item.annotation else None
|
|
)
|
|
cls.attributes.append(var)
|
|
|
|
file_info.classes.append(cls)
|
|
|
|
return file_info
|
|
|
|
def generate_file(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
|
|
"""Generate a stub file for a single Python file."""
|
|
from stubgen.parser import FileParser
|
|
|
|
file_info = FileParser(input_path).parse()
|
|
stub_content = self.writer.generate(file_info)
|
|
|
|
if output_path is None:
|
|
output_path = input_path.with_suffix(".pyi")
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
f.write(stub_content)
|
|
|
|
return output_path
|
|
|
|
def generate_directory(
|
|
self,
|
|
input_dir: Path,
|
|
output_dir: Optional[Path] = None,
|
|
recursive: bool = True
|
|
) -> List[Path]:
|
|
"""Generate stub files for all Python files in a directory."""
|
|
from stubgen.parser import parse_directory
|
|
|
|
if output_dir is None:
|
|
output_dir = input_dir
|
|
|
|
results = []
|
|
|
|
file_infos = parse_directory(input_dir, recursive)
|
|
|
|
for py_file, file_info in sorted(file_infos.items()):
|
|
try:
|
|
stub_content = self.writer.generate(file_info)
|
|
rel_path = py_file.relative_to(input_dir)
|
|
output_path = output_dir / rel_path.with_suffix(".pyi")
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
f.write(stub_content)
|
|
|
|
results.append(output_path)
|
|
except Exception as e:
|
|
print(f"Error generating stub for {py_file}: {e}", file=sys.stderr)
|
|
|
|
return results
|