diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index 56a427ca0..ba5696aed 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -36,6 +36,7 @@ export default function ParameterComponent({ type, name = "", required = false, + optionalHandle = null, }: ParameterComponentType) { const ref = useRef(null); const refHtml = useRef(null); @@ -132,13 +133,14 @@ export default function ParameterComponent({ {required ? " *" : ""} {left && - (type === "str" || + ((type === "str" || type === "bool" || type === "float" || type === "code" || type === "prompt" || type === "file" || - type === "int") ? ( + type === "int") && !optionalHandle + ) ? ( <> ) : ( ) : ( <> diff --git a/src/frontend/src/types/api/index.ts b/src/frontend/src/types/api/index.ts index ecd79240e..d54a8165c 100644 --- a/src/frontend/src/types/api/index.ts +++ b/src/frontend/src/types/api/index.ts @@ -12,6 +12,7 @@ export type APIClassType = { description: string; template: APITemplateType; display_name: string; + input_types?: Array; [key: string]: Array | string | APITemplateType; }; diff --git a/src/frontend/src/types/components/index.ts b/src/frontend/src/types/components/index.ts index 090d3a49f..85f3ec1cd 100644 --- a/src/frontend/src/types/components/index.ts +++ b/src/frontend/src/types/components/index.ts @@ -36,6 +36,7 @@ export type ParameterComponentType = { name?: string; tooltipTitle: string; dataContext?: typesContextType; + optionalHandle?: Array; }; export type InputListComponentType = { value: string[]; diff --git a/src/frontend/src/utils.ts b/src/frontend/src/utils.ts index 6ff05292f..797f3d93e 100644 --- a/src/frontend/src/utils.ts +++ b/src/frontend/src/utils.ts @@ -141,6 +141,7 @@ export const nodeColors: { [char: string]: string } = { wrappers: "#E6277A", utilities: "#31A3CC", output_parsers: "#E6A627", + str: "#049524", unknown: "#9CA3AF", }; @@ -631,11 +632,11 @@ export function isValidConnection( reactFlowInstance: ReactFlowInstance ) { if ( - sourceHandle.split("|")[0] === targetHandle.split("|")[0] || + targetHandle.split("|")[0].split(";").some((n) => n === sourceHandle.split("|")[0]) || sourceHandle .split("|") .slice(2) - .some((t) => t === targetHandle.split("|")[0]) || + .some((t) => targetHandle.split("|")[0].split(";").some((n) => n === t)) || targetHandle.split("|")[0] === "str" ) { let targetNode = reactFlowInstance.getNode(target).data.node;