Initial commit: Add python-stub-generator project
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
2026-01-30 04:51:45 +00:00
parent 085e7aa7d6
commit 78db650003

449
stubgen/generator.py Normal file
View File

@@ -0,0 +1,449 @@
"""Stub generator module for creating .pyi files."""
import ast
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from stubgen.parser import FileInfo, Function, Class, Variable, Import
from stubgen.inferrer import Inferrer
def _split_type_args(s: str) -> List[str]:
"""Split type arguments handling nested brackets."""
result = []
current = ""
depth = 0
for char in s:
if char == '[':
depth += 1
current += char
elif char == ']':
depth -= 1
current += char
elif char == ',' and depth == 0:
if current.strip():
result.append(current.strip())
current = ""
else:
current += char
if current.strip():
result.append(current.strip())
return result
class StubWriter:
"""Writes .pyi stub files from parsed and inferred information."""
def __init__(self, inferrer: Optional[Inferrer] = None):
self.inferrer = inferrer or Inferrer()
self.imports: Set[str] = set()
def generate(self, file_info: FileInfo) -> str:
"""Generate a complete .pyi stub content."""
self.imports = set()
lines = []
if file_info.docstring:
lines.append(f'"""{file_info.docstring}"""')
lines.append("")
self._process_imports(file_info.imports)
if self.imports:
lines.extend(sorted(self.imports))
lines.append("")
for var in file_info.variables:
var_lines = self._format_variable(var)
lines.extend(var_lines)
if file_info.variables:
lines.append("")
for cls in file_info.classes:
class_lines = self._format_class(cls)
lines.extend(class_lines)
lines.append("")
for func in file_info.functions:
func_lines = self._format_function(func)
lines.extend(func_lines)
lines.append("")
return "\n".join(lines)
def _process_imports(self, imports: List[Import]):
for imp in imports:
if imp.is_from:
prefix = "." * imp.level
if imp.module:
prefix += imp.module
if imp.names:
parts = []
for name in imp.names:
if name in imp.as_names:
parts.append(f"{name} as {imp.as_names[name]}")
else:
parts.append(name)
import_str = f"from {prefix} import {', '.join(parts)}"
self.imports.add(import_str)
else:
if imp.module and imp.names:
for name in imp.names:
if name in imp.as_names:
self.imports.add(f"import {imp.module} as {imp.as_names[name]}")
else:
self.imports.add(f"import {imp.module}.{name}")
elif imp.module:
self.imports.add(f"import {imp.module}")
elif imp.names:
for name in imp.names:
if name in imp.as_names:
self.imports.add(f"import {name} as {imp.as_names[name]}")
else:
self.imports.add(f"import {name}")
def _format_variable(self, var: Variable) -> List[str]:
lines = []
if var.annotation:
annotation_str = var.annotation.to_string()
self._add_needed_import(annotation_str)
lines.append(f"{var.name}: {annotation_str}")
else:
lines.append(f"{var.name}: Any")
return lines
def _format_function(self, func: Function, inferred_info: Dict = None) -> List[str]:
lines = []
for decorator in func.decorators:
lines.append(f"@{decorator}")
async_prefix = "async " if func.is_async else ""
args_parts = []
defaults_offset = len(func.args.args) - len(func.args.defaults)
for i, arg in enumerate(func.args.args):
arg_str = arg.arg
if arg.annotation:
ann_str = ast.unparse(arg.annotation)
self._add_needed_import(ann_str)
arg_str += f": {ann_str}"
elif inferred_info and arg.arg in inferred_info.get("args", {}):
arg_info = inferred_info["args"][arg.arg]
arg_type = arg_info.get("type", "Any")
self._add_needed_import(arg_type)
if "None" in str(arg_type):
arg_str += f": {arg_type}"
elif i >= defaults_offset and i - defaults_offset < len(func.args.defaults):
default = func.args.defaults[i - defaults_offset]
default_type = self._infer_default_type(default)
if default_type != "Any":
self._add_needed_import(default_type)
arg_str += f": {default_type}"
args_parts.append(arg_str)
if func.args.vararg:
vararg_type = "tuple[Any, ...]"
self._add_needed_import(vararg_type)
args_parts.append(f"*{func.args.vararg.arg}: {vararg_type}")
for kwonly in func.args.kwonlyargs:
kw_str = kwonly.arg
if kwonly.annotation:
kw_str += f": {ast.unparse(kwonly.annotation)}"
args_parts.append(kw_str)
if func.args.kwarg:
kwarg_type = "dict[str, Any]"
self._add_needed_import(kwarg_type)
args_parts.append(f"**{func.args.kwarg.arg}: {kwarg_type}")
args_str = ", ".join(args_parts)
return_type = "None"
if func.returns:
if hasattr(func.returns, 'annotation') and func.returns.annotation:
return_type = ast.unparse(func.returns.annotation)
else:
return_type = ast.unparse(func.returns)
self._add_needed_import(return_type)
elif inferred_info:
ret_info = inferred_info.get("returns", {})
return_type = ret_info.get("type", "None")
self._add_needed_import(return_type)
lines.append(f"{async_prefix}def {func.name}({args_str}) -> {return_type}: ...")
return lines
def _format_class(self, cls: Class, inferred_info: Dict = None) -> List[str]:
lines = []
for decorator in cls.decorators:
lines.append(f"@{decorator}")
bases = ", ".join(cls.bases) if cls.bases else "object"
lines.append(f"class {cls.name}({bases}):")
indent = " "
for attr in cls.attributes:
if attr.annotation:
attr_type = attr.annotation.to_string()
self._add_needed_import(attr_type)
lines.append(f"{indent}{attr.name}: {attr_type}")
else:
lines.append(f"{indent}{attr.name}: Any")
if cls.attributes and cls.methods:
lines.append("")
if inferred_info:
methods = inferred_info.get("methods", {})
else:
methods = {}
for method in cls.methods:
method_inferred = methods.get(method.name, {})
method_lines = self._format_function(method, method_inferred)
for line in method_lines:
lines.append(f"{indent}{line}")
for nested_cls in cls.body:
if isinstance(nested_cls, ast.ClassDef):
nested_info = inferred_info.get("classes", {}).get(nested_cls.name, {}) if inferred_info else {}
for line in lines:
pass
lines.append(f"{indent}class {nested_cls.name}:")
lines.append(f"{indent} pass")
if not cls.attributes and not cls.methods and not any(isinstance(n, ast.ClassDef) for n in cls.body):
lines.append(f"{indent}pass")
return lines
def _infer_default_type(self, default: ast.expr) -> str:
if default is None:
return "Any"
if isinstance(default, ast.Constant):
value = default.value
if value is None:
return "None"
elif isinstance(value, bool):
return "bool"
elif isinstance(value, int):
return "int"
elif isinstance(value, float):
return "float"
elif isinstance(value, str):
return "str"
elif isinstance(default, ast.List):
return "list"
elif isinstance(default, ast.Dict):
return "dict"
elif isinstance(default, ast.Tuple):
return "tuple"
elif isinstance(default, ast.Set):
return "set"
return "Any"
def _add_needed_import(self, type_str: str):
"""Add import statements for types that need them."""
if not type_str or type_str == "Any":
return
type_str = type_str.strip()
container_match = re.match(r'(List|Dict|Tuple|Set|FrozenSet|Callable|Type|Union|Optional)\[', type_str)
if container_match:
container_type = container_match.group(1)
if container_type == "List":
self.imports.add("from typing import List")
elif container_type == "Dict":
self.imports.add("from typing import Dict")
elif container_type == "Tuple":
self.imports.add("from typing import Tuple")
elif container_type == "Set":
self.imports.add("from typing import Set")
elif container_type == "FrozenSet":
self.imports.add("from typing import FrozenSet")
elif container_type == "Callable":
self.imports.add("from typing import Callable")
elif container_type == "Type":
self.imports.add("from typing import Type")
elif container_type in ("Union", "Optional"):
self.imports.add("from typing import Union, Optional")
inner_types = type_str[container_match.end():-1]
inner_type_list = _split_type_args(inner_types)
for inner in inner_type_list:
self._add_needed_import(inner)
generic_match = re.match(r'([A-Z][a-zA-Z0-9_]*)\[', type_str)
if generic_match:
type_name = generic_match.group(1)
builtin_types = {"str", "int", "float", "bool", "list", "dict", "set", "tuple", "type", "object"}
if type_name not in builtin_types:
self.imports.add(f"from typing import {type_name}")
self.imports.add("from typing import Any")
class StubGenerator:
"""Main stub generator class that orchestrates parsing and generation."""
def __init__(self, inferrer: Optional[Inferrer] = None, interactive: bool = False):
self.inferrer = inferrer or Inferrer()
self.writer = StubWriter(self.inferrer)
self.interactive = interactive
def confirm_types_interactively(self, inferred_types: Dict[str, Any]) -> Dict[str, Any]:
"""Allow user to confirm or modify inferred types."""
import click
confirmed_types = inferred_types.copy()
if not inferred_types:
return confirmed_types
for func_name, type_info in inferred_types.get("functions", {}).items():
click.echo(f"\nFunction: {func_name}")
args_info = type_info.get("args", {})
for arg_name, arg_info in args_info.items():
if arg_info.get("source") != "annotation":
current_type = arg_info.get("type", "Any")
confidence = arg_info.get("confidence", 0.5)
click.echo(f" {arg_name}: {current_type} (confidence: {confidence})")
returns_info = type_info.get("returns", {})
if returns_info.get("source") != "annotation":
current_type = returns_info.get("type", "None")
confidence = returns_info.get("confidence", 0.5)
click.echo(f" return: {current_type} (confidence: {confidence})")
return confirmed_types
def generate(self, tree, inferred_types: Optional[Dict[str, Any]] = None) -> str:
"""Generate stub content from AST tree and inferred types."""
from stubgen.parser import FileInfo
file_info = FileInfo(path=Path("."))
file_info = self._extract_info_from_tree(tree, file_info)
return self.writer.generate(file_info)
def _extract_info_from_tree(self, tree, file_info: FileInfo) -> FileInfo:
"""Extract file info from AST tree."""
from stubgen.parser import Function, Class, Variable, TypeHint
import ast
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
if node.col_offset == 0:
func = Function(
name=node.name,
args=node.args,
returns=TypeHint(annotation=node.returns) if node.returns else None,
decorators=[],
body=list(node.body),
is_async=isinstance(node, ast.AsyncFunctionDef)
)
file_info.functions.append(func)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
cls = Class(
name=node.name,
bases=[ast.unparse(base) for base in node.bases],
methods=[],
attributes=[],
decorators=[],
body=list(node.body)
)
for item in node.body:
if isinstance(item, ast.FunctionDef) or isinstance(item, ast.AsyncFunctionDef):
method = Function(
name=item.name,
args=item.args,
returns=TypeHint(annotation=item.returns) if item.returns else None,
decorators=[],
body=list(item.body),
is_async=isinstance(item, ast.AsyncFunctionDef)
)
cls.methods.append(method)
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
var = Variable(
name=item.target.id,
annotation=TypeHint(annotation=item.annotation) if item.annotation else None
)
cls.attributes.append(var)
file_info.classes.append(cls)
return file_info
def generate_file(self, input_path: Path, output_path: Optional[Path] = None) -> Path:
"""Generate a stub file for a single Python file."""
from stubgen.parser import FileParser
file_info = FileParser(input_path).parse()
stub_content = self.writer.generate(file_info)
if output_path is None:
output_path = input_path.with_suffix(".pyi")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(stub_content)
return output_path
def generate_directory(
self,
input_dir: Path,
output_dir: Optional[Path] = None,
recursive: bool = True
) -> List[Path]:
"""Generate stub files for all Python files in a directory."""
from stubgen.parser import parse_directory
if output_dir is None:
output_dir = input_dir
results = []
file_infos = parse_directory(input_dir, recursive)
for py_file, file_info in sorted(file_infos.items()):
try:
stub_content = self.writer.generate(file_info)
rel_path = py_file.relative_to(input_dir)
output_path = output_dir / rel_path.with_suffix(".pyi")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(stub_content)
results.append(output_path)
except Exception as e:
print(f"Error generating stub for {py_file}: {e}", file=sys.stderr)
return results