"""Stub generator module for creating .pyi files.""" import ast import re import sys from pathlib import Path from typing import Any, Dict, List, Optional, Set from stubgen.parser import FileInfo, Function, Class, Variable, Import from stubgen.inferrer import Inferrer def _split_type_args(s: str) -> List[str]: """Split type arguments handling nested brackets.""" result = [] current = "" depth = 0 for char in s: if char == '[': depth += 1 current += char elif char == ']': depth -= 1 current += char elif char == ',' and depth == 0: if current.strip(): result.append(current.strip()) current = "" else: current += char if current.strip(): result.append(current.strip()) return result class StubWriter: """Writes .pyi stub files from parsed and inferred information.""" def __init__(self, inferrer: Optional[Inferrer] = None): self.inferrer = inferrer or Inferrer() self.imports: Set[str] = set() def generate(self, file_info: FileInfo) -> str: """Generate a complete .pyi stub content.""" self.imports = set() lines = [] if file_info.docstring: lines.append(f'"""{file_info.docstring}"""') lines.append("") self._process_imports(file_info.imports) if self.imports: lines.extend(sorted(self.imports)) lines.append("") for var in file_info.variables: var_lines = self._format_variable(var) lines.extend(var_lines) if file_info.variables: lines.append("") for cls in file_info.classes: class_lines = self._format_class(cls) lines.extend(class_lines) lines.append("") for func in file_info.functions: func_lines = self._format_function(func) lines.extend(func_lines) lines.append("") return "\n".join(lines) def _process_imports(self, imports: List[Import]): for imp in imports: if imp.is_from: prefix = "." * imp.level if imp.module: prefix += imp.module if imp.names: parts = [] for name in imp.names: if name in imp.as_names: parts.append(f"{name} as {imp.as_names[name]}") else: parts.append(name) import_str = f"from {prefix} import {', '.join(parts)}" self.imports.add(import_str) else: if imp.module and imp.names: for name in imp.names: if name in imp.as_names: self.imports.add(f"import {imp.module} as {imp.as_names[name]}") else: self.imports.add(f"import {imp.module}.{name}") elif imp.module: self.imports.add(f"import {imp.module}") elif imp.names: for name in imp.names: if name in imp.as_names: self.imports.add(f"import {name} as {imp.as_names[name]}") else: self.imports.add(f"import {name}") def _format_variable(self, var: Variable) -> List[str]: lines = [] if var.annotation: annotation_str = var.annotation.to_string() self._add_needed_import(annotation_str) lines.append(f"{var.name}: {annotation_str}") else: lines.append(f"{var.name}: Any") return lines def _format_function(self, func: Function, inferred_info: Dict = None) -> List[str]: lines = [] for decorator in func.decorators: lines.append(f"@{decorator}") async_prefix = "async " if func.is_async else "" args_parts = [] defaults_offset = len(func.args.args) - len(func.args.defaults) for i, arg in enumerate(func.args.args): arg_str = arg.arg if arg.annotation: ann_str = ast.unparse(arg.annotation) self._add_needed_import(ann_str) arg_str += f": {ann_str}" elif inferred_info and arg.arg in inferred_info.get("args", {}): arg_info = inferred_info["args"][arg.arg] arg_type = arg_info.get("type", "Any") self._add_needed_import(arg_type) if "None" in str(arg_type): arg_str += f": {arg_type}" elif i >= defaults_offset and i - defaults_offset < len(func.args.defaults): default = func.args.defaults[i - defaults_offset] default_type = self._infer_default_type(default) if default_type != "Any": self._add_needed_import(default_type) arg_str += f": {default_type}" args_parts.append(arg_str) if func.args.vararg: vararg_type = "tuple[Any, ...]" self._add_needed_import(vararg_type) args_parts.append(f"*{func.args.vararg.arg}: {vararg_type}") for kwonly in func.args.kwonlyargs: kw_str = kwonly.arg if kwonly.annotation: kw_str += f": {ast.unparse(kwonly.annotation)}" args_parts.append(kw_str) if func.args.kwarg: kwarg_type = "dict[str, Any]" self._add_needed_import(kwarg_type) args_parts.append(f"**{func.args.kwarg.arg}: {kwarg_type}") args_str = ", ".join(args_parts) return_type = "None" if func.returns: if hasattr(func.returns, 'annotation') and func.returns.annotation: return_type = ast.unparse(func.returns.annotation) else: return_type = ast.unparse(func.returns) self._add_needed_import(return_type) elif inferred_info: ret_info = inferred_info.get("returns", {}) return_type = ret_info.get("type", "None") self._add_needed_import(return_type) lines.append(f"{async_prefix}def {func.name}({args_str}) -> {return_type}: ...") return lines def _format_class(self, cls: Class, inferred_info: Dict = None) -> List[str]: lines = [] for decorator in cls.decorators: lines.append(f"@{decorator}") bases = ", ".join(cls.bases) if cls.bases else "object" lines.append(f"class {cls.name}({bases}):") indent = " " for attr in cls.attributes: if attr.annotation: attr_type = attr.annotation.to_string() self._add_needed_import(attr_type) lines.append(f"{indent}{attr.name}: {attr_type}") else: lines.append(f"{indent}{attr.name}: Any") if cls.attributes and cls.methods: lines.append("") if inferred_info: methods = inferred_info.get("methods", {}) else: methods = {} for method in cls.methods: method_inferred = methods.get(method.name, {}) method_lines = self._format_function(method, method_inferred) for line in method_lines: lines.append(f"{indent}{line}") for nested_cls in cls.body: if isinstance(nested_cls, ast.ClassDef): nested_info = inferred_info.get("classes", {}).get(nested_cls.name, {}) if inferred_info else {} for line in lines: pass lines.append(f"{indent}class {nested_cls.name}:") lines.append(f"{indent} pass") if not cls.attributes and not cls.methods and not any(isinstance(n, ast.ClassDef) for n in cls.body): lines.append(f"{indent}pass") return lines def _infer_default_type(self, default: ast.expr) -> str: if default is None: return "Any" if isinstance(default, ast.Constant): value = default.value if value is None: return "None" elif isinstance(value, bool): return "bool" elif isinstance(value, int): return "int" elif isinstance(value, float): return "float" elif isinstance(value, str): return "str" elif isinstance(default, ast.List): return "list" elif isinstance(default, ast.Dict): return "dict" elif isinstance(default, ast.Tuple): return "tuple" elif isinstance(default, ast.Set): return "set" return "Any" def _add_needed_import(self, type_str: str): """Add import statements for types that need them.""" if not type_str or type_str == "Any": return type_str = type_str.strip() container_match = re.match(r'(List|Dict|Tuple|Set|FrozenSet|Callable|Type|Union|Optional)\[', type_str) if container_match: container_type = container_match.group(1) if container_type == "List": self.imports.add("from typing import List") elif container_type == "Dict": self.imports.add("from typing import Dict") elif container_type == "Tuple": self.imports.add("from typing import Tuple") elif container_type == "Set": self.imports.add("from typing import Set") elif container_type == "FrozenSet": self.imports.add("from typing import FrozenSet") elif container_type == "Callable": self.imports.add("from typing import Callable") elif container_type == "Type": self.imports.add("from typing import Type") elif container_type in ("Union", "Optional"): self.imports.add("from typing import Union, Optional") inner_types = type_str[container_match.end():-1] inner_type_list = _split_type_args(inner_types) for inner in inner_type_list: self._add_needed_import(inner) generic_match = re.match(r'([A-Z][a-zA-Z0-9_]*)\[', type_str) if generic_match: type_name = generic_match.group(1) builtin_types = {"str", "int", "float", "bool", "list", "dict", "set", "tuple", "type", "object"} if type_name not in builtin_types: self.imports.add(f"from typing import {type_name}") self.imports.add("from typing import Any") class StubGenerator: """Main stub generator class that orchestrates parsing and generation.""" def __init__(self, inferrer: Optional[Inferrer] = None, interactive: bool = False): self.inferrer = inferrer or Inferrer() self.writer = StubWriter(self.inferrer) self.interactive = interactive def confirm_types_interactively(self, inferred_types: Dict[str, Any]) -> Dict[str, Any]: """Allow user to confirm or modify inferred types.""" import click confirmed_types = inferred_types.copy() if not inferred_types: return confirmed_types for func_name, type_info in inferred_types.get("functions", {}).items(): click.echo(f"\nFunction: {func_name}") args_info = type_info.get("args", {}) for arg_name, arg_info in args_info.items(): if arg_info.get("source") != "annotation": current_type = arg_info.get("type", "Any") confidence = arg_info.get("confidence", 0.5) click.echo(f" {arg_name}: {current_type} (confidence: {confidence})") returns_info = type_info.get("returns", {}) if returns_info.get("source") != "annotation": current_type = returns_info.get("type", "None") confidence = returns_info.get("confidence", 0.5) click.echo(f" return: {current_type} (confidence: {confidence})") return confirmed_types def generate(self, tree, inferred_types: Optional[Dict[str, Any]] = None) -> str: """Generate stub content from AST tree and inferred types.""" from stubgen.parser import FileInfo file_info = FileInfo(path=Path(".")) file_info = self._extract_info_from_tree(tree, file_info) return self.writer.generate(file_info) def _extract_info_from_tree(self, tree, file_info: FileInfo) -> FileInfo: """Extract file info from AST tree.""" from stubgen.parser import Function, Class, Variable, TypeHint import ast for node in ast.walk(tree): if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): if node.col_offset == 0: func = Function( name=node.name, args=node.args, returns=TypeHint(annotation=node.returns) if node.returns else None, decorators=[], body=list(node.body), is_async=isinstance(node, ast.AsyncFunctionDef) ) file_info.functions.append(func) for node in ast.walk(tree): if isinstance(node, ast.ClassDef): cls = Class( name=node.name, bases=[ast.unparse(base) for base in node.bases], methods=[], attributes=[], decorators=[], body=list(node.body) ) for item in node.body: if isinstance(item, ast.FunctionDef) or isinstance(item, ast.AsyncFunctionDef): method = Function( name=item.name, args=item.args, returns=TypeHint(annotation=item.returns) if item.returns else None, decorators=[], body=list(item.body), is_async=isinstance(item, ast.AsyncFunctionDef) ) cls.methods.append(method) elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): var = Variable( name=item.target.id, annotation=TypeHint(annotation=item.annotation) if item.annotation else None ) cls.attributes.append(var) file_info.classes.append(cls) return file_info def generate_file(self, input_path: Path, output_path: Optional[Path] = None) -> Path: """Generate a stub file for a single Python file.""" from stubgen.parser import FileParser file_info = FileParser(input_path).parse() stub_content = self.writer.generate(file_info) if output_path is None: output_path = input_path.with_suffix(".pyi") output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.write(stub_content) return output_path def generate_directory( self, input_dir: Path, output_dir: Optional[Path] = None, recursive: bool = True ) -> List[Path]: """Generate stub files for all Python files in a directory.""" from stubgen.parser import parse_directory if output_dir is None: output_dir = input_dir results = [] file_infos = parse_directory(input_dir, recursive) for py_file, file_info in sorted(file_infos.items()): try: stub_content = self.writer.generate(file_info) rel_path = py_file.relative_to(input_dir) output_path = output_dir / rel_path.with_suffix(".pyi") output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.write(stub_content) results.append(output_path) except Exception as e: print(f"Error generating stub for {py_file}: {e}", file=sys.stderr) return results