diff --git a/stubgen/generator.py b/stubgen/generator.py new file mode 100644 index 0000000..18613d9 --- /dev/null +++ b/stubgen/generator.py @@ -0,0 +1,449 @@ +"""Stub generator module for creating .pyi files.""" + +import ast +import re +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from stubgen.parser import FileInfo, Function, Class, Variable, Import +from stubgen.inferrer import Inferrer + + +def _split_type_args(s: str) -> List[str]: + """Split type arguments handling nested brackets.""" + result = [] + current = "" + depth = 0 + + for char in s: + if char == '[': + depth += 1 + current += char + elif char == ']': + depth -= 1 + current += char + elif char == ',' and depth == 0: + if current.strip(): + result.append(current.strip()) + current = "" + else: + current += char + + if current.strip(): + result.append(current.strip()) + + return result + + +class StubWriter: + """Writes .pyi stub files from parsed and inferred information.""" + + def __init__(self, inferrer: Optional[Inferrer] = None): + self.inferrer = inferrer or Inferrer() + self.imports: Set[str] = set() + + def generate(self, file_info: FileInfo) -> str: + """Generate a complete .pyi stub content.""" + self.imports = set() + lines = [] + + if file_info.docstring: + lines.append(f'"""{file_info.docstring}"""') + lines.append("") + + self._process_imports(file_info.imports) + + if self.imports: + lines.extend(sorted(self.imports)) + lines.append("") + + for var in file_info.variables: + var_lines = self._format_variable(var) + lines.extend(var_lines) + + if file_info.variables: + lines.append("") + + for cls in file_info.classes: + class_lines = self._format_class(cls) + lines.extend(class_lines) + lines.append("") + + for func in file_info.functions: + func_lines = self._format_function(func) + lines.extend(func_lines) + lines.append("") + + return "\n".join(lines) + + def _process_imports(self, imports: List[Import]): + for imp in imports: + if imp.is_from: + prefix = "." * imp.level + if imp.module: + prefix += imp.module + if imp.names: + parts = [] + for name in imp.names: + if name in imp.as_names: + parts.append(f"{name} as {imp.as_names[name]}") + else: + parts.append(name) + import_str = f"from {prefix} import {', '.join(parts)}" + self.imports.add(import_str) + else: + if imp.module and imp.names: + for name in imp.names: + if name in imp.as_names: + self.imports.add(f"import {imp.module} as {imp.as_names[name]}") + else: + self.imports.add(f"import {imp.module}.{name}") + elif imp.module: + self.imports.add(f"import {imp.module}") + elif imp.names: + for name in imp.names: + if name in imp.as_names: + self.imports.add(f"import {name} as {imp.as_names[name]}") + else: + self.imports.add(f"import {name}") + + def _format_variable(self, var: Variable) -> List[str]: + lines = [] + + if var.annotation: + annotation_str = var.annotation.to_string() + self._add_needed_import(annotation_str) + lines.append(f"{var.name}: {annotation_str}") + else: + lines.append(f"{var.name}: Any") + + return lines + + def _format_function(self, func: Function, inferred_info: Dict = None) -> List[str]: + lines = [] + + for decorator in func.decorators: + lines.append(f"@{decorator}") + + async_prefix = "async " if func.is_async else "" + + args_parts = [] + + defaults_offset = len(func.args.args) - len(func.args.defaults) + + for i, arg in enumerate(func.args.args): + arg_str = arg.arg + + if arg.annotation: + ann_str = ast.unparse(arg.annotation) + self._add_needed_import(ann_str) + arg_str += f": {ann_str}" + elif inferred_info and arg.arg in inferred_info.get("args", {}): + arg_info = inferred_info["args"][arg.arg] + arg_type = arg_info.get("type", "Any") + self._add_needed_import(arg_type) + if "None" in str(arg_type): + arg_str += f": {arg_type}" + elif i >= defaults_offset and i - defaults_offset < len(func.args.defaults): + default = func.args.defaults[i - defaults_offset] + default_type = self._infer_default_type(default) + if default_type != "Any": + self._add_needed_import(default_type) + arg_str += f": {default_type}" + + args_parts.append(arg_str) + + if func.args.vararg: + vararg_type = "tuple[Any, ...]" + self._add_needed_import(vararg_type) + args_parts.append(f"*{func.args.vararg.arg}: {vararg_type}") + + for kwonly in func.args.kwonlyargs: + kw_str = kwonly.arg + if kwonly.annotation: + kw_str += f": {ast.unparse(kwonly.annotation)}" + args_parts.append(kw_str) + + if func.args.kwarg: + kwarg_type = "dict[str, Any]" + self._add_needed_import(kwarg_type) + args_parts.append(f"**{func.args.kwarg.arg}: {kwarg_type}") + + args_str = ", ".join(args_parts) + + return_type = "None" + if func.returns: + if hasattr(func.returns, 'annotation') and func.returns.annotation: + return_type = ast.unparse(func.returns.annotation) + else: + return_type = ast.unparse(func.returns) + self._add_needed_import(return_type) + elif inferred_info: + ret_info = inferred_info.get("returns", {}) + return_type = ret_info.get("type", "None") + self._add_needed_import(return_type) + + lines.append(f"{async_prefix}def {func.name}({args_str}) -> {return_type}: ...") + + return lines + + def _format_class(self, cls: Class, inferred_info: Dict = None) -> List[str]: + lines = [] + + for decorator in cls.decorators: + lines.append(f"@{decorator}") + + bases = ", ".join(cls.bases) if cls.bases else "object" + lines.append(f"class {cls.name}({bases}):") + + indent = " " + + for attr in cls.attributes: + if attr.annotation: + attr_type = attr.annotation.to_string() + self._add_needed_import(attr_type) + lines.append(f"{indent}{attr.name}: {attr_type}") + else: + lines.append(f"{indent}{attr.name}: Any") + + if cls.attributes and cls.methods: + lines.append("") + + if inferred_info: + methods = inferred_info.get("methods", {}) + else: + methods = {} + + for method in cls.methods: + method_inferred = methods.get(method.name, {}) + method_lines = self._format_function(method, method_inferred) + for line in method_lines: + lines.append(f"{indent}{line}") + + for nested_cls in cls.body: + if isinstance(nested_cls, ast.ClassDef): + nested_info = inferred_info.get("classes", {}).get(nested_cls.name, {}) if inferred_info else {} + for line in lines: + pass + lines.append(f"{indent}class {nested_cls.name}:") + lines.append(f"{indent} pass") + + if not cls.attributes and not cls.methods and not any(isinstance(n, ast.ClassDef) for n in cls.body): + lines.append(f"{indent}pass") + + return lines + + def _infer_default_type(self, default: ast.expr) -> str: + if default is None: + return "Any" + + if isinstance(default, ast.Constant): + value = default.value + if value is None: + return "None" + elif isinstance(value, bool): + return "bool" + elif isinstance(value, int): + return "int" + elif isinstance(value, float): + return "float" + elif isinstance(value, str): + return "str" + elif isinstance(default, ast.List): + return "list" + elif isinstance(default, ast.Dict): + return "dict" + elif isinstance(default, ast.Tuple): + return "tuple" + elif isinstance(default, ast.Set): + return "set" + + return "Any" + + def _add_needed_import(self, type_str: str): + """Add import statements for types that need them.""" + if not type_str or type_str == "Any": + return + + type_str = type_str.strip() + + container_match = re.match(r'(List|Dict|Tuple|Set|FrozenSet|Callable|Type|Union|Optional)\[', type_str) + if container_match: + container_type = container_match.group(1) + if container_type == "List": + self.imports.add("from typing import List") + elif container_type == "Dict": + self.imports.add("from typing import Dict") + elif container_type == "Tuple": + self.imports.add("from typing import Tuple") + elif container_type == "Set": + self.imports.add("from typing import Set") + elif container_type == "FrozenSet": + self.imports.add("from typing import FrozenSet") + elif container_type == "Callable": + self.imports.add("from typing import Callable") + elif container_type == "Type": + self.imports.add("from typing import Type") + elif container_type in ("Union", "Optional"): + self.imports.add("from typing import Union, Optional") + + inner_types = type_str[container_match.end():-1] + inner_type_list = _split_type_args(inner_types) + for inner in inner_type_list: + self._add_needed_import(inner) + + generic_match = re.match(r'([A-Z][a-zA-Z0-9_]*)\[', type_str) + if generic_match: + type_name = generic_match.group(1) + builtin_types = {"str", "int", "float", "bool", "list", "dict", "set", "tuple", "type", "object"} + if type_name not in builtin_types: + self.imports.add(f"from typing import {type_name}") + + self.imports.add("from typing import Any") + + +class StubGenerator: + """Main stub generator class that orchestrates parsing and generation.""" + + def __init__(self, inferrer: Optional[Inferrer] = None, interactive: bool = False): + self.inferrer = inferrer or Inferrer() + self.writer = StubWriter(self.inferrer) + self.interactive = interactive + + def confirm_types_interactively(self, inferred_types: Dict[str, Any]) -> Dict[str, Any]: + """Allow user to confirm or modify inferred types.""" + import click + + confirmed_types = inferred_types.copy() + + if not inferred_types: + return confirmed_types + + for func_name, type_info in inferred_types.get("functions", {}).items(): + click.echo(f"\nFunction: {func_name}") + + args_info = type_info.get("args", {}) + for arg_name, arg_info in args_info.items(): + if arg_info.get("source") != "annotation": + current_type = arg_info.get("type", "Any") + confidence = arg_info.get("confidence", 0.5) + click.echo(f" {arg_name}: {current_type} (confidence: {confidence})") + + returns_info = type_info.get("returns", {}) + if returns_info.get("source") != "annotation": + current_type = returns_info.get("type", "None") + confidence = returns_info.get("confidence", 0.5) + click.echo(f" return: {current_type} (confidence: {confidence})") + + return confirmed_types + + def generate(self, tree, inferred_types: Optional[Dict[str, Any]] = None) -> str: + """Generate stub content from AST tree and inferred types.""" + from stubgen.parser import FileInfo + + file_info = FileInfo(path=Path(".")) + file_info = self._extract_info_from_tree(tree, file_info) + + return self.writer.generate(file_info) + + def _extract_info_from_tree(self, tree, file_info: FileInfo) -> FileInfo: + """Extract file info from AST tree.""" + from stubgen.parser import Function, Class, Variable, TypeHint + import ast + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): + if node.col_offset == 0: + func = Function( + name=node.name, + args=node.args, + returns=TypeHint(annotation=node.returns) if node.returns else None, + decorators=[], + body=list(node.body), + is_async=isinstance(node, ast.AsyncFunctionDef) + ) + file_info.functions.append(func) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + cls = Class( + name=node.name, + bases=[ast.unparse(base) for base in node.bases], + methods=[], + attributes=[], + decorators=[], + body=list(node.body) + ) + + for item in node.body: + if isinstance(item, ast.FunctionDef) or isinstance(item, ast.AsyncFunctionDef): + method = Function( + name=item.name, + args=item.args, + returns=TypeHint(annotation=item.returns) if item.returns else None, + decorators=[], + body=list(item.body), + is_async=isinstance(item, ast.AsyncFunctionDef) + ) + cls.methods.append(method) + elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): + var = Variable( + name=item.target.id, + annotation=TypeHint(annotation=item.annotation) if item.annotation else None + ) + cls.attributes.append(var) + + file_info.classes.append(cls) + + return file_info + + def generate_file(self, input_path: Path, output_path: Optional[Path] = None) -> Path: + """Generate a stub file for a single Python file.""" + from stubgen.parser import FileParser + + file_info = FileParser(input_path).parse() + stub_content = self.writer.generate(file_info) + + if output_path is None: + output_path = input_path.with_suffix(".pyi") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(stub_content) + + return output_path + + def generate_directory( + self, + input_dir: Path, + output_dir: Optional[Path] = None, + recursive: bool = True + ) -> List[Path]: + """Generate stub files for all Python files in a directory.""" + from stubgen.parser import parse_directory + + if output_dir is None: + output_dir = input_dir + + results = [] + + file_infos = parse_directory(input_dir, recursive) + + for py_file, file_info in sorted(file_infos.items()): + try: + stub_content = self.writer.generate(file_info) + rel_path = py_file.relative_to(input_dir) + output_path = output_dir / rel_path.with_suffix(".pyi") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(stub_content) + + results.append(output_path) + except Exception as e: + print(f"Error generating stub for {py_file}: {e}", file=sys.stderr) + + return results