diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 8a42f1cbb..d0adb66ff 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -34,8 +34,6 @@ class Graph: def __init__( self, - nodes: List[Dict], - edges: List[Dict[str, str]], flow_id: Optional[str] = None, flow_name: Optional[str] = None, user_id: Optional[str] = None, @@ -48,9 +46,6 @@ class Graph: edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph. flow_id (Optional[str], optional): The ID of the flow. Defaults to None. """ - self._vertices = nodes - self._edges = edges - self.raw_graph_data = {"nodes": nodes, "edges": edges} self._runs = 0 self._updates = 0 self.flow_id = flow_id @@ -63,7 +58,26 @@ class Graph: self._sorted_vertices_layers: List[List[str]] = [] self._run_id = "" self._start_time = datetime.now(timezone.utc) + self.inactivated_vertices: set = set() + self.activated_vertices: List[str] = [] + self.vertices_layers: List[List[str]] = [] + self.vertices_to_run: set[str] = set() + self.stop_vertex: Optional[str] = None + self.inactive_vertices: set = set() + self.edges: List[ContractEdge] = [] + self.vertices: List[Vertex] = [] + self.run_manager = RunnableVerticesManager() + self.state_manager = GraphStateManager() + try: + self.tracing_service: "TracingService" | None = get_tracing_service() + except Exception as exc: + logger.error(f"Error getting tracing service: {exc}") + self.tracing_service = None + def add_nodes_and_edges(self, nodes: List[Dict], edges: List[Dict[str, str]]): + self._vertices = nodes + self._edges = edges + self.raw_graph_data = {"nodes": nodes, "edges": edges} self.top_level_vertices = [] for vertex in self._vertices: if vertex_id := vertex.get("id"): @@ -72,25 +86,20 @@ class Graph: self._vertices = self._graph_data["nodes"] self._edges = self._graph_data["edges"] - self.inactivated_vertices: set = set() - self.activated_vertices: List[str] = [] - self.vertices_layers: List[List[str]] = [] - self.vertices_to_run: set[str] = set() - self.stop_vertex: Optional[str] = None + self.initialize() - self.inactive_vertices: set = set() - self.edges: List[ContractEdge] = [] - self.vertices: List[Vertex] = [] - self.run_manager = RunnableVerticesManager() + # TODO: Create a TypedDict to represente the node + def add_node(self, node: dict): + self._vertices.append(node) + + # TODO: Create a TypedDict to represente the edge + def add_edge(self, edge: dict): + self._edges.append(edge) + + def initialize(self): self._build_graph() self.build_graph_maps(self.edges) self.define_vertices_lists() - self.state_manager = GraphStateManager() - try: - self.tracing_service: "TracingService" | None = get_tracing_service() - except Exception as exc: - logger.error(f"Error getting tracing service: {exc}") - self.tracing_service = None def get_state(self, name: str) -> Optional[Data]: """ @@ -638,7 +647,9 @@ class Graph: try: vertices = payload["nodes"] edges = payload["edges"] - return cls(vertices, edges, flow_id, flow_name, user_id) + graph = cls(flow_id, flow_name, user_id) + graph.add_nodes_and_edges(vertices, edges) + return graph except KeyError as exc: logger.exception(exc) if "nodes" not in payload and "edges" not in payload: @@ -1188,20 +1199,24 @@ class Graph: """Builds the vertices of the graph.""" vertices: List[Vertex] = [] for vertex in self._vertices: - vertex_data = vertex["data"] - vertex_type: str = vertex_data["type"] # type: ignore - vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - if "id" not in vertex_data: - raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id") - - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) - - vertex_instance = VertexClass(vertex, graph=self) - vertex_instance.set_top_level(self.top_level_vertices) + vertex_instance = self._create_vertex(vertex) vertices.append(vertex_instance) return vertices + def _create_vertex(self, vertex: dict): + vertex_data = vertex["data"] + vertex_type: str = vertex_data["type"] # type: ignore + vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore + if "id" not in vertex_data: + raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id") + + VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + + vertex_instance = VertexClass(vertex, graph=self) + vertex_instance.set_top_level(self.top_level_vertices) + return vertex_instance + def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index ab77315f3..11a3fa2b7 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -158,7 +158,9 @@ def get_graph(_type="basic"): data_graph = flow_graph["data"] nodes = data_graph["nodes"] edges = data_graph["edges"] - return Graph(nodes, edges) + graph = Graph() + graph.add_nodes_and_edges(nodes, edges) + return graph @pytest.fixture diff --git a/src/backend/tests/unit/test_graph.py b/src/backend/tests/unit/graph/test_graph.py similarity index 99% rename from src/backend/tests/unit/test_graph.py rename to src/backend/tests/unit/graph/test_graph.py index ec69051d2..f6b85aa4e 100644 --- a/src/backend/tests/unit/test_graph.py +++ b/src/backend/tests/unit/graph/test_graph.py @@ -116,7 +116,8 @@ def test_invalid_node_types(): "edges": [], } with pytest.raises(Exception): - Graph(graph_data["nodes"], graph_data["edges"]) + g = Graph() + g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"]) def test_get_vertices_with_target(basic_graph):