diff --git a/tests/unit/test_graph_builder.py b/tests/unit/test_graph_builder.py new file mode 100644 index 0000000..8f77cec --- /dev/null +++ b/tests/unit/test_graph_builder.py @@ -0,0 +1,138 @@ +import pytest +from pathlib import Path +from src.graph.builder import GraphBuilder, GraphType, GraphNode, NodeType, GraphEdge +from src.parsers.base import Entity, EntityType + + +class TestGraphBuilder: + def setup_method(self): + self.builder = GraphBuilder(GraphType.DIRECTED) + + def test_add_node(self): + node = GraphNode( + node_id="test_node", + node_type=NodeType.FILE, + name="test.py", + file_path=Path("/test/test.py"), + ) + node_id = self.builder.add_node(node) + assert node_id == "test_node" + assert len(self.builder.nodes) == 1 + + def test_add_multiple_nodes(self): + node1 = GraphNode(node_id="node1", node_type=NodeType.FILE, name="file1.py") + node2 = GraphNode(node_id="node2", node_type=NodeType.FILE, name="file2.py") + self.builder.add_node(node1) + self.builder.add_node(node2) + assert len(self.builder.nodes) == 2 + + def test_add_edge(self): + source = GraphNode(node_id="source", node_type=NodeType.FILE, name="source.py") + target = GraphNode(node_id="target", node_type=NodeType.FILE, name="target.py") + self.builder.add_node(source) + self.builder.add_node(target) + + edge = GraphEdge(source="source", target="target", edge_type="imports") + self.builder.add_edge(edge) + assert len(self.builder.edges) == 1 + + def test_get_graph(self): + graph = self.builder.get_graph() + assert graph is not None + + def test_get_node_by_id(self): + node = GraphNode(node_id="test", node_type=NodeType.FILE, name="test.py") + self.builder.add_node(node) + retrieved = self.builder.get_node_by_id("test") + assert retrieved is not None + assert retrieved.name == "test.py" + + def test_get_nodes_by_type(self): + file_node = GraphNode(node_id="file1", node_type=NodeType.FILE, name="file.py") + func_node = GraphNode(node_id="func1", node_type=NodeType.FUNCTION, name="func") + self.builder.add_node(file_node) + self.builder.add_node(func_node) + + file_nodes = self.builder.get_nodes_by_type(NodeType.FILE) + func_nodes = self.builder.get_nodes_by_type(NodeType.FUNCTION) + assert len(file_nodes) == 1 + assert len(func_nodes) == 1 + + def test_serialize(self): + node = GraphNode( + node_id="test", + node_type=NodeType.FILE, + name="test.py", + file_path=Path("/test.py"), + start_line=1, + end_line=10, + ) + self.builder.add_node(node) + + data = self.builder.serialize() + assert "nodes" in data + assert "edges" in data + assert len(data["nodes"]) == 1 + assert data["nodes"][0]["name"] == "test.py" + + def test_deserialize(self): + data = { + "nodes": [ + { + "id": "test", + "type": "file", + "name": "test.py", + "file_path": "/test.py", + "start_line": 1, + "end_line": 10, + } + ], + "edges": [], + } + self.builder.deserialize(data) + assert len(self.builder.nodes) == 1 + + def test_get_subgraph(self): + node1 = GraphNode(node_id="n1", node_type=NodeType.FILE, name="f1.py") + node2 = GraphNode(node_id="n2", node_type=NodeType.FILE, name="f2.py") + node3 = GraphNode(node_id="n3", node_type=NodeType.FILE, name="f3.py") + self.builder.add_node(node1) + self.builder.add_node(node2) + self.builder.add_node(node3) + + subgraph = self.builder.get_subgraph(["n1", "n2"]) + assert subgraph.number_of_nodes() == 2 + + +class TestGraphNode: + def test_default_label(self): + node = GraphNode(node_id="test", node_type=NodeType.FILE, name="test.py") + assert node.label == "test.py" + + def test_custom_label(self): + node = GraphNode( + node_id="test", + node_type=NodeType.FILE, + name="test.py", + label="Custom Label", + ) + assert node.label == "Custom Label" + + def test_default_style_and_shape(self): + node = GraphNode(node_id="test", node_type=NodeType.FILE, name="test.py") + assert node.style == "filled" + assert node.shape == "ellipse" + + +class TestGraphEdge: + def test_default_edge_type(self): + edge = GraphEdge(source="a", target="b") + assert edge.edge_type == "depends" + + def test_custom_edge_type(self): + edge = GraphEdge(source="a", target="b", edge_type="imports") + assert edge.edge_type == "imports" + + def test_edge_with_label(self): + edge = GraphEdge(source="a", target="b", label="imports os") + assert edge.label == "imports os"