fix: Handle group nodes in graph sorting (#3929)

* Fix: Handle group nodes in graph sorting

- Added `get_root_of_group_node` function to identify the root of a group node.
- Updated `sort_up_to_vertex` to use `get_root_of_group_node` for handling group nodes.
- Modified `__filter_vertices` to pass `parent_node_map` to `sort_up_to_vertex`.

* Refactor: Update NodeStatus component to handle group nodes and improve build status handling

* [autofix.ci] apply automated fixes

* Update type hint for parent_node_map in sort_up_to_vertex function

---------

Co-authored-by: anovazzi1 <otavio2204@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-09-27 12:26:20 -03:00 committed by GitHub
commit b34a7c7f02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 75 additions and 14 deletions

View file

@ -1885,7 +1885,10 @@ class Graph:
def __filter_vertices(self, vertex_id: str, is_start: bool = False):
dictionaryized_graph = self.__to_dict()
vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start)
parent_node_map = {vertex.id: vertex.parent_node_id for vertex in self.vertices}
vertex_ids = sort_up_to_vertex(
graph=dictionaryized_graph, vertex_id=vertex_id, parent_node_map=parent_node_map, is_start=is_start
)
return [self.get_vertex(vertex_id) for vertex_id in vertex_ids]
def sort_vertices(

View file

@ -246,17 +246,46 @@ def get_successors(graph: dict[str, dict[str, list[str]]], vertex_id: str) -> li
if current_id in visited:
continue
visited.add(current_id)
successors_result.append(current_id)
if current_id != vertex_id:
successors_result.append(current_id)
stack.extend(graph[current_id]["successors"])
return successors_result
def sort_up_to_vertex(graph: dict[str, dict[str, list[str]]], vertex_id: str, is_start: bool = False) -> list[str]:
def get_root_of_group_node(
graph: dict[str, dict[str, list[str]]], vertex_id: str, parent_node_map: dict[str, str | None]
) -> str:
"""Returns the root of a group node."""
if vertex_id in parent_node_map.values():
# Get all vertices with vertex_id as their parent node
child_vertices = [v_id for v_id, parent_id in parent_node_map.items() if parent_id == vertex_id]
# Now go through successors of the child vertices
# and get the one that none of its successors is in child_vertices
for child_id in child_vertices:
successors = get_successors(graph, child_id)
if not any(successor in child_vertices for successor in successors):
return child_id
raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found")
def sort_up_to_vertex(
graph: dict[str, dict[str, list[str]]],
vertex_id: str,
parent_node_map: dict[str, str | None] | None = None,
is_start: bool = False,
) -> list[str]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
try:
stop_or_start_vertex = graph[vertex_id]
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found into graph")
if parent_node_map is None:
raise ValueError("Parent node map is required to find the root of a group node")
vertex_id = get_root_of_group_node(graph=graph, vertex_id=vertex_id, parent_node_map=parent_node_map)
if vertex_id not in graph:
raise ValueError(f"Vertex {vertex_id} not found into graph")
stop_or_start_vertex = graph[vertex_id]
visited, excluded = set(), set()
stack = [vertex_id]

View file

@ -45,7 +45,7 @@ def test_get_successors_a(graph):
result = utils.get_successors(graph, vertex_id)
assert set(result) == {"A", "B", "D", "E", "F", "H", "G"}
assert set(result) == {"B", "D", "E", "F", "H", "G"}
def test_get_successors_z(graph):
@ -53,7 +53,7 @@ def test_get_successors_z(graph):
result = utils.get_successors(graph, vertex_id)
assert set(result) == {"Z"}
assert len(result) == 0
def test_sort_up_to_vertex_n_is_start(graph):

View file

@ -16,6 +16,8 @@ import { useDarkStore } from "@/stores/darkStore";
import useFlowStore from "@/stores/flowStore";
import { useShortcutsStore } from "@/stores/shortcuts";
import { VertexBuildTypeAPI } from "@/types/api";
import { NodeDataType } from "@/types/flow";
import { findLastNode } from "@/utils/reactflowUtils";
import { classNames } from "@/utils/utils";
import { useEffect, useState } from "react";
import { useHotkeys } from "react-hotkeys-hook";
@ -28,6 +30,7 @@ export default function NodeStatus({
setBorderColor,
frozen,
showNode,
data,
}: {
nodeId: string;
display_name: string;
@ -35,15 +38,39 @@ export default function NodeStatus({
setBorderColor: (color: string) => void;
frozen?: boolean;
showNode: boolean;
data: NodeDataType;
}) {
const nodeId_ = data.node?.flow?.data
? (findLastNode(data.node?.flow.data!)?.id ?? nodeId)
: nodeId;
const [validationString, setValidationString] = useState<string>("");
const [validationStatus, setValidationStatus] =
useState<VertexBuildTypeAPI | null>(null);
const buildStatus = useFlowStore(
(state) => state.flowBuildStatus[nodeId]?.status,
);
const buildStatus = useFlowStore((state) => {
if (data.node?.flow && data.node.flow.data?.nodes) {
const flow = data.node.flow;
const nodes = flow.data?.nodes; // check all the build status of the nodes in the flow
const buildStatus_: BuildStatus[] = [];
//@ts-ignore
for (const node of nodes) {
buildStatus_.push(state.flowBuildStatus[node.id]?.status);
}
if (buildStatus_.every((status) => status === BuildStatus.BUILT)) {
return BuildStatus.BUILT;
}
if (buildStatus_.some((status) => status === BuildStatus.BUILDING)) {
return BuildStatus.BUILDING;
}
if (buildStatus_.some((status) => status === BuildStatus.ERROR)) {
return BuildStatus.ERROR;
} else {
return BuildStatus.TO_BUILD;
}
}
return state.flowBuildStatus[nodeId]?.status;
});
const lastRunTime = useFlowStore(
(state) => state.flowBuildStatus[nodeId]?.timestamp,
(state) => state.flowBuildStatus[nodeId_]?.timestamp,
);
const iconStatus = useIconStatus(buildStatus, validationStatus);
const buildFlow = useFlowStore((state) => state.buildFlow);
@ -62,7 +89,7 @@ export default function NodeStatus({
const flowPool = useFlowStore((state) => state.flowPool);
useHotkeys(play, handlePlayWShortcut, { preventDefault: true });
useValidationStatusString(validationStatus, setValidationString);
useUpdateValidationStatus(nodeId, flowPool, setValidationStatus);
useUpdateValidationStatus(nodeId_, flowPool, setValidationStatus);
const getBaseBorderClass = (selected) => {
let className = selected

View file

@ -334,6 +334,7 @@ export default function GenericNode({
</div>
{showNode && (
<NodeStatus
data={data}
frozen={data.node?.frozen}
showNode={showNode}
display_name={data.node?.display_name!}

View file

@ -10,9 +10,10 @@ const useIconStatus = (
validationStatus: VertexBuildTypeAPI | null,
) => {
const conditionSuccess =
!(!buildStatus || buildStatus === BuildStatus.TO_BUILD) &&
validationStatus &&
validationStatus.valid;
buildStatus === BuildStatus.BUILT ||
(!(!buildStatus || buildStatus === BuildStatus.TO_BUILD) &&
validationStatus &&
validationStatus.valid);
const conditionError = buildStatus === BuildStatus.ERROR;
const conditionInactive = buildStatus === BuildStatus.INACTIVE;