diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index 653248763..e26ae589f 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -15,6 +15,8 @@ import InputFileComponent from "../../../../components/inputFileComponent"; import { TabsContext } from "../../../../contexts/tabsContext"; import IntComponent from "../../../../components/intComponent"; import PromptAreaComponent from "../../../../components/promptComponent"; +import { nodeNames, nodeIcons } from "../../../../utils"; +import React from "react"; export default function ParameterComponent({ left, @@ -28,6 +30,8 @@ export default function ParameterComponent({ required = false, }: ParameterComponentType) { const ref = useRef(null); + const refParent = useRef(""); + const refParentIcon = useRef(null); const updateNodeInternals = useUpdateNodeInternals(); const [position, setPosition] = useState(0); useEffect(() => { @@ -48,6 +52,19 @@ export default function ParameterComponent({ let disabled = reactFlowInstance?.getEdges().some((e) => e.targetHandle === id) ?? false; const { save } = useContext(TabsContext); + const [myData, setMyData] = useState(useContext(typesContext).data); + + useEffect(() => { + Object.keys(myData).forEach((d) => { + let keys = Object.keys(myData[d]).filter( + (nd) => nd.toLowerCase() == data.type.toLowerCase() + ); + if (keys.length > 0) { + refParent.current = d; + refParentIcon.current = nodeIcons[d]; + } + }); + }, []); return (