diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index 9c1b4bbf1..09af96c6f 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -1,6 +1,6 @@ import { cloneDeep } from "lodash"; import React, { ReactNode, useEffect, useRef, useState } from "react"; -import { Handle, Position } from "reactflow"; +import { Handle, Position, useUpdateNodeInternals } from "reactflow"; import ShadTooltip from "../../../../components/ShadTooltipComponent"; import CodeAreaComponent from "../../../../components/codeAreaComponent"; import DictComponent from "../../../../components/dictComponent"; @@ -124,27 +124,32 @@ export default function ParameterComponent({ renderTooltips(); }; + const updateNodeInternals = useUpdateNodeInternals(); + const handleNodeClass = (newNodeClass: APIClassType, code?: string): void => { if (!data.node) return; if (data.node!.template[name].value !== code) { takeSnapshot(); } - + + setNode(data.id, (oldNode) => { let newNode = cloneDeep(oldNode); - + newNode.data = { ...newNode.data, node: newNodeClass, description: newNodeClass.description ?? data.node!.description, display_name: newNodeClass.display_name ?? data.node!.display_name, }; - + newNode.data.node.template[name].value = code; - + return newNode; }); - + + updateNodeInternals(data.id); + renderTooltips(); }; @@ -268,6 +273,9 @@ export default function ParameterComponent({ {