- Fixed undefined 'tool' variable in display_history function - Changed '[tool]' markup tag usage to proper Rich syntax - All tests now pass (38/38 unit tests) - Type checking passes with mypy --strict
113 lines
3.1 KiB
Python
113 lines
3.1 KiB
Python
from codesnap.core.extractor import FunctionExtractor
|
|
|
|
|
|
class TestFunctionExtractor:
|
|
def setup_method(self) -> None:
|
|
self.extractor = FunctionExtractor()
|
|
|
|
def test_extract_simple_function(self) -> None:
|
|
code = """
|
|
def hello():
|
|
print("Hello, World!")
|
|
"""
|
|
functions = self.extractor.extract_functions_python(code)
|
|
assert len(functions) >= 1
|
|
func = functions[0]
|
|
assert func.name == "hello"
|
|
assert len(func.parameters) == 0
|
|
|
|
def test_extract_function_with_parameters(self) -> None:
|
|
code = """
|
|
def greet(name, greeting="Hello"):
|
|
return f"{greeting}, {name}!"
|
|
"""
|
|
functions = self.extractor.extract_functions_python(code)
|
|
assert len(functions) >= 1
|
|
func = functions[0]
|
|
assert func.name == "greet"
|
|
assert "name" in func.parameters
|
|
assert "greeting" in func.parameters
|
|
|
|
def test_extract_async_function(self) -> None:
|
|
code = """
|
|
async def fetch_data(url):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
return await response.json()
|
|
"""
|
|
functions = self.extractor.extract_functions_python(code)
|
|
assert len(functions) >= 1
|
|
func = functions[0]
|
|
assert func.name == "fetch_data"
|
|
assert func.is_async is True
|
|
|
|
def test_extract_function_with_return_type(self) -> None:
|
|
code = """
|
|
def add(a: int, b: int) -> int:
|
|
return a + b
|
|
"""
|
|
functions = self.extractor.extract_functions_python(code)
|
|
assert len(functions) >= 1
|
|
func = functions[0]
|
|
assert func.name == "add"
|
|
|
|
def test_extract_function_with_decorator(self) -> None:
|
|
code = """
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
"""
|
|
functions = self.extractor.extract_functions_python(code)
|
|
assert len(functions) >= 1
|
|
|
|
def test_extract_classes(self) -> None:
|
|
code = """
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.value = 42
|
|
|
|
def get_value(self):
|
|
return self.value
|
|
"""
|
|
classes = self.extractor.extract_classes_python(code)
|
|
assert len(classes) >= 1
|
|
cls = classes[0]
|
|
assert cls.name == "MyClass"
|
|
|
|
def test_extract_class_with_inheritance(self) -> None:
|
|
code = """
|
|
class ChildClass(ParentClass, MixinClass):
|
|
pass
|
|
"""
|
|
classes = self.extractor.extract_classes_python(code)
|
|
assert len(classes) >= 1
|
|
cls = classes[0]
|
|
assert "ParentClass" in cls.base_classes
|
|
assert "MixinClass" in cls.base_classes
|
|
|
|
def test_extract_all_python(self) -> None:
|
|
code = """
|
|
def func1():
|
|
pass
|
|
|
|
class MyClass:
|
|
def method1(self):
|
|
pass
|
|
|
|
def func2():
|
|
pass
|
|
"""
|
|
functions, classes = self.extractor.extract_all(code, "python")
|
|
assert len(functions) >= 2
|
|
assert len(classes) >= 1
|
|
|
|
def test_extract_from_file(self) -> None:
|
|
code = """
|
|
def test_function(x, y):
|
|
return x + y
|
|
"""
|
|
result = self.extractor.extract_from_file("test.py", code, "python")
|
|
assert result["file"] == "test.py"
|
|
assert len(result["functions"]) >= 1
|
|
assert result["functions"][0]["name"] == "test_function"
|