From b34a7c7f02c6f2841838cd0116febf04b4fb12ca Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 27 Sep 2024 12:26:20 -0300 Subject: [PATCH] fix: Handle group nodes in graph sorting (#3929) * Fix: Handle group nodes in graph sorting - Added `get_root_of_group_node` function to identify the root of a group node. - Updated `sort_up_to_vertex` to use `get_root_of_group_node` for handling group nodes. - Modified `__filter_vertices` to pass `parent_node_map` to `sort_up_to_vertex`. * Refactor: Update NodeStatus component to handle group nodes and improve build status handling * [autofix.ci] apply automated fixes * Update type hint for parent_node_map in sort_up_to_vertex function --------- Co-authored-by: anovazzi1 Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/backend/base/langflow/graph/graph/base.py | 5 ++- .../base/langflow/graph/graph/utils.py | 35 ++++++++++++++++-- .../tests/unit/graph/graph/test_utils.py | 4 +- .../components/NodeStatus/index.tsx | 37 ++++++++++++++++--- .../src/CustomNodes/GenericNode/index.tsx | 1 + .../CustomNodes/hooks/use-icons-status.tsx | 7 ++-- 6 files changed, 75 insertions(+), 14 deletions(-) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 269a76391..c9bb6a60f 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -1885,7 +1885,10 @@ class Graph: def __filter_vertices(self, vertex_id: str, is_start: bool = False): dictionaryized_graph = self.__to_dict() - vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start) + parent_node_map = {vertex.id: vertex.parent_node_id for vertex in self.vertices} + vertex_ids = sort_up_to_vertex( + graph=dictionaryized_graph, vertex_id=vertex_id, parent_node_map=parent_node_map, is_start=is_start + ) return [self.get_vertex(vertex_id) for vertex_id in vertex_ids] def sort_vertices( diff --git a/src/backend/base/langflow/graph/graph/utils.py b/src/backend/base/langflow/graph/graph/utils.py index 5ccc82387..4f2dc346b 100644 --- a/src/backend/base/langflow/graph/graph/utils.py +++ b/src/backend/base/langflow/graph/graph/utils.py @@ -246,17 +246,46 @@ def get_successors(graph: dict[str, dict[str, list[str]]], vertex_id: str) -> li if current_id in visited: continue visited.add(current_id) - successors_result.append(current_id) + if current_id != vertex_id: + successors_result.append(current_id) stack.extend(graph[current_id]["successors"]) return successors_result -def sort_up_to_vertex(graph: dict[str, dict[str, list[str]]], vertex_id: str, is_start: bool = False) -> list[str]: +def get_root_of_group_node( + graph: dict[str, dict[str, list[str]]], vertex_id: str, parent_node_map: dict[str, str | None] +) -> str: + """Returns the root of a group node.""" + if vertex_id in parent_node_map.values(): + # Get all vertices with vertex_id as their parent node + child_vertices = [v_id for v_id, parent_id in parent_node_map.items() if parent_id == vertex_id] + + # Now go through successors of the child vertices + # and get the one that none of its successors is in child_vertices + for child_id in child_vertices: + successors = get_successors(graph, child_id) + if not any(successor in child_vertices for successor in successors): + return child_id + + raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found") + + +def sort_up_to_vertex( + graph: dict[str, dict[str, list[str]]], + vertex_id: str, + parent_node_map: dict[str, str | None] | None = None, + is_start: bool = False, +) -> list[str]: """Cuts the graph up to a given vertex and sorts the resulting subgraph.""" try: stop_or_start_vertex = graph[vertex_id] except KeyError: - raise ValueError(f"Vertex {vertex_id} not found into graph") + if parent_node_map is None: + raise ValueError("Parent node map is required to find the root of a group node") + vertex_id = get_root_of_group_node(graph=graph, vertex_id=vertex_id, parent_node_map=parent_node_map) + if vertex_id not in graph: + raise ValueError(f"Vertex {vertex_id} not found into graph") + stop_or_start_vertex = graph[vertex_id] visited, excluded = set(), set() stack = [vertex_id] diff --git a/src/backend/tests/unit/graph/graph/test_utils.py b/src/backend/tests/unit/graph/graph/test_utils.py index 5f211604c..e224c58a7 100644 --- a/src/backend/tests/unit/graph/graph/test_utils.py +++ b/src/backend/tests/unit/graph/graph/test_utils.py @@ -45,7 +45,7 @@ def test_get_successors_a(graph): result = utils.get_successors(graph, vertex_id) - assert set(result) == {"A", "B", "D", "E", "F", "H", "G"} + assert set(result) == {"B", "D", "E", "F", "H", "G"} def test_get_successors_z(graph): @@ -53,7 +53,7 @@ def test_get_successors_z(graph): result = utils.get_successors(graph, vertex_id) - assert set(result) == {"Z"} + assert len(result) == 0 def test_sort_up_to_vertex_n_is_start(graph): diff --git a/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx index 670a9f9e4..32a826c2f 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx @@ -16,6 +16,8 @@ import { useDarkStore } from "@/stores/darkStore"; import useFlowStore from "@/stores/flowStore"; import { useShortcutsStore } from "@/stores/shortcuts"; import { VertexBuildTypeAPI } from "@/types/api"; +import { NodeDataType } from "@/types/flow"; +import { findLastNode } from "@/utils/reactflowUtils"; import { classNames } from "@/utils/utils"; import { useEffect, useState } from "react"; import { useHotkeys } from "react-hotkeys-hook"; @@ -28,6 +30,7 @@ export default function NodeStatus({ setBorderColor, frozen, showNode, + data, }: { nodeId: string; display_name: string; @@ -35,15 +38,39 @@ export default function NodeStatus({ setBorderColor: (color: string) => void; frozen?: boolean; showNode: boolean; + data: NodeDataType; }) { + const nodeId_ = data.node?.flow?.data + ? (findLastNode(data.node?.flow.data!)?.id ?? nodeId) + : nodeId; const [validationString, setValidationString] = useState(""); const [validationStatus, setValidationStatus] = useState(null); - const buildStatus = useFlowStore( - (state) => state.flowBuildStatus[nodeId]?.status, - ); + const buildStatus = useFlowStore((state) => { + if (data.node?.flow && data.node.flow.data?.nodes) { + const flow = data.node.flow; + const nodes = flow.data?.nodes; // check all the build status of the nodes in the flow + const buildStatus_: BuildStatus[] = []; + //@ts-ignore + for (const node of nodes) { + buildStatus_.push(state.flowBuildStatus[node.id]?.status); + } + if (buildStatus_.every((status) => status === BuildStatus.BUILT)) { + return BuildStatus.BUILT; + } + if (buildStatus_.some((status) => status === BuildStatus.BUILDING)) { + return BuildStatus.BUILDING; + } + if (buildStatus_.some((status) => status === BuildStatus.ERROR)) { + return BuildStatus.ERROR; + } else { + return BuildStatus.TO_BUILD; + } + } + return state.flowBuildStatus[nodeId]?.status; + }); const lastRunTime = useFlowStore( - (state) => state.flowBuildStatus[nodeId]?.timestamp, + (state) => state.flowBuildStatus[nodeId_]?.timestamp, ); const iconStatus = useIconStatus(buildStatus, validationStatus); const buildFlow = useFlowStore((state) => state.buildFlow); @@ -62,7 +89,7 @@ export default function NodeStatus({ const flowPool = useFlowStore((state) => state.flowPool); useHotkeys(play, handlePlayWShortcut, { preventDefault: true }); useValidationStatusString(validationStatus, setValidationString); - useUpdateValidationStatus(nodeId, flowPool, setValidationStatus); + useUpdateValidationStatus(nodeId_, flowPool, setValidationStatus); const getBaseBorderClass = (selected) => { let className = selected diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index fc79ebd55..ee2b91258 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -334,6 +334,7 @@ export default function GenericNode({ {showNode && ( { const conditionSuccess = - !(!buildStatus || buildStatus === BuildStatus.TO_BUILD) && - validationStatus && - validationStatus.valid; + buildStatus === BuildStatus.BUILT || + (!(!buildStatus || buildStatus === BuildStatus.TO_BUILD) && + validationStatus && + validationStatus.valid); const conditionError = buildStatus === BuildStatus.ERROR; const conditionInactive = buildStatus === BuildStatus.INACTIVE;