410 lines
13 KiB
Python
410 lines
13 KiB
Python
"""AST Parser module for analyzing Python source code."""
|
|
|
|
import ast
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
class TypeHint:
|
|
"""Represents a type hint extracted from AST."""
|
|
|
|
def __init__(
|
|
self,
|
|
annotation: Optional[ast.expr],
|
|
is_noneable: bool = False,
|
|
source: str = "annotation"
|
|
):
|
|
self.annotation = annotation
|
|
self.is_noneable = is_noneable
|
|
self.source = source
|
|
|
|
def to_string(self) -> str:
|
|
"""Convert annotation to string representation."""
|
|
if self.annotation is None:
|
|
return "Any"
|
|
return ast.unparse(self.annotation)
|
|
|
|
|
|
class Variable:
|
|
"""Represents a variable with optional type hint."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
annotation: Optional[TypeHint] = None,
|
|
value: Optional[ast.expr] = None
|
|
):
|
|
self.name = name
|
|
self.annotation = annotation
|
|
self.value = value
|
|
|
|
def to_string(self) -> str:
|
|
if self.annotation:
|
|
return f"{self.name}: {self.annotation.to_string()}"
|
|
return f"{self.name}: Any"
|
|
|
|
|
|
class Function:
|
|
"""Represents a function definition."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
args: ast.arguments,
|
|
returns: Optional[TypeHint] = None,
|
|
decorators: List[str] = None,
|
|
body: List[ast.stmt] = None,
|
|
is_async: bool = False
|
|
):
|
|
self.name = name
|
|
self.args = args
|
|
self.returns = returns
|
|
self.decorators = decorators or []
|
|
self.body = body or []
|
|
self.is_async = is_async
|
|
|
|
def to_string(self) -> str:
|
|
args_str = self._format_args()
|
|
returns_str = " -> None"
|
|
if self.returns:
|
|
returns_str = f" -> {self.returns.to_string()}"
|
|
|
|
decorator_str = ""
|
|
for dec in self.decorators:
|
|
decorator_str += f"@{dec}\n"
|
|
|
|
async_prefix = "async " if self.is_async else ""
|
|
return f"{decorator_str}{async_prefix}def {name}({args_str}){returns_str}: ..."
|
|
|
|
def _format_args(self) -> str:
|
|
parts = []
|
|
args = self.args
|
|
|
|
defaults_offset = len(args.args) - len(args.defaults)
|
|
|
|
for i, arg in enumerate(args.args):
|
|
arg_str = arg.arg
|
|
if arg.annotation:
|
|
arg_str += f": {arg.annotation}"
|
|
elif i >= defaults_offset:
|
|
default_idx = i - defaults_offset
|
|
if args.defaults[default_idx]:
|
|
default_value = ast.unparse(args.defaults[default_idx])
|
|
if isinstance(args.defaults[default_idx], ast.Constant) and args.defaults[default_idx].value is None:
|
|
arg_str += ": None"
|
|
|
|
parts.append(arg_str)
|
|
|
|
if args.vararg:
|
|
parts.append(f"*{args.vararg.arg}")
|
|
if args.kwonlyargs:
|
|
for kwarg in args.kwonlyargs:
|
|
kw_str = kwarg.arg
|
|
if kwarg.annotation:
|
|
kw_str += f": {kwarg.annotation}"
|
|
parts.append(kw_str)
|
|
if args.kwarg:
|
|
parts.append(f"**{args.kwarg.arg}")
|
|
|
|
return ", ".join(parts)
|
|
|
|
|
|
class Class:
|
|
"""Represents a class definition."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
bases: List[str] = None,
|
|
methods: List[Function] = None,
|
|
attributes: List[Variable] = None,
|
|
decorators: List[str] = None,
|
|
body: List[ast.stmt] = None
|
|
):
|
|
self.name = name
|
|
self.bases = bases or []
|
|
self.methods = methods or []
|
|
self.attributes = attributes or []
|
|
self.decorators = decorators or []
|
|
self.body = body or []
|
|
|
|
def to_string(self) -> str:
|
|
bases_str = ""
|
|
if self.bases:
|
|
bases_str = f"({', '.join(self.bases)})"
|
|
|
|
decorator_str = ""
|
|
for dec in self.decorators:
|
|
decorator_str += f"@{dec}\n"
|
|
|
|
return f"{decorator_str}class {self.name}{bases_str}:\n pass"
|
|
|
|
|
|
class Import:
|
|
"""Represents an import statement."""
|
|
|
|
def __init__(
|
|
self,
|
|
module: Optional[str] = None,
|
|
names: List[str] = None,
|
|
as_names: Dict[str, str] = None,
|
|
is_from: bool = False,
|
|
level: int = 0
|
|
):
|
|
self.module = module
|
|
self.names = names or []
|
|
self.as_names = as_names or {}
|
|
self.is_from = is_from
|
|
self.level = level
|
|
|
|
def to_string(self) -> str:
|
|
if self.is_from:
|
|
prefix = "." * self.level
|
|
if self.module:
|
|
prefix += self.module
|
|
if self.names:
|
|
name_str = ", ".join(
|
|
self.as_names.get(n, n) if n in self.as_names else n
|
|
for n in self.names
|
|
)
|
|
return f"from {prefix} import {name_str}"
|
|
return f"from {prefix} import *"
|
|
else:
|
|
if self.names:
|
|
name_str = ", ".join(
|
|
self.as_names.get(n, n) if n in self.as_names else n
|
|
for n in self.names
|
|
)
|
|
if self.module:
|
|
return f"import {self.module} as {name_str}" if len(self.names) == 1 and self.names[0] in self.as_names else f"import {self.module}.{name_str}"
|
|
return f"import {name_str}"
|
|
if self.module:
|
|
return f"import {self.module}"
|
|
return ""
|
|
|
|
|
|
class FileInfo:
|
|
"""Aggregates all information parsed from a Python file."""
|
|
|
|
def __init__(
|
|
self,
|
|
path: Path,
|
|
imports: List[Import] = None,
|
|
functions: List[Function] = None,
|
|
classes: List[Class] = None,
|
|
variables: List[Variable] = None,
|
|
docstring: Optional[str] = None
|
|
):
|
|
self.path = path
|
|
self.imports = imports or []
|
|
self.functions = functions or []
|
|
self.classes = classes or []
|
|
self.variables = variables or []
|
|
self.docstring = docstring
|
|
|
|
def to_string(self) -> str:
|
|
parts = []
|
|
if self.docstring:
|
|
parts.append(f'"""{self.docstring}"""')
|
|
|
|
for imp in self.imports:
|
|
imp_str = imp.to_string()
|
|
if imp_str:
|
|
parts.append(imp_str)
|
|
|
|
for var in self.variables:
|
|
parts.append(f" {var.to_string()}")
|
|
|
|
for cls in self.classes:
|
|
parts.append(cls.to_string())
|
|
|
|
for func in self.functions:
|
|
parts.append(func.to_string())
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
class FileParser:
|
|
"""Parses Python source files and extracts information for stub generation."""
|
|
|
|
def __init__(self, path: Path):
|
|
self.path = path
|
|
self.tree: Optional[ast.AST] = None
|
|
self.file_info: Optional[FileInfo] = None
|
|
|
|
def parse(self) -> FileInfo:
|
|
"""Parse the Python file and return FileInfo."""
|
|
try:
|
|
with open(self.path, 'r', encoding='utf-8') as f:
|
|
source = f.read()
|
|
except (UnicodeDecodeError, FileNotFoundError) as e:
|
|
raise ValueError(f"Cannot read file {self.path}: {e}")
|
|
|
|
try:
|
|
self.tree = ast.parse(source, filename=str(self.path))
|
|
except SyntaxError as e:
|
|
raise ValueError(f"Syntax error in {self.path}: {e}")
|
|
|
|
self.file_info = FileInfo(path=self.path)
|
|
self._extract_docstring()
|
|
self._extract_imports()
|
|
self._extract_classes_and_functions()
|
|
self._extract_variables()
|
|
|
|
return self.file_info
|
|
|
|
def _extract_docstring(self):
|
|
if self.tree.body and isinstance(self.tree.body[0], ast.Expr):
|
|
expr = self.tree.body[0]
|
|
if isinstance(expr.value, ast.Constant) and isinstance(expr.value.value, str):
|
|
self.file_info.docstring = expr.value.value
|
|
|
|
def _extract_imports(self):
|
|
for node in ast.walk(self.tree):
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
imp = Import(
|
|
module=alias.name if alias.name else None,
|
|
names=[alias.name] if alias.name else [],
|
|
as_names={alias.name: alias.asname} if alias.asname else {}
|
|
)
|
|
self.file_info.imports.append(imp)
|
|
|
|
elif isinstance(node, ast.ImportFrom):
|
|
names = [alias.name for alias in node.names]
|
|
as_names = {
|
|
alias.name: alias.asname
|
|
for alias in node.names
|
|
if alias.asname
|
|
}
|
|
module = node.module or ""
|
|
if node.level > 0:
|
|
module = "." * node.level + module
|
|
|
|
imp = Import(
|
|
module=module if module else None,
|
|
names=names,
|
|
as_names=as_names,
|
|
is_from=True,
|
|
level=node.level
|
|
)
|
|
self.file_info.imports.append(imp)
|
|
|
|
def _extract_classes_and_functions(self):
|
|
for node in ast.walk(self.tree):
|
|
if isinstance(node, ast.ClassDef):
|
|
self._extract_class(node)
|
|
elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
|
|
if node.col_offset == 0:
|
|
self._extract_function(node)
|
|
|
|
def _extract_class(self, node: ast.ClassDef):
|
|
bases = []
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Name):
|
|
bases.append(base.id)
|
|
elif isinstance(base, ast.Attribute):
|
|
bases.append(ast.unparse(base))
|
|
|
|
decorators = self._extract_decorators(node.decorator_list)
|
|
|
|
cls = Class(
|
|
name=node.name,
|
|
bases=bases,
|
|
decorators=decorators,
|
|
body=list(node.body)
|
|
)
|
|
|
|
for item in node.body:
|
|
if isinstance(item, ast.FunctionDef) or isinstance(item, ast.AsyncFunctionDef):
|
|
method = self._extract_function(item)
|
|
if method:
|
|
cls.methods.append(method)
|
|
elif isinstance(item, ast.AnnAssign):
|
|
var = self._extract_variable(item)
|
|
if var:
|
|
cls.attributes.append(var)
|
|
|
|
self.file_info.classes.append(cls)
|
|
|
|
def _extract_function(self, node: ast.FunctionDef) -> Optional[Function]:
|
|
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
return None
|
|
|
|
returns = None
|
|
if node.returns:
|
|
returns = TypeHint(annotation=node.returns)
|
|
|
|
decorators = self._extract_decorators(node.decorator_list)
|
|
|
|
func = Function(
|
|
name=node.name,
|
|
args=node.args,
|
|
returns=returns,
|
|
decorators=decorators,
|
|
body=list(node.body),
|
|
is_async=isinstance(node, ast.AsyncFunctionDef)
|
|
)
|
|
|
|
if node.col_offset == 0:
|
|
self.file_info.functions.append(func)
|
|
|
|
return func
|
|
|
|
def _extract_decorators(self, decorator_list: List[ast.expr]) -> List[str]:
|
|
decorators = []
|
|
for dec in decorator_list:
|
|
if isinstance(dec, ast.Name):
|
|
decorators.append(dec.id)
|
|
elif isinstance(dec, ast.Attribute):
|
|
decorators.append(ast.unparse(dec))
|
|
elif isinstance(dec, ast.Call):
|
|
decorators.append(ast.unparse(dec))
|
|
return decorators
|
|
|
|
def _extract_variables(self):
|
|
for node in ast.walk(self.tree):
|
|
if isinstance(node, ast.AnnAssign) and not isinstance(node.target, ast.Attribute):
|
|
if isinstance(node.target, ast.Name) and node.target.id[0].islower():
|
|
var = self._extract_variable(node)
|
|
if var:
|
|
self.file_info.variables.append(var)
|
|
|
|
def _extract_variable(self, node: ast.AnnAssign) -> Optional[Variable]:
|
|
if not isinstance(node.target, ast.Name):
|
|
return None
|
|
|
|
name = node.target.id
|
|
annotation = None
|
|
|
|
if node.annotation:
|
|
annotation = TypeHint(annotation=node.annotation)
|
|
|
|
return Variable(name=name, annotation=annotation, value=node.value)
|
|
|
|
|
|
def parse_file(path: Path) -> FileInfo:
|
|
"""Parse a Python file and return its information."""
|
|
parser = FileParser(path)
|
|
return parser.parse()
|
|
|
|
|
|
def parse_directory(path: Path, recursive: bool = True) -> Dict[Path, FileInfo]:
|
|
"""Parse all Python files in a directory."""
|
|
results = {}
|
|
|
|
if recursive:
|
|
py_files = path.rglob("*.py")
|
|
else:
|
|
py_files = path.glob("*.py")
|
|
|
|
for py_file in sorted(py_files):
|
|
try:
|
|
info = parse_file(py_file)
|
|
results[py_file] = info
|
|
except ValueError as e:
|
|
print(f"Warning: {e}", file=sys.stderr)
|
|
|
|
return results
|