From 83b3347fc5ba6f80d17383f22861ab750f6d3e3e Mon Sep 17 00:00:00 2001 From: anovazzi1 Date: Tue, 2 May 2023 18:46:29 -0300 Subject: [PATCH 1/5] fixed AutoUpdate for each node --- src/frontend/src/contexts/tabsContext.tsx | 32 +- src/frontend/src/utils.ts | 709 +++++++++++----------- 2 files changed, 380 insertions(+), 361 deletions(-) diff --git a/src/frontend/src/contexts/tabsContext.tsx b/src/frontend/src/contexts/tabsContext.tsx index 8a77107ed..1d79c8282 100644 --- a/src/frontend/src/contexts/tabsContext.tsx +++ b/src/frontend/src/contexts/tabsContext.tsx @@ -8,10 +8,10 @@ import { } from "react"; import { FlowType } from "../types/flow"; import { LangFlowState, TabsContextType } from "../types/tabs"; -import { normalCaseToSnakeCase, updateObject } from "../utils"; +import { normalCaseToSnakeCase, updateObject, updateTemplate } from "../utils"; import { alertContext } from "./alertContext"; import { typesContext } from "./typesContext"; -import { TemplateVariableType } from "../types/api"; +import { APITemplateType, TemplateVariableType } from "../types/api"; const { v4: uuidv4 } = require("uuid"); const TabsContextInitialValue: TabsContextType = { @@ -64,12 +64,12 @@ export function TabsProvider({ children }: { children: ReactNode }) { cookieObject.flows.forEach((flow) => { flow.data.nodes.forEach((node) => { if (Object.keys(templates[node.data.type]["template"]).length > 0) { - node.data.node.template = updateObject( + node.data.node.template = updateTemplate( templates[node.data.type][ "template" - ] as unknown as TemplateVariableType, + ] as unknown as APITemplateType, - node.data.node.template as TemplateVariableType + node.data.node.template as APITemplateType ); } }); @@ -127,16 +127,6 @@ export function TabsProvider({ children }: { children: ReactNode }) { file.text().then((text) => { // parse the text into a JSON object let flow: FlowType = JSON.parse(text); - flow.data.nodes.forEach((node) => { - if (Object.keys(templates[node.data.type]["template"]).length > 0) { - node.data.node.template = updateObject( - templates[node.data.type][ - "template" - ] as unknown as TemplateVariableType, - node.data.node.template as TemplateVariableType - ); - } - }); addFlow(flow); }); @@ -176,6 +166,18 @@ export function TabsProvider({ children }: { children: ReactNode }) { const data = flow?.data ? flow.data : null; const description = flow?.description ? flow.description : ""; + if(data){ + data.nodes.forEach((node) => { + if (Object.keys(templates[node.data.type]["template"]).length > 0) { + node.data.node.template = updateTemplate( + templates[node.data.type][ + "template" + ] as unknown as APITemplateType, + node.data.node.template as APITemplateType + ); + } + }); + } // Create a new flow with a default name if no flow is provided. let newFlow: FlowType = { description, diff --git a/src/frontend/src/utils.ts b/src/frontend/src/utils.ts index 9e5bceab6..5890e0118 100644 --- a/src/frontend/src/utils.ts +++ b/src/frontend/src/utils.ts @@ -1,431 +1,448 @@ import { - RocketLaunchIcon, - LinkIcon, - CpuChipIcon, - LightBulbIcon, - CommandLineIcon, - WrenchScrewdriverIcon, - WrenchIcon, - ComputerDesktopIcon, - Bars3CenterLeftIcon, - GiftIcon, - PaperClipIcon, - QuestionMarkCircleIcon, - FingerPrintIcon, - ScissorsIcon, - CircleStackIcon, - Squares2X2Icon, + RocketLaunchIcon, + LinkIcon, + CpuChipIcon, + LightBulbIcon, + CommandLineIcon, + WrenchScrewdriverIcon, + WrenchIcon, + ComputerDesktopIcon, + Bars3CenterLeftIcon, + GiftIcon, + PaperClipIcon, + QuestionMarkCircleIcon, + FingerPrintIcon, + ScissorsIcon, + CircleStackIcon, + Squares2X2Icon, } from "@heroicons/react/24/outline"; import { Connection, Edge, Node, ReactFlowInstance } from "reactflow"; import { FlowType } from "./types/flow"; +import { APITemplateType, TemplateVariableType } from "./types/api"; var _ = require("lodash"); export function classNames(...classes: Array) { - return classes.filter(Boolean).join(" "); + return classes.filter(Boolean).join(" "); } export const textColors = { - white: "text-white", - red: "text-red-700", - orange: "text-orange-700", - amber: "text-amber-700", - yellow: "text-yellow-700", - lime: "text-lime-700", - green: "text-green-700", - emerald: "text-emerald-700", - teal: "text-teal-700", - cyan: "text-cyan-700", - sky: "text-sky-700", - blue: "text-blue-700", - indigo: "text-indigo-700", - violet: "text-violet-700", - purple: "text-purple-700", - fuchsia: "text-fuchsia-700", - pink: "text-pink-700", - rose: "text-rose-700", - black: "text-black-700", - gray: "text-gray-700", + white: "text-white", + red: "text-red-700", + orange: "text-orange-700", + amber: "text-amber-700", + yellow: "text-yellow-700", + lime: "text-lime-700", + green: "text-green-700", + emerald: "text-emerald-700", + teal: "text-teal-700", + cyan: "text-cyan-700", + sky: "text-sky-700", + blue: "text-blue-700", + indigo: "text-indigo-700", + violet: "text-violet-700", + purple: "text-purple-700", + fuchsia: "text-fuchsia-700", + pink: "text-pink-700", + rose: "text-rose-700", + black: "text-black-700", + gray: "text-gray-700", }; export const borderLColors = { - white: "border-l-white", - red: "border-l-red-500", - orange: "border-l-orange-500", - amber: "border-l-amber-500", - yellow: "border-l-yellow-500", - lime: "border-l-lime-500", - green: "border-l-green-500", - emerald: "border-l-emerald-500", - teal: "border-l-teal-500", - cyan: "border-l-cyan-500", - sky: "border-l-sky-500", - blue: "border-l-blue-500", - indigo: "border-l-indigo-500", - violet: "border-l-violet-500", - purple: "border-l-purple-500", - fuchsia: "border-l-fuchsia-500", - pink: "border-l-pink-500", - rose: "border-l-rose-500", - black: "border-l-black-500", - gray: "border-l-gray-500", + white: "border-l-white", + red: "border-l-red-500", + orange: "border-l-orange-500", + amber: "border-l-amber-500", + yellow: "border-l-yellow-500", + lime: "border-l-lime-500", + green: "border-l-green-500", + emerald: "border-l-emerald-500", + teal: "border-l-teal-500", + cyan: "border-l-cyan-500", + sky: "border-l-sky-500", + blue: "border-l-blue-500", + indigo: "border-l-indigo-500", + violet: "border-l-violet-500", + purple: "border-l-purple-500", + fuchsia: "border-l-fuchsia-500", + pink: "border-l-pink-500", + rose: "border-l-rose-500", + black: "border-l-black-500", + gray: "border-l-gray-500", }; export const nodeColors: { [char: string]: string } = { - prompts: "#4367BF", - llms: "#6344BE", - chains: "#FE7500", - agents: "#903BBE", - tools: "#FF3434", - memories: "#F5B85A", - advanced: "#000000", - chat: "#198BF6", - thought: "#272541", - embeddings: "#42BAA7", - documentloaders: "#7AAE42", - vectorstores: "#AA8742", - textsplitters: "#B47CB5", - toolkits: "#DB2C2C", - wrappers: "#E6277A", - utilities: "#31A3CC", - unknown: "#9CA3AF", + prompts: "#4367BF", + llms: "#6344BE", + chains: "#FE7500", + agents: "#903BBE", + tools: "#FF3434", + memories: "#F5B85A", + advanced: "#000000", + chat: "#198BF6", + thought: "#272541", + embeddings: "#42BAA7", + documentloaders: "#7AAE42", + vectorstores: "#AA8742", + textsplitters: "#B47CB5", + toolkits: "#DB2C2C", + wrappers: "#E6277A", + utilities: "#31A3CC", + unknown: "#9CA3AF", }; export const nodeNames: { [char: string]: string } = { - prompts: "Prompts", - llms: "LLMs", - chains: "Chains", - agents: "Agents", - tools: "Tools", - memories: "Memories", - advanced: "Advanced", - chat: "Chat", - embeddings: "Embeddings", - documentloaders: "Document Loaders", - vectorstores: "Vector Stores", - toolkits: "Toolkits", - wrappers: "Wrappers", - textsplitters: "Text Splitters", - utilities: "Utilities", - unknown: "Unknown", + prompts: "Prompts", + llms: "LLMs", + chains: "Chains", + agents: "Agents", + tools: "Tools", + memories: "Memories", + advanced: "Advanced", + chat: "Chat", + embeddings: "Embeddings", + documentloaders: "Document Loaders", + vectorstores: "Vector Stores", + toolkits: "Toolkits", + wrappers: "Wrappers", + textsplitters: "Text Splitters", + utilities: "Utilities", + unknown: "Unknown", }; export const nodeIcons: { - [char: string]: React.ForwardRefExoticComponent< - React.SVGProps - >; + [char: string]: React.ForwardRefExoticComponent< + React.SVGProps + >; } = { - agents: RocketLaunchIcon, - chains: LinkIcon, - memories: CpuChipIcon, - llms: LightBulbIcon, - prompts: CommandLineIcon, - tools: WrenchIcon, - advanced: ComputerDesktopIcon, - chat: Bars3CenterLeftIcon, - embeddings: FingerPrintIcon, - documentloaders: PaperClipIcon, - vectorstores: CircleStackIcon, - toolkits: WrenchScrewdriverIcon, - textsplitters: ScissorsIcon, - wrappers: GiftIcon, - utilities: Squares2X2Icon, - unknown: QuestionMarkCircleIcon, + agents: RocketLaunchIcon, + chains: LinkIcon, + memories: CpuChipIcon, + llms: LightBulbIcon, + prompts: CommandLineIcon, + tools: WrenchIcon, + advanced: ComputerDesktopIcon, + chat: Bars3CenterLeftIcon, + embeddings: FingerPrintIcon, + documentloaders: PaperClipIcon, + vectorstores: CircleStackIcon, + toolkits: WrenchScrewdriverIcon, + textsplitters: ScissorsIcon, + wrappers: GiftIcon, + utilities: Squares2X2Icon, + unknown: QuestionMarkCircleIcon, }; export const bgColors = { - white: "bg-white", - red: "bg-red-100", - orange: "bg-orange-100", - amber: "bg-amber-100", - yellow: "bg-yellow-100", - lime: "bg-lime-100", - green: "bg-green-100", - emerald: "bg-emerald-100", - teal: "bg-teal-100", - cyan: "bg-cyan-100", - sky: "bg-sky-100", - blue: "bg-blue-100", - indigo: "bg-indigo-100", - violet: "bg-violet-100", - purple: "bg-purple-100", - fuchsia: "bg-fuchsia-100", - pink: "bg-pink-100", - rose: "bg-rose-100", - black: "bg-black-100", - gray: "bg-gray-100", + white: "bg-white", + red: "bg-red-100", + orange: "bg-orange-100", + amber: "bg-amber-100", + yellow: "bg-yellow-100", + lime: "bg-lime-100", + green: "bg-green-100", + emerald: "bg-emerald-100", + teal: "bg-teal-100", + cyan: "bg-cyan-100", + sky: "bg-sky-100", + blue: "bg-blue-100", + indigo: "bg-indigo-100", + violet: "bg-violet-100", + purple: "bg-purple-100", + fuchsia: "bg-fuchsia-100", + pink: "bg-pink-100", + rose: "bg-rose-100", + black: "bg-black-100", + gray: "bg-gray-100", }; export const bgColorsHover = { - white: "hover:bg-white", - black: "hover:bg-black-50", - gray: "hover:bg-gray-50", - red: "hover:bg-red-50", - orange: "hover:bg-orange-50", - amber: "hover:bg-amber-50", - yellow: "hover:bg-yellow-50", - lime: "hover:bg-lime-50", - green: "hover:bg-green-50", - emerald: "hover:bg-emerald-50", - teal: "hover:bg-teal-50", - cyan: "hover:bg-cyan-50", - sky: "hover:bg-sky-50", - blue: "hover:bg-blue-50", - indigo: "hover:bg-indigo-50", - violet: "hover:bg-violet-50", - purple: "hover:bg-purple-50", - fuchsia: "hover:bg-fuchsia-50", - pink: "hover:bg-pink-50", - rose: "hover:bg-rose-50", + white: "hover:bg-white", + black: "hover:bg-black-50", + gray: "hover:bg-gray-50", + red: "hover:bg-red-50", + orange: "hover:bg-orange-50", + amber: "hover:bg-amber-50", + yellow: "hover:bg-yellow-50", + lime: "hover:bg-lime-50", + green: "hover:bg-green-50", + emerald: "hover:bg-emerald-50", + teal: "hover:bg-teal-50", + cyan: "hover:bg-cyan-50", + sky: "hover:bg-sky-50", + blue: "hover:bg-blue-50", + indigo: "hover:bg-indigo-50", + violet: "hover:bg-violet-50", + purple: "hover:bg-purple-50", + fuchsia: "hover:bg-fuchsia-50", + pink: "hover:bg-pink-50", + rose: "hover:bg-rose-50", }; export const textColorsHex = { - red: "rgb(185 28 28)", - orange: "rgb(194 65 12)", - amber: "rgb(180 83 9)", - yellow: "rgb(161 98 7)", - lime: "rgb(77 124 15)", - green: "rgb(21 128 61)", - emerald: "rgb(4 120 87)", - teal: "rgb(15 118 110)", - cyan: "rgb(14 116 144)", - sky: "rgb(3 105 161)", - blue: "rgb(29 78 216)", - indigo: "rgb(67 56 202)", - violet: "rgb(109 40 217)", - purple: "rgb(126 34 206)", - fuchsia: "rgb(162 28 175)", - pink: "rgb(190 24 93)", - rose: "rgb(190 18 60)", + red: "rgb(185 28 28)", + orange: "rgb(194 65 12)", + amber: "rgb(180 83 9)", + yellow: "rgb(161 98 7)", + lime: "rgb(77 124 15)", + green: "rgb(21 128 61)", + emerald: "rgb(4 120 87)", + teal: "rgb(15 118 110)", + cyan: "rgb(14 116 144)", + sky: "rgb(3 105 161)", + blue: "rgb(29 78 216)", + indigo: "rgb(67 56 202)", + violet: "rgb(109 40 217)", + purple: "rgb(126 34 206)", + fuchsia: "rgb(162 28 175)", + pink: "rgb(190 24 93)", + rose: "rgb(190 18 60)", }; export const bgColorsHex = { - red: "rgb(254 226 226)", - orange: "rgb(255 237 213)", - amber: "rgb(254 243 199)", - yellow: "rgb(254 249 195)", - lime: "rgb(236 252 203)", - green: "rgb(220 252 231)", - emerald: "rgb(209 250 229)", - teal: "rgb(204 251 241)", - cyan: "rgb(207 250 254)", - sky: "rgb(224 242 254)", - blue: "rgb(219 234 254)", - indigo: "rgb(224 231 255)", - violet: "rgb(237 233 254)", - purple: "rgb(243 232 255)", - fuchsia: "rgb(250 232 255)", - pink: "rgb(252 231 243)", - rose: "rgb(255 228 230)", + red: "rgb(254 226 226)", + orange: "rgb(255 237 213)", + amber: "rgb(254 243 199)", + yellow: "rgb(254 249 195)", + lime: "rgb(236 252 203)", + green: "rgb(220 252 231)", + emerald: "rgb(209 250 229)", + teal: "rgb(204 251 241)", + cyan: "rgb(207 250 254)", + sky: "rgb(224 242 254)", + blue: "rgb(219 234 254)", + indigo: "rgb(224 231 255)", + violet: "rgb(237 233 254)", + purple: "rgb(243 232 255)", + fuchsia: "rgb(250 232 255)", + pink: "rgb(252 231 243)", + rose: "rgb(255 228 230)", }; export const taskTypeMap: { [key: string]: string } = { - MULTICLASS_CLASSIFICATION: "Multiclass Classification", + MULTICLASS_CLASSIFICATION: "Multiclass Classification", }; const charWidths: { [char: string]: number } = { - " ": 0.2, - "!": 0.2, - '"': 0.3, - "#": 0.5, - $: 0.5, - "%": 0.5, - "&": 0.5, - "(": 0.2, - ")": 0.2, - "*": 0.5, - "+": 0.5, - ",": 0.2, - "-": 0.2, - ".": 0.1, - "/": 0.5, - ":": 0.2, - ";": 0.2, - "<": 0.5, - "=": 0.5, - ">": 0.5, - "?": 0.2, - "@": 0.5, - "[": 0.2, - "\\": 0.5, - "]": 0.2, - "^": 0.5, - _: 0.2, - "`": 0.5, - "{": 0.2, - "|": 0.2, - "}": 0.2, - "~": 0.5, + " ": 0.2, + "!": 0.2, + '"': 0.3, + "#": 0.5, + $: 0.5, + "%": 0.5, + "&": 0.5, + "(": 0.2, + ")": 0.2, + "*": 0.5, + "+": 0.5, + ",": 0.2, + "-": 0.2, + ".": 0.1, + "/": 0.5, + ":": 0.2, + ";": 0.2, + "<": 0.5, + "=": 0.5, + ">": 0.5, + "?": 0.2, + "@": 0.5, + "[": 0.2, + "\\": 0.5, + "]": 0.2, + "^": 0.5, + _: 0.2, + "`": 0.5, + "{": 0.2, + "|": 0.2, + "}": 0.2, + "~": 0.5, }; for (let i = 65; i <= 90; i++) { - charWidths[String.fromCharCode(i)] = 0.6; + charWidths[String.fromCharCode(i)] = 0.6; } for (let i = 97; i <= 122; i++) { - charWidths[String.fromCharCode(i)] = 0.5; + charWidths[String.fromCharCode(i)] = 0.5; } export function measureTextWidth(text: string, fontSize: number) { - let wordWidth = 0; - for (let j = 0; j < text.length; j++) { - let char = text[j]; - let charWidth = charWidths[char] || 0.5; - wordWidth += charWidth * fontSize; - } - return wordWidth; + let wordWidth = 0; + for (let j = 0; j < text.length; j++) { + let char = text[j]; + let charWidth = charWidths[char] || 0.5; + wordWidth += charWidth * fontSize; + } + return wordWidth; } export function measureTextHeight( - text: string, - width: number, - fontSize: number + text: string, + width: number, + fontSize: number ) { - const charHeight = fontSize; - const lineHeight = charHeight * 1.5; - const words = text.split(" "); - let lineWidth = 0; - let totalHeight = 0; - for (let i = 0; i < words.length; i++) { - let word = words[i]; - let wordWidth = measureTextWidth(word, fontSize); - if (lineWidth + wordWidth + charWidths[" "] * fontSize <= width) { - lineWidth += wordWidth + charWidths[" "] * fontSize; - } else { - totalHeight += lineHeight; - lineWidth = wordWidth; - } - } - totalHeight += lineHeight; - return totalHeight; + const charHeight = fontSize; + const lineHeight = charHeight * 1.5; + const words = text.split(" "); + let lineWidth = 0; + let totalHeight = 0; + for (let i = 0; i < words.length; i++) { + let word = words[i]; + let wordWidth = measureTextWidth(word, fontSize); + if (lineWidth + wordWidth + charWidths[" "] * fontSize <= width) { + lineWidth += wordWidth + charWidths[" "] * fontSize; + } else { + totalHeight += lineHeight; + lineWidth = wordWidth; + } + } + totalHeight += lineHeight; + return totalHeight; } export function toCamelCase(str: string) { - return str - .split(" ") - .map((word, index) => - index === 0 - ? word.toLowerCase() - : word[0].toUpperCase() + word.slice(1).toLowerCase() - ) - .join(""); + return str + .split(" ") + .map((word, index) => + index === 0 + ? word.toLowerCase() + : word[0].toUpperCase() + word.slice(1).toLowerCase() + ) + .join(""); } export function toFirstUpperCase(str: string) { - return str - .split(" ") - .map((word, index) => word[0].toUpperCase() + word.slice(1).toLowerCase()) - .join(""); + return str + .split(" ") + .map((word, index) => word[0].toUpperCase() + word.slice(1).toLowerCase()) + .join(""); } export function snakeToNormalCase(str: string) { - return str - .split("_") - .map((word, index) => { - if (index === 0) { - return word[0].toUpperCase() + word.slice(1).toLowerCase(); - } - return word.toLowerCase(); - }) - .join(" "); + return str + .split("_") + .map((word, index) => { + if (index === 0) { + return word[0].toUpperCase() + word.slice(1).toLowerCase(); + } + return word.toLowerCase(); + }) + .join(" "); } export function normalCaseToSnakeCase(str: string) { - return str - .split(" ") - .map((word, index) => { - if (index === 0) { - return word[0].toUpperCase() + word.slice(1).toLowerCase(); - } - return word.toLowerCase(); - }) - .join("_"); + return str + .split(" ") + .map((word, index) => { + if (index === 0) { + return word[0].toUpperCase() + word.slice(1).toLowerCase(); + } + return word.toLowerCase(); + }) + .join("_"); } export function roundNumber(x: number, decimals: number) { - return Math.round(x * Math.pow(10, decimals)) / Math.pow(10, decimals); + return Math.round(x * Math.pow(10, decimals)) / Math.pow(10, decimals); } export function getConnectedNodes(edge: Edge, nodes: Array): Array { - const sourceId = edge.source; - const targetId = edge.target; - const connectedNodes = nodes.filter( - (node) => node.id === targetId || node.id === sourceId - ); - return connectedNodes; + const sourceId = edge.source; + const targetId = edge.target; + const connectedNodes = nodes.filter( + (node) => node.id === targetId || node.id === sourceId + ); + return connectedNodes; } export function isValidConnection( - { source, target, sourceHandle, targetHandle }: Connection, - reactFlowInstance: ReactFlowInstance + { source, target, sourceHandle, targetHandle }: Connection, + reactFlowInstance: ReactFlowInstance ) { - if ( - sourceHandle.split("|")[0] === targetHandle.split("|")[0] || - sourceHandle - .split("|") - .slice(2) - .some((t) => t === targetHandle.split("|")[0]) || - targetHandle.split("|")[0] === "str" - ) { - let targetNode = reactFlowInstance.getNode(target).data.node; - if (!targetNode) { - if ( - !reactFlowInstance - .getEdges() - .find((e) => e.targetHandle === targetHandle) - ) { - return true; - } - } else if ( - (!targetNode.template[targetHandle.split("|")[1]].list && - !reactFlowInstance - .getEdges() - .find((e) => e.targetHandle === targetHandle)) || - targetNode.template[targetHandle.split("|")[1]].list - ) { - return true; - } - } - return false; + if ( + sourceHandle.split("|")[0] === targetHandle.split("|")[0] || + sourceHandle + .split("|") + .slice(2) + .some((t) => t === targetHandle.split("|")[0]) || + targetHandle.split("|")[0] === "str" + ) { + let targetNode = reactFlowInstance.getNode(target).data.node; + if (!targetNode) { + if ( + !reactFlowInstance + .getEdges() + .find((e) => e.targetHandle === targetHandle) + ) { + return true; + } + } else if ( + (!targetNode.template[targetHandle.split("|")[1]].list && + !reactFlowInstance + .getEdges() + .find((e) => e.targetHandle === targetHandle)) || + targetNode.template[targetHandle.split("|")[1]].list + ) { + return true; + } + } + return false; } export function removeApiKeys(flow: FlowType): FlowType { - let cleanFLow = _.cloneDeep(flow); - cleanFLow.data.nodes.forEach((node) => { - for (const key in node.data.node.template) { - if (key.includes("api")) { - console.log(node.data.node.template[key]); - node.data.node.template[key].value = ""; - } - } - }); - return cleanFLow; + let cleanFLow = _.cloneDeep(flow); + cleanFLow.data.nodes.forEach((node) => { + for (const key in node.data.node.template) { + if (key.includes("api")) { + console.log(node.data.node.template[key]); + node.data.node.template[key].value = ""; + } + } + }); + return cleanFLow; } export function updateObject>( - reference: T, - objectToUpdate: T + reference: T, + objectToUpdate: T ): T { - let clonedObject = _.cloneDeep(objectToUpdate); - // Loop through each key in the object to update - for (const key in clonedObject) { - // If the key is not in the reference object, delete it - if (!(key in reference)) { - delete clonedObject[key]; - } - } - // Loop through each key in the reference object - for (const key in reference) { - // If the key is not in the object to update, add it - if (!(key in clonedObject)) { - clonedObject[key] = reference[key]; - } - } - return clonedObject; + let clonedObject = _.cloneDeep(objectToUpdate); + // Loop through each key in the object to update + for (const key in clonedObject) { + // If the key is not in the reference object, delete it + if (!(key in reference)) { + delete clonedObject[key]; + } + } + // Loop through each key in the reference object + for (const key in reference) { + // If the key is not in the object to update, add it + if (!(key in clonedObject)) { + clonedObject[key] = reference[key]; + } + } + return clonedObject; } export function debounce(func, wait) { - let timeout; - return function (...args) { - const context = this; - clearTimeout(timeout); - timeout = setTimeout(() => func.apply(context, args), wait); - }; + let timeout; + return function (...args) { + const context = this; + clearTimeout(timeout); + timeout = setTimeout(() => func.apply(context, args), wait); + }; +} + +export function updateTemplate( + reference: APITemplateType, + objectToUpdate: APITemplateType +): APITemplateType { + let clonedObject:APITemplateType = _.cloneDeep(reference); + + // Loop through each key in the reference object + for (const key in clonedObject) { + // If the key is not in the object to update, add it + if (objectToUpdate[key] && objectToUpdate[key].value) { + clonedObject[key].value = objectToUpdate[key].value; + } + } + return clonedObject; } From a9d8c6fcb3b35ff166c5aedf1615e4fcdc5b369a Mon Sep 17 00:00:00 2001 From: anovazzi1 Date: Tue, 2 May 2023 18:47:22 -0300 Subject: [PATCH 2/5] fomrated code --- src/frontend/src/contexts/tabsContext.tsx | 424 +++++++++++----------- 1 file changed, 211 insertions(+), 213 deletions(-) diff --git a/src/frontend/src/contexts/tabsContext.tsx b/src/frontend/src/contexts/tabsContext.tsx index 1d79c8282..6dfa37324 100644 --- a/src/frontend/src/contexts/tabsContext.tsx +++ b/src/frontend/src/contexts/tabsContext.tsx @@ -1,10 +1,10 @@ import { - createContext, - useEffect, - useState, - useRef, - ReactNode, - useContext, + createContext, + useEffect, + useState, + useRef, + ReactNode, + useContext, } from "react"; import { FlowType } from "../types/flow"; import { LangFlowState, TabsContextType } from "../types/tabs"; @@ -15,223 +15,221 @@ import { APITemplateType, TemplateVariableType } from "../types/api"; const { v4: uuidv4 } = require("uuid"); const TabsContextInitialValue: TabsContextType = { - save: () => {}, - tabIndex: 0, - setTabIndex: (index: number) => {}, - flows: [], - removeFlow: (id: string) => {}, - addFlow: (flowData?: any) => {}, - updateFlow: (newFlow: FlowType) => {}, - incrementNodeId: () => 0, - downloadFlow: (flow: FlowType) => {}, - uploadFlow: () => {}, - hardReset: () => {}, + save: () => {}, + tabIndex: 0, + setTabIndex: (index: number) => {}, + flows: [], + removeFlow: (id: string) => {}, + addFlow: (flowData?: any) => {}, + updateFlow: (newFlow: FlowType) => {}, + incrementNodeId: () => 0, + downloadFlow: (flow: FlowType) => {}, + uploadFlow: () => {}, + hardReset: () => {}, }; export const TabsContext = createContext( - TabsContextInitialValue + TabsContextInitialValue ); export function TabsProvider({ children }: { children: ReactNode }) { - const { setNoticeData } = useContext(alertContext); - const [tabIndex, setTabIndex] = useState(0); - const [flows, setFlows] = useState>([]); - const [id, setId] = useState(""); - const { templates } = useContext(typesContext); + const { setNoticeData } = useContext(alertContext); + const [tabIndex, setTabIndex] = useState(0); + const [flows, setFlows] = useState>([]); + const [id, setId] = useState(""); + const { templates } = useContext(typesContext); - const newNodeId = useRef(0); - function incrementNodeId() { - newNodeId.current = newNodeId.current + 1; - return newNodeId.current; - } - function save() { - if (flows.length !== 0) - window.localStorage.setItem( - "tabsData", - JSON.stringify({ tabIndex, flows, id, nodeId: newNodeId.current }) - ); - } - useEffect(() => { - //save tabs locally - save(); - }, [flows, id, tabIndex, newNodeId]); - - useEffect(() => { - //get tabs locally saved - let cookie = window.localStorage.getItem("tabsData"); - if (cookie && Object.keys(templates).length > 0) { - let cookieObject: LangFlowState = JSON.parse(cookie); - cookieObject.flows.forEach((flow) => { - flow.data.nodes.forEach((node) => { - if (Object.keys(templates[node.data.type]["template"]).length > 0) { - node.data.node.template = updateTemplate( - templates[node.data.type][ - "template" - ] as unknown as APITemplateType, - - node.data.node.template as APITemplateType - ); - } - }); - }); - setTabIndex(cookieObject.tabIndex); - setFlows(cookieObject.flows); - setId(cookieObject.id); - newNodeId.current = cookieObject.nodeId; - } - }, [templates]); - function hardReset() { - newNodeId.current = 0; - setTabIndex(0); - setFlows([]); - setId(""); - } - - /** - * Downloads the current flow as a JSON file - */ - function downloadFlow(flow: FlowType) { - // create a data URI with the current flow data - const jsonString = `data:text/json;chatset=utf-8,${encodeURIComponent( - JSON.stringify(flow) - )}`; - - // create a link element and set its properties - const link = document.createElement("a"); - link.href = jsonString; - link.download = `${normalCaseToSnakeCase(flows[tabIndex].name)}.json`; - - // simulate a click on the link element to trigger the download - link.click(); - setNoticeData({ - title: "Warning: Critical data,JSON file may including API keys.", - }); - } - - /** - * Creates a file input and listens to a change event to upload a JSON flow file. - * If the file type is application/json, the file is read and parsed into a JSON object. - * The resulting JSON object is passed to the addFlow function. - */ - function uploadFlow() { - // create a file input - const input = document.createElement("input"); - input.type = "file"; - // add a change event listener to the file input - input.onchange = (e: Event) => { - // check if the file type is application/json - if ((e.target as HTMLInputElement).files[0].type === "application/json") { - // get the file from the file input - const file = (e.target as HTMLInputElement).files[0]; - // read the file as text - file.text().then((text) => { - // parse the text into a JSON object - let flow: FlowType = JSON.parse(text); - - addFlow(flow); - }); - } - }; - // trigger the file input click event to open the file dialog - input.click(); - } - /** - * Removes a flow from an array of flows based on its id. - * Updates the state of flows and tabIndex using setFlows and setTabIndex hooks. - * @param {string} id - The id of the flow to remove. - */ - function removeFlow(id: string) { - setFlows((prevState) => { - const newFlows = [...prevState]; - const index = newFlows.findIndex((flow) => flow.id === id); - if (index >= 0) { - if (index === tabIndex) { - setTabIndex(flows.length - 2); - newFlows.splice(index, 1); - } else { - let flowId = flows[tabIndex].id; - newFlows.splice(index, 1); - setTabIndex(newFlows.findIndex((flow) => flow.id === flowId)); - } - } - return newFlows; - }); - } - /** - * Add a new flow to the list of flows. - * @param flow Optional flow to add. - */ - function addFlow(flow?: FlowType) { - // Get data from the flow or set it to null if there's no flow provided. - const data = flow?.data ? flow.data : null; - const description = flow?.description ? flow.description : ""; - - if(data){ - data.nodes.forEach((node) => { - if (Object.keys(templates[node.data.type]["template"]).length > 0) { - node.data.node.template = updateTemplate( - templates[node.data.type][ - "template" - ] as unknown as APITemplateType, - node.data.node.template as APITemplateType - ); - } - }); + const newNodeId = useRef(0); + function incrementNodeId() { + newNodeId.current = newNodeId.current + 1; + return newNodeId.current; } - // Create a new flow with a default name if no flow is provided. - let newFlow: FlowType = { - description, - name: flow?.name ?? "New Flow", - id: id.toString(), - data, - }; + function save() { + if (flows.length !== 0) + window.localStorage.setItem( + "tabsData", + JSON.stringify({ tabIndex, flows, id, nodeId: newNodeId.current }) + ); + } + useEffect(() => { + //save tabs locally + save(); + }, [flows, id, tabIndex, newNodeId]); - // Increment the ID counter. - setId(uuidv4()); + useEffect(() => { + //get tabs locally saved + let cookie = window.localStorage.getItem("tabsData"); + if (cookie && Object.keys(templates).length > 0) { + let cookieObject: LangFlowState = JSON.parse(cookie); + cookieObject.flows.forEach((flow) => { + flow.data.nodes.forEach((node) => { + if (Object.keys(templates[node.data.type]["template"]).length > 0) { + node.data.node.template = updateTemplate( + templates[node.data.type][ + "template" + ] as unknown as APITemplateType, - // Add the new flow to the list of flows. - setFlows((prevState) => { - const newFlows = [...prevState, newFlow]; - return newFlows; - }); + node.data.node.template as APITemplateType + ); + } + }); + }); + setTabIndex(cookieObject.tabIndex); + setFlows(cookieObject.flows); + setId(cookieObject.id); + newNodeId.current = cookieObject.nodeId; + } + }, [templates]); + function hardReset() { + newNodeId.current = 0; + setTabIndex(0); + setFlows([]); + setId(""); + } - // Set the tab index to the new flow. - setTabIndex(flows.length); - } - /** - * Updates an existing flow with new data - * @param newFlow - The new flow object containing the updated data - */ - function updateFlow(newFlow: FlowType) { - setFlows((prevState) => { - const newFlows = [...prevState]; - const index = newFlows.findIndex((flow) => flow.id === newFlow.id); - if (index !== -1) { - newFlows[index].description = newFlow.description ?? ""; - newFlows[index].data = newFlow.data; - newFlows[index].name = newFlow.name; - } - return newFlows; - }); - } + /** + * Downloads the current flow as a JSON file + */ + function downloadFlow(flow: FlowType) { + // create a data URI with the current flow data + const jsonString = `data:text/json;chatset=utf-8,${encodeURIComponent( + JSON.stringify(flow) + )}`; - return ( - - {children} - - ); + // create a link element and set its properties + const link = document.createElement("a"); + link.href = jsonString; + link.download = `${normalCaseToSnakeCase(flows[tabIndex].name)}.json`; + + // simulate a click on the link element to trigger the download + link.click(); + setNoticeData({ + title: "Warning: Critical data,JSON file may including API keys.", + }); + } + + /** + * Creates a file input and listens to a change event to upload a JSON flow file. + * If the file type is application/json, the file is read and parsed into a JSON object. + * The resulting JSON object is passed to the addFlow function. + */ + function uploadFlow() { + // create a file input + const input = document.createElement("input"); + input.type = "file"; + // add a change event listener to the file input + input.onchange = (e: Event) => { + // check if the file type is application/json + if ((e.target as HTMLInputElement).files[0].type === "application/json") { + // get the file from the file input + const file = (e.target as HTMLInputElement).files[0]; + // read the file as text + file.text().then((text) => { + // parse the text into a JSON object + let flow: FlowType = JSON.parse(text); + + addFlow(flow); + }); + } + }; + // trigger the file input click event to open the file dialog + input.click(); + } + /** + * Removes a flow from an array of flows based on its id. + * Updates the state of flows and tabIndex using setFlows and setTabIndex hooks. + * @param {string} id - The id of the flow to remove. + */ + function removeFlow(id: string) { + setFlows((prevState) => { + const newFlows = [...prevState]; + const index = newFlows.findIndex((flow) => flow.id === id); + if (index >= 0) { + if (index === tabIndex) { + setTabIndex(flows.length - 2); + newFlows.splice(index, 1); + } else { + let flowId = flows[tabIndex].id; + newFlows.splice(index, 1); + setTabIndex(newFlows.findIndex((flow) => flow.id === flowId)); + } + } + return newFlows; + }); + } + /** + * Add a new flow to the list of flows. + * @param flow Optional flow to add. + */ + function addFlow(flow?: FlowType) { + // Get data from the flow or set it to null if there's no flow provided. + const data = flow?.data ? flow.data : null; + const description = flow?.description ? flow.description : ""; + + if (data) { + data.nodes.forEach((node) => { + if (Object.keys(templates[node.data.type]["template"]).length > 0) { + node.data.node.template = updateTemplate( + templates[node.data.type]["template"] as unknown as APITemplateType, + node.data.node.template as APITemplateType + ); + } + }); + } + // Create a new flow with a default name if no flow is provided. + let newFlow: FlowType = { + description, + name: flow?.name ?? "New Flow", + id: id.toString(), + data, + }; + + // Increment the ID counter. + setId(uuidv4()); + + // Add the new flow to the list of flows. + setFlows((prevState) => { + const newFlows = [...prevState, newFlow]; + return newFlows; + }); + + // Set the tab index to the new flow. + setTabIndex(flows.length); + } + /** + * Updates an existing flow with new data + * @param newFlow - The new flow object containing the updated data + */ + function updateFlow(newFlow: FlowType) { + setFlows((prevState) => { + const newFlows = [...prevState]; + const index = newFlows.findIndex((flow) => flow.id === newFlow.id); + if (index !== -1) { + newFlows[index].description = newFlow.description ?? ""; + newFlows[index].data = newFlow.data; + newFlows[index].name = newFlow.name; + } + return newFlows; + }); + } + + return ( + + {children} + + ); } From 9d3098f3e22a2b9f531685d9adda820aba171a83 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 2 May 2023 19:49:32 -0300 Subject: [PATCH 3/5] refactor(langflow): reorder imports in multiple files This commit reorders imports in multiple files to follow PEP8 guidelines and improve code readability. No functional changes were made. --- src/backend/langflow/__init__.py | 2 +- src/backend/langflow/api/callback.py | 1 + src/backend/langflow/api/chat_manager.py | 10 ++++++---- src/backend/langflow/api/schemas.py | 1 + src/backend/langflow/cache/base.py | 2 +- src/backend/langflow/cache/manager.py | 3 ++- src/backend/langflow/interface/agents/custom.py | 2 +- src/backend/langflow/interface/loading.py | 2 +- src/backend/langflow/interface/run.py | 1 + src/backend/langflow/interface/utils.py | 9 +++++---- src/backend/langflow/main.py | 2 +- src/backend/langflow/utils/util.py | 2 +- tests/conftest.py | 3 +-- tests/test_cache_manager.py | 7 ++++--- tests/test_websocket.py | 1 + 15 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/backend/langflow/__init__.py b/src/backend/langflow/__init__.py index fb06fe1a7..35fe814d2 100644 --- a/src/backend/langflow/__init__.py +++ b/src/backend/langflow/__init__.py @@ -1,4 +1,4 @@ -from langflow.interface.loading import load_flow_from_json from langflow.cache import cache_manager +from langflow.interface.loading import load_flow_from_json __all__ = ["load_flow_from_json", "cache_manager"] diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py index cad4b1416..baab596a4 100644 --- a/src/backend/langflow/api/callback.py +++ b/src/backend/langflow/api/callback.py @@ -1,4 +1,5 @@ from typing import Any + from langchain.callbacks.base import AsyncCallbackHandler from langflow.api.schemas import ChatResponse diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 5c490c43c..2dab12e34 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -1,9 +1,12 @@ import asyncio -from typing import Dict, List -from collections import defaultdict -from fastapi import WebSocket import json +from collections import defaultdict +from typing import Dict, List + +from fastapi import WebSocket + from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.cache import cache_manager from langflow.cache.manager import Subject from langflow.interface.run import ( get_result_and_steps, @@ -11,7 +14,6 @@ from langflow.interface.run import ( ) from langflow.interface.utils import pil_to_base64, try_setting_streaming_options from langflow.utils.logger import logger -from langflow.cache import cache_manager class ChatHistory(Subject): diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index e4adacf9e..dd157d85f 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -1,4 +1,5 @@ from typing import Any, Union + from pydantic import BaseModel, validator diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index ede0fb06e..73439e9dd 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -2,13 +2,13 @@ import base64 import contextlib import functools import hashlib - import json import os import tempfile from collections import OrderedDict from pathlib import Path from typing import Any, Dict + import dill # type: ignore CACHE: Dict[str, Any] = {} diff --git a/src/backend/langflow/cache/manager.py b/src/backend/langflow/cache/manager.py index 971519230..947f5ce21 100644 --- a/src/backend/langflow/cache/manager.py +++ b/src/backend/langflow/cache/manager.py @@ -1,7 +1,8 @@ from contextlib import contextmanager from typing import Any, Awaitable, Callable, List, Optional -from PIL import Image + import pandas as pd +from PIL import Image class Subject: diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 67827bd9d..84bf793f9 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -26,9 +26,9 @@ from langchain.agents.agent_toolkits.vectorstore.prompt import ( ) from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS as SQL_FORMAT_INSTRUCTIONS +from langchain.base_language import BaseLanguageModel from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import BaseChatMemory -from langchain.base_language import BaseLanguageModel from langchain.sql_database import SQLDatabase from langchain.tools.python.tool import PythonAstREPLTool from langchain.tools.sql_database.prompt import QUERY_CHECKER diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 82405319a..49cf9f698 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -17,6 +17,7 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.chains.loading import load_chain_from_config from langchain.llms.base import BaseLLM from langchain.llms.loading import load_llm_from_config +from pydantic import ValidationError from langflow.interface.agents.custom import CUSTOM_AGENTS from langflow.interface.importing.utils import import_by_type @@ -25,7 +26,6 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.types import get_type_list from langflow.interface.utils import load_file_into_dict from langflow.utils import util, validate -from pydantic import ValidationError def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 89f26944e..68639785c 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,6 +1,7 @@ import contextlib import io from typing import Any, Dict + from chromadb.errors import NotEnoughElementsException # type: ignore from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index 21f627b60..e8b3e417e 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -1,14 +1,15 @@ import base64 -from io import BytesIO import json import os -from PIL.Image import Image +from io import BytesIO + +import yaml from langchain.callbacks.manager import AsyncCallbackManager from langchain.chat_models import AzureChatOpenAI, ChatOpenAI from langchain.llms import AzureOpenAI, OpenAI -from langflow.api.callback import StreamingLLMCallbackHandler +from PIL.Image import Image -import yaml +from langflow.api.callback import StreamingLLMCallbackHandler def load_file_into_dict(file_path: str) -> dict: diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index c1f2decd5..56cc32e46 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -1,9 +1,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from langflow.api.chat import router as chat_router from langflow.api.endpoints import router as endpoints_router from langflow.api.validate import router as validate_router -from langflow.api.chat import router as chat_router def create_app(): diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 4d2a281cb..e959b0103 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -1,7 +1,7 @@ -from functools import wraps import importlib import inspect import re +from functools import wraps from typing import Dict, Optional from docstring_parser import parse # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index 15da0d1ef..870c48a32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,10 @@ import json from pathlib import Path from typing import AsyncGenerator -from httpx import AsyncClient - import pytest from fastapi.testclient import TestClient +from httpx import AsyncClient def pytest_configure(): diff --git a/tests/test_cache_manager.py b/tests/test_cache_manager.py index 8680a43cb..f3e65481e 100644 --- a/tests/test_cache_manager.py +++ b/tests/test_cache_manager.py @@ -1,8 +1,9 @@ -import pytest -from PIL import Image -import pandas as pd from io import StringIO + +import pandas as pd +import pytest from langflow.cache.manager import CacheManager +from PIL import Image @pytest.fixture diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 7ac646cf4..5b60d0fed 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,5 +1,6 @@ import json from unittest.mock import patch + from fastapi.testclient import TestClient From 1b10041730fd551761bb74dbc414c7e39fbaf0a4 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 2 May 2023 20:12:44 -0300 Subject: [PATCH 4/5] refactor(base.py): simplify type of Union fields to the first type in the Union fix(test_agents_template.py): set value of list field to an empty list fix(test_llms_template.py): change type of 'request_timeout' field to float --- src/backend/langflow/template/base.py | 13 +++++++++++-- tests/test_agents_template.py | 1 + tests/test_llms_template.py | 5 ++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index fd34a19f2..5273782ed 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -162,14 +162,23 @@ class FrontendNode(BaseModel): _type = _type.replace("Optional[", "")[:-1] # Check for list type - if "List" in _type: - _type = _type.replace("List[", "")[:-1] + if "List" in _type or "Sequence" in _type: + _type = _type.replace("List[", "") + _type = _type.replace("Sequence[", "")[:-1] field.is_list = True # Replace 'Mapping' with 'dict' if "Mapping" in _type: _type = _type.replace("Mapping", "dict") + # {'type': 'Union[float, Tuple[float, float], NoneType]'} != {'type': 'float'} + if "Union" in _type: + _type = _type.replace("Union[", "")[:-1] + _type = _type.split(",")[0] + _type = _type.replace("]", "").replace("[", "") + + field.field_type = _type + # Change type from str to Tool field.field_type = "Tool" if key in {"allowed_tools"} else field.field_type diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 750050c25..db746a424 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -48,6 +48,7 @@ def test_zero_shot_agent(client: TestClient): "type": "Tool", "list": True, "advanced": False, + "value": [], } diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index 9f5fdfb4d..dc0149520 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -291,7 +291,7 @@ def test_openai(client: TestClient): "multiline": False, "password": False, "name": "request_timeout", - "type": "Union[float, Tuple[float, float], NoneType]", + "type": "float", "list": False, "advanced": False, } @@ -418,10 +418,9 @@ def test_chat_open_ai(client: TestClient): "placeholder": "", "show": False, "multiline": False, - "value": 60, "password": False, "name": "request_timeout", - "type": "int", + "type": "float", "list": False, "advanced": False, } From e0761cfe231491698d40cf62e716653771a83a07 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 2 May 2023 20:23:14 -0300 Subject: [PATCH 5/5] refactor(agents): change tool_names variable from list to set The `tool_names` variable was changed from a list to a set in the `JsonAgent`, `CSVAgent`, `VectorStoreAgent`, `SQLAgent`, and `MalfoyAgent` classes. This change was made to improve performance and avoid duplicates. --- src/backend/langflow/interface/agents/custom.py | 10 +++++----- src/backend/langflow/interface/agents/prebuilt.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index 84bf793f9..6b545ed3d 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -51,7 +51,7 @@ class JsonAgent(AgentExecutor): @classmethod def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel): tools = toolkit.get_tools() - tool_names = [tool.name for tool in tools] + tool_names = {tool.name for tool in tools} prompt = ZeroShotAgent.create_prompt( tools, prefix=JSON_PREFIX, @@ -109,7 +109,7 @@ class CSVAgent(AgentExecutor): llm=llm, prompt=partial_prompt, ) - tool_names = [tool.name for tool in tools] + tool_names = {tool.name for tool in tools} agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) @@ -146,7 +146,7 @@ class VectorStoreAgent(AgentExecutor): llm=llm, prompt=prompt, ) - tool_names = [tool.name for tool in tools] + tool_names = {tool.name for tool in tools} agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, verbose=True @@ -212,7 +212,7 @@ class SQLAgent(AgentExecutor): llm=llm, prompt=prompt, ) - tool_names = [tool.name for tool in tools] # type: ignore + tool_names = {tool.name for tool in tools} # type: ignore agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) return AgentExecutor.from_agent_and_tools( agent=agent, @@ -255,7 +255,7 @@ class VectorStoreRouterAgent(AgentExecutor): llm=llm, prompt=prompt, ) - tool_names = [tool.name for tool in tools] + tool_names = {tool.name for tool in tools} agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, verbose=True diff --git a/src/backend/langflow/interface/agents/prebuilt.py b/src/backend/langflow/interface/agents/prebuilt.py index 957dd4f7d..58d8d561f 100644 --- a/src/backend/langflow/interface/agents/prebuilt.py +++ b/src/backend/langflow/interface/agents/prebuilt.py @@ -21,7 +21,7 @@ class MalfoyAgent(AgentExecutor): @classmethod def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel): tools = toolkit.get_tools() - tool_names = [tool.name for tool in tools] + tool_names = {tool.name for tool in tools} prompt = ZeroShotAgent.create_prompt( tools, prefix=JSON_PREFIX,