diff --git a/tests/unit/test_graph_builder.py b/tests/unit/test_graph_builder.py index 8f77cec..59c66fa 100644 --- a/tests/unit/test_graph_builder.py +++ b/tests/unit/test_graph_builder.py @@ -1,7 +1,5 @@ -import pytest from pathlib import Path from src.graph.builder import GraphBuilder, GraphType, GraphNode, NodeType, GraphEdge -from src.parsers.base import Entity, EntityType class TestGraphBuilder: @@ -123,6 +121,19 @@ class TestGraphNode: assert node.style == "filled" assert node.shape == "ellipse" + def test_class_node_shape(self): + node = GraphNode(node_id="test", node_type=NodeType.CLASS, name="TestClass") + assert node.shape == "ellipse" + + def test_class_node_shape_in_builder(self): + from src.graph.builder import GraphBuilder, GraphType + builder = GraphBuilder(GraphType.DIRECTED) + node = GraphNode(node_id="test", node_type=NodeType.CLASS, name="TestClass") + node.shape = "diamond" + builder.add_node(node) + added_node = builder.get_node_by_id("test") + assert added_node.shape == "diamond" + class TestGraphEdge: def test_default_edge_type(self):