diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index d63a91ea2..f50a559c7 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -21,9 +21,11 @@ import { classNames, getRandomKeyByssmm, groupByFamily, - isValidConnection, } from "../../../../utils"; -import { cleanEdges } from "../../../../utils/reactflowUtils"; +import { + cleanEdges, + isValidConnection, +} from "../../../../utils/reactflowUtils"; import { nodeColors, nodeIconsLucide, diff --git a/src/frontend/src/pages/FlowPage/components/PageComponent/index.tsx b/src/frontend/src/pages/FlowPage/components/PageComponent/index.tsx index a6ac0f4df..163e3d9bc 100644 --- a/src/frontend/src/pages/FlowPage/components/PageComponent/index.tsx +++ b/src/frontend/src/pages/FlowPage/components/PageComponent/index.tsx @@ -26,7 +26,7 @@ import { typesContext } from "../../../../contexts/typesContext"; import { undoRedoContext } from "../../../../contexts/undoRedoContext"; import { APIClassType } from "../../../../types/api"; import { FlowType, NodeType } from "../../../../types/flow"; -import { isValidConnection } from "../../../../utils"; +import { isValidConnection } from "../../../../utils/reactflowUtils"; import ConnectionLineComponent from "../ConnectionLineComponent"; import ExtraSidebar from "../extraSidebarComponent"; diff --git a/src/frontend/src/utils.ts b/src/frontend/src/utils.ts index de42e7819..95890a3c8 100644 --- a/src/frontend/src/utils.ts +++ b/src/frontend/src/utils.ts @@ -1,6 +1,6 @@ import clsx, { ClassValue } from "clsx"; import _ from "lodash"; -import { Connection, ReactFlowInstance } from "reactflow"; +import { ReactFlowInstance } from "reactflow"; import { twMerge } from "tailwind-merge"; import { ADJECTIVES, DESCRIPTIONS, NOUNS } from "./flow_constants"; import { APITemplateType } from "./types/api"; @@ -74,48 +74,6 @@ export function roundNumber(x: number, decimals: number) { return Math.round(x * Math.pow(10, decimals)) / Math.pow(10, decimals); } -export function isValidConnection( - { source, target, sourceHandle, targetHandle }: Connection, - reactFlowInstance: ReactFlowInstance -) { - if ( - targetHandle - .split("|")[0] - .split(";") - .some((n) => n === sourceHandle.split("|")[0]) || - sourceHandle - .split("|") - .slice(2) - .some((t) => - targetHandle - .split("|")[0] - .split(";") - .some((n) => n === t) - ) || - targetHandle.split("|")[0] === "str" - ) { - let targetNode = reactFlowInstance?.getNode(target)?.data?.node; - if (!targetNode) { - if ( - !reactFlowInstance - .getEdges() - .find((e) => e.targetHandle === targetHandle) - ) { - return true; - } - } else if ( - (!targetNode.template[targetHandle.split("|")[1]].list && - !reactFlowInstance - .getEdges() - .find((e) => e.targetHandle === targetHandle)) || - targetNode.template[targetHandle.split("|")[1]].list - ) { - return true; - } - } - return false; -} - export function removeApiKeys(flow: FlowType): FlowType { let cleanFLow = _.cloneDeep(flow); cleanFLow.data.nodes.forEach((node) => { diff --git a/src/frontend/src/utils/reactflowUtils.ts b/src/frontend/src/utils/reactflowUtils.ts index 2f9c58343..c5fe272e7 100644 --- a/src/frontend/src/utils/reactflowUtils.ts +++ b/src/frontend/src/utils/reactflowUtils.ts @@ -1,4 +1,5 @@ import _ from "lodash"; +import { Connection, ReactFlowInstance } from "reactflow"; import { cleanEdgesType } from "../types/utils/reactflowUtils"; export function cleanEdges({ @@ -44,3 +45,45 @@ export function cleanEdges({ }); updateEdge(newEdges); } + +export function isValidConnection( + { source, target, sourceHandle, targetHandle }: Connection, + reactFlowInstance: ReactFlowInstance +) { + if ( + targetHandle + .split("|")[0] + .split(";") + .some((n) => n === sourceHandle.split("|")[0]) || + sourceHandle + .split("|") + .slice(2) + .some((t) => + targetHandle + .split("|")[0] + .split(";") + .some((n) => n === t) + ) || + targetHandle.split("|")[0] === "str" + ) { + let targetNode = reactFlowInstance?.getNode(target)?.data?.node; + if (!targetNode) { + if ( + !reactFlowInstance + .getEdges() + .find((e) => e.targetHandle === targetHandle) + ) { + return true; + } + } else if ( + (!targetNode.template[targetHandle.split("|")[1]].list && + !reactFlowInstance + .getEdges() + .find((e) => e.targetHandle === targetHandle)) || + targetNode.template[targetHandle.split("|")[1]].list + ) { + return true; + } + } + return false; +}