diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index d45221be7..2b2ccc43f 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -15,6 +15,11 @@ CUSTOM_NODES = { "utilities": { "SQLDatabase": nodes.SQLDatabaseNode(), }, + "chains": { + "SeriesCharacterChain": nodes.SeriesCharacterChainNode(), + "TimeTravelGuideChain": nodes.TimeTravelGuideChainNode(), + "MidJourneyPromptChain": nodes.MidJourneyPromptChainNode(), + }, } diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index ecadde108..860d2ac1e 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -23,6 +23,7 @@ class TemplateFieldCreator(BaseModel, ABC): options: list[str] = [] name: str = "" display_name: Optional[str] = None + advanced: bool = True def to_dict(self): result = self.dict() @@ -232,3 +233,12 @@ class FrontendNode(BaseModel): # other conditions are to make sure that it is not an input or output variable if "api" in key.lower() and "key" in key.lower(): field.required = False + + if "kwargs" in field.name.lower(): + field.advanced = True + field.required = False + field.show = False + # If the field.name contains api or api and key, then it might be an api key + # other conditions are to make sure that it is not an input or output variable + if "api" in key.lower() and "key" in key.lower(): + field.required = False diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index b174a9363..9e33e2a79 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -101,6 +101,107 @@ class PythonFunctionNode(FrontendNode): return super().to_dict() +class MidJourneyPromptChainNode(FrontendNode): + name: str = "MidJourneyPromptChain" + template: Template = Template( + type_name="MidJourneyPromptChain", + fields=[ + TemplateField( + field_type="BaseLanguageModel", + required=True, + placeholder="", + is_list=False, + show=True, + advanced=False, + multiline=False, + name="llm", + ), + ], + ) + description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." + base_classes: list[str] = [ + "LLMChain", + "BaseCustomChain", + "Chain", + "ConversationChain", + "MidJourneyPromptChain", + ] + + +class TimeTravelGuideChainNode(FrontendNode): + name: str = "TimeTravelGuideChain" + template: Template = Template( + type_name="TimeTravelGuideChain", + fields=[ + TemplateField( + field_type="BaseLanguageModel", + required=True, + placeholder="", + is_list=False, + show=True, + advanced=False, + multiline=False, + name="llm", + ), + ], + ) + description: str = "Time travel guide chain to be used in the flow." + base_classes: list[str] = [ + "LLMChain", + "BaseCustomChain", + "TimeTravelGuideChain", + "Chain", + "ConversationChain", + ] + + +class SeriesCharacterChainNode(FrontendNode): + name: str = "SeriesCharacterChain" + template: Template = Template( + type_name="SeriesCharacterChain", + fields=[ + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + advanced=False, + multiline=False, + name="character", + ), + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + advanced=False, + multiline=False, + name="series", + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + placeholder="", + is_list=False, + show=True, + advanced=False, + multiline=False, + name="llm", + ), + ], + ) + description: str = "SeriesCharacterChain is a chain you can use to have a conversation with a character from a series." # noqa + base_classes: list[str] = [ + "LLMChain", + "BaseCustomChain", + "Chain", + "ConversationChain", + "SeriesCharacterChain", + ] + + class ToolNode(FrontendNode): name: str = "Tool" template: Template = Template( @@ -418,17 +519,29 @@ class ChainFrontendNode(FrontendNode): def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) + field.advanced = False if "key" in field.name: field.password = False field.show = False if field.name in ["input_key", "output_key"]: field.required = True field.show = True + field.advanced = True + # Separated for possible future changes if field.name == "prompt": # if no prompt is provided, use the default prompt field.required = False field.show = True + field.advanced = False + if field.name == "memory": + field.required = False + field.show = True + field.advanced = False + if field.name == "verbose": + field.required = False + field.show = True + field.advanced = True class LLMFrontendNode(FrontendNode): @@ -438,7 +551,7 @@ class LLMFrontendNode(FrontendNode): "huggingfacehub_api_token": "HuggingFace Hub API Token", } FrontendNode.format_field(field, name) - SHOW_FIELDS = ["repo_id", "task", "model_kwargs"] + SHOW_FIELDS = ["repo_id"] if field.name in SHOW_FIELDS: field.show = True @@ -448,14 +561,21 @@ class LLMFrontendNode(FrontendNode): # Required should be False to support # loading the API key from environment variables field.required = False + field.advanced = False if field.name == "task": field.required = True field.show = True field.is_list = True field.options = ["text-generation", "text2text-generation"] + field.advanced = True if display_name := display_names_dict.get(field.name): field.display_name = display_name if field.name == "model_kwargs": field.field_type = "code" + field.advanced = True + field.show = True + elif field.name in ["model_name", "temperature"]: + field.advanced = False + field.show = True diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index ff13af901..bfaa2ba24 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -1,4 +1,4 @@ -import { TrashIcon } from "@heroicons/react/24/outline"; +import { Cog6ToothIcon, TrashIcon } from "@heroicons/react/24/outline"; import { classNames, nodeColors, @@ -10,6 +10,8 @@ import { typesContext } from "../../contexts/typesContext"; import { useContext, useRef } from "react"; import { NodeDataType } from "../../types/flow"; import { alertContext } from "../../contexts/alertContext"; +import { PopUpContext } from "../../contexts/popUpContext"; +import NodeModal from "../../modals/NodeModal"; export default function GenericNode({ data, @@ -21,6 +23,7 @@ export default function GenericNode({ const { setErrorData } = useContext(alertContext); const showError = useRef(true); const { types, deleteNode } = useContext(typesContext); + const { openPopUp } = useContext(PopUpContext); const Icon = nodeIcons[types[data.type]]; if (!Icon) { if (showError.current) { @@ -51,17 +54,46 @@ export default function GenericNode({ />
{data.type}
- +
+ + +
-
+
{data.node.description}
@@ -70,7 +102,7 @@ export default function GenericNode({ .filter((t) => t.charAt(0) !== "_") .map((t: string, idx) => (
- {idx === 0 ? ( + {/* {idx === 0 ? (
) : ( <> - )} - {data.node.template[t].show ? ( + )} */} + {data.node.template[t].show && !data.node.template[t].advanced ? ( ))} -
- Output +
+ {" "}
+ {/*
+ Output +
*/} , t) => errors.concat( - (template[t].required && template[t].show) && - (!template[t].value || template[t].value === "") && + (template[t].required) && + (template[t].value===undefined || template[t].value === "") && !reactFlowInstance .getEdges() .some( diff --git a/src/frontend/src/components/codeAreaComponent/index.tsx b/src/frontend/src/components/codeAreaComponent/index.tsx index 756296970..2086b467c 100644 --- a/src/frontend/src/components/codeAreaComponent/index.tsx +++ b/src/frontend/src/components/codeAreaComponent/index.tsx @@ -19,9 +19,24 @@ export default function CodeAreaComponent({ } }, [disabled, onChange]); return ( -
+
{ + openPopUp( + { + setMyValue(t); + onChange(t); + }} + /> + ); + }} className={ "truncate block w-full text-gray-500 px-3 py-2 rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm" + (disabled ? " bg-gray-200" : "") diff --git a/src/frontend/src/components/inputFileComponent/index.tsx b/src/frontend/src/components/inputFileComponent/index.tsx index 36d13e918..cb915bce5 100644 --- a/src/frontend/src/components/inputFileComponent/index.tsx +++ b/src/frontend/src/components/inputFileComponent/index.tsx @@ -69,6 +69,7 @@ export default function InputFileComponent({ >
{ - if (disabled) { - setMyValue(""); - onChange(""); - } - }, [disabled, onChange]); - return ( -
-
- - {myValue !== "" ? myValue : 'Text empty'} - - -
-
- ); +export default function PromptAreaComponent({ + value, + onChange, + disabled, +}: TextAreaComponentType) { + const [myValue, setMyValue] = useState(value); + const { openPopUp } = useContext(PopUpContext); + useEffect(() => { + if (disabled) { + setMyValue(""); + onChange(""); + } + }, [disabled, onChange]); + return ( +
+
+ { + openPopUp( + { + setMyValue(t); + onChange(t); + }} + /> + ); + }} + className={ + "truncate block w-full text-gray-500 px-3 py-2 rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm" + + (disabled ? " bg-gray-200" : "") + } + > + {myValue !== "" ? myValue : "Text empty"} + + +
+
+ ); } diff --git a/src/frontend/src/components/textAreaComponent/index.tsx b/src/frontend/src/components/textAreaComponent/index.tsx index df61e1ae1..153de4ffa 100644 --- a/src/frontend/src/components/textAreaComponent/index.tsx +++ b/src/frontend/src/components/textAreaComponent/index.tsx @@ -1,7 +1,6 @@ import { ArrowTopRightOnSquareIcon } from "@heroicons/react/24/outline"; import { useContext, useEffect, useState } from "react"; import { PopUpContext } from "../../contexts/popUpContext"; -import CodeAreaModal from "../../modals/codeAreaModal"; import TextAreaModal from "../../modals/textAreaModal"; import { TextAreaComponentType } from "../../types/components"; @@ -17,7 +16,7 @@ export default function TextAreaComponent({ value, onChange, disabled }:TextArea return (
- {openPopUp( {setMyValue(t); onChange(t);}}/>)}} className={ "truncate block w-full text-gray-500 px-3 py-2 rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm" + (disabled ? " bg-gray-200" : "") diff --git a/src/frontend/src/contexts/index.tsx b/src/frontend/src/contexts/index.tsx index 310606ea5..783bd108c 100644 --- a/src/frontend/src/contexts/index.tsx +++ b/src/frontend/src/contexts/index.tsx @@ -12,15 +12,13 @@ export default function ContextWrapper({ children }: { children: ReactNode }) { <> - - - - - {children} - - - - + + + + {children} + + + diff --git a/src/frontend/src/contexts/popUpContext.tsx b/src/frontend/src/contexts/popUpContext.tsx index 1d4d28bc7..efa263146 100644 --- a/src/frontend/src/contexts/popUpContext.tsx +++ b/src/frontend/src/contexts/popUpContext.tsx @@ -1,33 +1,33 @@ import { createContext } from "react"; import React, { useState } from "react"; -//context to set JSX element on the DOM +// context to set JSX element on the DOM export const PopUpContext = createContext({ openPopUp: (popUpElement: JSX.Element) => {}, - closePopUp: () => {}, + closePopUp: () => {}, }); interface PopUpProviderProps { - children: React.ReactNode; + children: React.ReactNode; } const PopUpProvider = ({ children }: PopUpProviderProps) => { - const [popUpElement, setPopUpElement] = useState(null); + const [popUpElements, setPopUpElements] = useState([]); - const openPopUp = (element: JSX.Element) => { - setPopUpElement(element); - }; + const openPopUp = (element: JSX.Element) => { + setPopUpElements(prevPopUps => [element, ...prevPopUps]); + }; - const closePopUp = () => { - setPopUpElement(null); - }; + const closePopUp = () => { + setPopUpElements(prevPopUps => prevPopUps.slice(1)); + }; - return ( - - {children} - {popUpElement} - - ); + return ( + + {children} + {popUpElements[0]} + + ); }; export default PopUpProvider; diff --git a/src/frontend/src/modals/NodeModal/components/ModalField/index.tsx b/src/frontend/src/modals/NodeModal/components/ModalField/index.tsx new file mode 100644 index 000000000..a6f56ca6d --- /dev/null +++ b/src/frontend/src/modals/NodeModal/components/ModalField/index.tsx @@ -0,0 +1,166 @@ +import { useContext, useState } from "react"; +import { TabsContext } from "../../../../contexts/tabsContext"; +import InputListComponent from "../../../../components/inputListComponent"; +import Dropdown from "../../../../components/dropdownComponent"; +import TextAreaComponent from "../../../../components/textAreaComponent"; +import InputComponent from "../../../../components/inputComponent"; +import ToggleComponent from "../../../../components/toggleComponent"; +import FloatComponent from "../../../../components/floatComponent"; +import IntComponent from "../../../../components/intComponent"; +import InputFileComponent from "../../../../components/inputFileComponent"; +import PromptAreaComponent from "../../../../components/promptComponent"; +import CodeAreaComponent from "../../../../components/codeAreaComponent"; +import { classNames } from "../../../../utils"; + +export default function ModalField({ data, title, required, id, name, type }) { + const { save } = useContext(TabsContext); + const [enabled, setEnabled] = useState( + data.node.template[name]?.value ?? false + ); + const display = + type === "str" || + type === "int" || + type === "prompt" || + type === "bool" || + type === "float" || + type === "file" || + type === "code"; + + return ( +
+ {display && ( +
+ {title} + {required ? " *" : ""} +
+ )} + + {type === "str" && !data.node.template[name].options ? ( +
+ {data.node.template[name].list ? ( + { + data.node.template[name].value = t; + save(); + }} + /> + ) : data.node.template[name].multiline ? ( + { + data.node.template[name].value = t; + save(); + }} + /> + ) : ( + { + data.node.template[name].value = t; + save(); + }} + /> + )} +
+ ) : type === "bool" ? ( +
+ {" "} + { + data.node.template[name].value = t; + setEnabled(t); + save(); + }} + /> +
+ ) : type === "float" ? ( +
+ { + data.node.template[name].value = t; + save(); + }} + /> +
+ ) : type === "str" && data.node.template[name].options ? ( +
+ (data.node.template[name].value = newValue)} + value={data.node.template[name].value ?? "Choose an option"} + > +
+ ) : type === "int" ? ( +
+ { + data.node.template[name].value = t; + save(); + }} + /> +
+ ) : type === "file" ? ( +
+ { + data.node.template[name].value = t; + }} + fileTypes={data.node.template[name].fileTypes} + suffixes={data.node.template[name].suffixes} + onFileChange={(t: string) => { + data.node.template[name].content = t; + save(); + }} + > +
+ ) : type === "prompt" ? ( +
+ { + data.node.template[name].value = t; + save(); + }} + /> +
+ ) : type === "code" ? ( +
+ { + data.node.template[name].value = t; + save(); + }} + /> +
+ ) : ( +
+ )} +
+ ); +} diff --git a/src/frontend/src/modals/NodeModal/index.tsx b/src/frontend/src/modals/NodeModal/index.tsx new file mode 100644 index 000000000..31786553b --- /dev/null +++ b/src/frontend/src/modals/NodeModal/index.tsx @@ -0,0 +1,144 @@ +import { Dialog, Transition } from "@headlessui/react"; +import { XMarkIcon } from "@heroicons/react/24/outline"; +import { Fragment, useContext, useRef, useState } from "react"; +import { PopUpContext } from "../../contexts/popUpContext"; +import { NodeDataType } from "../../types/flow"; +import { nodeColors, nodeIcons, snakeToNormalCase } from "../../utils"; +import { typesContext } from "../../contexts/typesContext"; +import ModalField from "./components/ModalField"; + +export default function NodeModal({ data }: { data: NodeDataType }) { + const [open, setOpen] = useState(true); + const { closePopUp } = useContext(PopUpContext); + const { types } = useContext(typesContext); + const ref = useRef(); + function setModalOpen(x: boolean) { + setOpen(x); + if (x === false) { + setTimeout(() => { + closePopUp(); + }, 300); + } + } + const Icon = nodeIcons[types[data.type]]; + return ( + + + +
+ + +
+
+ + +
+ +
+
+
+ +
+ + {data.type} + +
+
+
+
+
+
+ { + Object.keys(data.node.template) + .filter((t) => t.charAt(0) !== "_"&& data.node.template[t].advanced && data.node.template[t].show) + .map((t: string, idx) => { + return ( + + ); + }) + } +
+
+
+
+
+ +
+
+
+
+
+
+
+
+ ); +} diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index d63365fdc..a546454eb 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -36,6 +36,7 @@ def test_zero_shot_agent(client: TestClient): "name": "llm_chain", "type": "LLMChain", "list": False, + "advanced": True, } assert template["allowed_tools"] == { "required": False, @@ -46,6 +47,7 @@ def test_zero_shot_agent(client: TestClient): "name": "allowed_tools", "type": "Tool", "list": True, + "advanced": True, } @@ -68,6 +70,7 @@ def test_json_agent(client: TestClient): "name": "toolkit", "type": "BaseToolkit", "list": False, + "advanced": True, } assert template["llm"] == { "required": True, @@ -78,6 +81,7 @@ def test_json_agent(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": True, } @@ -104,6 +108,7 @@ def test_csv_agent(client: TestClient): "type": "file", "list": False, "content": None, + "advanced": True, } assert template["llm"] == { "required": True, @@ -114,6 +119,7 @@ def test_csv_agent(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": True, } @@ -143,6 +149,7 @@ def test_initialize_agent(client: TestClient): "name": "agent", "type": "str", "list": True, + "advanced": True, } assert template["memory"] == { "required": False, @@ -153,6 +160,7 @@ def test_initialize_agent(client: TestClient): "name": "memory", "type": "BaseChatMemory", "list": False, + "advanced": True, } assert template["tools"] == { "required": False, @@ -163,6 +171,7 @@ def test_initialize_agent(client: TestClient): "name": "tools", "type": "Tool", "list": True, + "advanced": True, } assert template["llm"] == { "required": True, @@ -173,4 +182,5 @@ def test_initialize_agent(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": True, } diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index 510437797..cff844c90 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -31,16 +31,18 @@ def test_conversation_chain(client: TestClient): "name": "memory", "type": "BaseMemory", "list": False, + "advanced": False, } assert template["verbose"] == { "required": False, "placeholder": "", - "show": False, + "show": True, "multiline": False, "password": False, "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["llm"] == { "required": True, @@ -51,6 +53,7 @@ def test_conversation_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": False, } assert template["input_key"] == { "required": True, @@ -62,6 +65,7 @@ def test_conversation_chain(client: TestClient): "name": "input_key", "type": "str", "list": False, + "advanced": True, } assert template["output_key"] == { "required": True, @@ -73,6 +77,7 @@ def test_conversation_chain(client: TestClient): "name": "output_key", "type": "str", "list": False, + "advanced": True, } assert template["_type"] == "ConversationChain" @@ -102,17 +107,19 @@ def test_llm_chain(client: TestClient): "name": "memory", "type": "BaseMemory", "list": False, + "advanced": False, } assert template["verbose"] == { "required": False, "placeholder": "", - "show": False, + "show": True, "multiline": False, "value": False, "password": False, "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["llm"] == { "required": True, @@ -123,6 +130,7 @@ def test_llm_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": False, } assert template["output_key"] == { "required": True, @@ -134,6 +142,7 @@ def test_llm_chain(client: TestClient): "name": "output_key", "type": "str", "list": False, + "advanced": True, } @@ -156,17 +165,19 @@ def test_llm_checker_chain(client: TestClient): "name": "memory", "type": "BaseMemory", "list": False, + "advanced": False, } assert template["verbose"] == { "required": False, "placeholder": "", - "show": False, + "show": True, "multiline": False, "value": False, "password": False, "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["llm"] == { "required": True, @@ -177,6 +188,7 @@ def test_llm_checker_chain(client: TestClient): "name": "llm", "type": "BaseLLM", "list": False, + "advanced": False, } assert template["input_key"] == { "required": True, @@ -188,6 +200,7 @@ def test_llm_checker_chain(client: TestClient): "name": "input_key", "type": "str", "list": False, + "advanced": True, } assert template["output_key"] == { "required": True, @@ -199,6 +212,7 @@ def test_llm_checker_chain(client: TestClient): "name": "output_key", "type": "str", "list": False, + "advanced": True, } assert template["_type"] == "LLMCheckerChain" @@ -228,17 +242,19 @@ def test_llm_math_chain(client: TestClient): "name": "memory", "type": "BaseMemory", "list": False, + "advanced": False, } assert template["verbose"] == { "required": False, "placeholder": "", - "show": False, + "show": True, "multiline": False, "value": False, "password": False, "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["llm"] == { "required": True, @@ -249,6 +265,7 @@ def test_llm_math_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, + "advanced": False, } assert template["input_key"] == { "required": True, @@ -260,6 +277,7 @@ def test_llm_math_chain(client: TestClient): "name": "input_key", "type": "str", "list": False, + "advanced": True, } assert template["output_key"] == { "required": True, @@ -271,6 +289,7 @@ def test_llm_math_chain(client: TestClient): "name": "output_key", "type": "str", "list": False, + "advanced": True, } assert template["_type"] == "LLMMathChain" @@ -298,35 +317,7 @@ def test_series_character_chain(client: TestClient): "SeriesCharacterChain", } template = chain["template"] - assert template["memory"] == { - "required": False, - "placeholder": "", - "show": True, - "multiline": False, - "value": { - "chat_memory": {"messages": []}, - "output_key": None, - "input_key": None, - "return_messages": False, - "human_prefix": "Human", - "ai_prefix": "AI", - "memory_key": "history", - }, - "password": False, - "name": "memory", - "type": "BaseMemory", - "list": False, - } - assert template["verbose"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "password": False, - "name": "verbose", - "type": "bool", - "list": False, - } + assert template["llm"] == { "required": True, "placeholder": "", @@ -336,50 +327,7 @@ def test_series_character_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, - } - assert template["input_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "input", - "password": False, - "name": "input_key", - "type": "str", - "list": False, - } - assert template["output_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "response", - "password": False, - "name": "output_key", - "type": "str", - "list": False, - } - assert template["template"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": True, - "value": "I want you to act like {character} from {series}.\nI want you to respond and answer like {character}. do not write any explanations. only answer like {character}.\nYou must know all of the knowledge of {character}.\nCurrent conversation:\n{history}\nHuman: {input}\n{character}:", # noqa: E501 - "password": False, - "name": "template", - "type": "str", - "list": False, - } - assert template["ai_prefix_value"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "value": "character", - "password": False, - "name": "ai_prefix_value", - "type": "str", - "list": False, + "advanced": False, } assert template["character"] == { "required": True, @@ -390,6 +338,7 @@ def test_series_character_chain(client: TestClient): "name": "character", "type": "str", "list": False, + "advanced": False, } assert template["series"] == { "required": True, @@ -400,6 +349,7 @@ def test_series_character_chain(client: TestClient): "name": "series", "type": "str", "list": False, + "advanced": False, } assert template["_type"] == "SeriesCharacterChain" @@ -429,55 +379,7 @@ def test_mid_journey_prompt_chain(client: TestClient): # Test the template object template = chain["template"] - assert template["memory"] == { - "required": False, - "placeholder": "", - "show": True, - "multiline": False, - "value": { - "chat_memory": {"messages": []}, - "output_key": None, - "input_key": None, - "return_messages": False, - "human_prefix": "Human", - "ai_prefix": "AI", - "memory_key": "history", - }, - "password": False, - "name": "memory", - "type": "BaseMemory", - "list": False, - } - assert template["verbose"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "password": False, - "name": "verbose", - "type": "bool", - "list": False, - } - # Continue with other template object assertions - assert template["prompt"] == { - "required": False, - "placeholder": "", - "show": True, - "multiline": False, - "value": { - "input_variables": ["history", "input"], - "output_parser": None, - "partial_variables": {}, - "template": "The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n\nCurrent conversation:\n{history}\nHuman: {input}\nAI:", # noqa: E501 - "template_format": "f-string", - "validate_template": True, - "_type": "prompt", - }, - "password": False, - "name": "prompt", - "type": "BasePromptTemplate", - "list": False, - } + assert template["llm"] == { "required": True, "placeholder": "", @@ -487,49 +389,7 @@ def test_mid_journey_prompt_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, - } - assert template["output_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "response", - "password": False, - "name": "output_key", - "type": "str", - "list": False, - } - assert template["input_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "input", - "password": False, - "name": "input_key", - "type": "str", - "list": False, - } - assert template["template"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": True, - "value": 'I want you to act as a prompt generator for Midjourney\'s artificial intelligence program.\n Your job is to provide detailed and creative descriptions that will inspire unique and interesting images from the AI.\n Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible.\n For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures.\n The more detailed and imaginative your description, the more interesting the resulting image will be. Here is your first prompt:\n "A field of wildflowers stretches out as far as the eye can see, each one a different color and shape. In the distance, a massive tree towers over the landscape, its branches reaching up to the sky like tentacles."\n\n Current conversation:\n {history}\n Human: {input}\n AI:', # noqa: E501 - "password": False, - "name": "template", - "type": "str", - "list": False, - } - assert template["ai_prefix_value"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "password": False, - "name": "ai_prefix_value", - "type": "str", - "list": False, + "advanced": False, } # Test the description object assert ( @@ -557,55 +417,7 @@ def test_time_travel_guide_chain(client: TestClient): # Test the template object template = chain["template"] - assert template["memory"] == { - "required": False, - "placeholder": "", - "show": True, - "multiline": False, - "value": { - "chat_memory": {"messages": []}, - "output_key": None, - "input_key": None, - "return_messages": False, - "human_prefix": "Human", - "ai_prefix": "AI", - "memory_key": "history", - }, - "password": False, - "name": "memory", - "type": "BaseMemory", - "list": False, - } - assert template["verbose"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "password": False, - "name": "verbose", - "type": "bool", - "list": False, - } - assert template["prompt"] == { - "required": False, - "placeholder": "", - "show": True, - "multiline": False, - "value": { - "input_variables": ["history", "input"], - "output_parser": None, - "partial_variables": {}, - "template": "The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n\nCurrent conversation:\n{history}\nHuman: {input}\nAI:", # noqa: E501 - "template_format": "f-string", - "validate_template": True, - "_type": "prompt", - }, - "password": False, - "name": "prompt", - "type": "BasePromptTemplate", - "list": False, - } assert template["llm"] == { "required": True, "placeholder": "", @@ -615,50 +427,7 @@ def test_time_travel_guide_chain(client: TestClient): "name": "llm", "type": "BaseLanguageModel", "list": False, - } - assert template["output_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "response", - "password": False, - "name": "output_key", - "type": "str", - "list": False, + "advanced": False, } - assert template["input_key"] == { - "required": True, - "placeholder": "", - "show": True, - "multiline": False, - "value": "input", - "password": False, - "name": "input_key", - "type": "str", - "list": False, - } - - assert template["template"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": True, - "value": "I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.\n Current conversation:\n {history}\n Human: {input}\n AI:", # noqa: E501 - "password": False, - "name": "template", - "type": "str", - "list": False, - } - assert template["ai_prefix_value"] == { - "required": False, - "placeholder": "", - "show": False, - "multiline": False, - "password": False, - "name": "ai_prefix_value", - "type": "str", - "list": False, - } - assert chain["description"] == "" + assert chain["description"] == "Time travel guide chain to be used in the flow." diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index 934ec7e05..45b3b032d 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -28,6 +28,7 @@ def test_hugging_face_hub(client: TestClient): "name": "cache", "type": "bool", "list": False, + "advanced": True, } assert template["verbose"] == { "required": False, @@ -39,6 +40,7 @@ def test_hugging_face_hub(client: TestClient): "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["client"] == { "required": False, @@ -49,6 +51,7 @@ def test_hugging_face_hub(client: TestClient): "name": "client", "type": "Any", "list": False, + "advanced": True, } assert template["repo_id"] == { "required": False, @@ -60,6 +63,7 @@ def test_hugging_face_hub(client: TestClient): "name": "repo_id", "type": "str", "list": False, + "advanced": True, } assert template["task"] == { "required": True, @@ -71,6 +75,7 @@ def test_hugging_face_hub(client: TestClient): "name": "task", "type": "str", "list": True, + "advanced": True, } assert template["model_kwargs"] == { "required": False, @@ -81,6 +86,7 @@ def test_hugging_face_hub(client: TestClient): "name": "model_kwargs", "type": "code", "list": False, + "advanced": True, } assert template["huggingfacehub_api_token"] == { "required": False, @@ -92,6 +98,7 @@ def test_hugging_face_hub(client: TestClient): "display_name": "HuggingFace Hub API Token", "type": "str", "list": False, + "advanced": False, } @@ -113,6 +120,7 @@ def test_openai(client: TestClient): "name": "cache", "type": "bool", "list": False, + "advanced": True, } assert template["verbose"] == { "required": False, @@ -123,6 +131,7 @@ def test_openai(client: TestClient): "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["client"] == { "required": False, @@ -133,6 +142,7 @@ def test_openai(client: TestClient): "name": "client", "type": "Any", "list": False, + "advanced": True, } assert template["model_name"] == { "required": False, @@ -151,6 +161,7 @@ def test_openai(client: TestClient): "name": "model_name", "type": "str", "list": True, + "advanced": False, } # Add more assertions for other properties here assert template["temperature"] == { @@ -163,6 +174,7 @@ def test_openai(client: TestClient): "name": "temperature", "type": "float", "list": False, + "advanced": False, } assert template["max_tokens"] == { "required": False, @@ -174,6 +186,7 @@ def test_openai(client: TestClient): "name": "max_tokens", "type": "int", "list": False, + "advanced": True, } assert template["top_p"] == { "required": False, @@ -185,6 +198,7 @@ def test_openai(client: TestClient): "name": "top_p", "type": "float", "list": False, + "advanced": True, } assert template["frequency_penalty"] == { "required": False, @@ -196,6 +210,7 @@ def test_openai(client: TestClient): "name": "frequency_penalty", "type": "float", "list": False, + "advanced": True, } assert template["presence_penalty"] == { "required": False, @@ -207,6 +222,7 @@ def test_openai(client: TestClient): "name": "presence_penalty", "type": "float", "list": False, + "advanced": True, } assert template["n"] == { "required": False, @@ -218,6 +234,7 @@ def test_openai(client: TestClient): "name": "n", "type": "int", "list": False, + "advanced": True, } assert template["best_of"] == { "required": False, @@ -229,6 +246,7 @@ def test_openai(client: TestClient): "name": "best_of", "type": "int", "list": False, + "advanced": True, } assert template["model_kwargs"] == { "required": False, @@ -239,6 +257,7 @@ def test_openai(client: TestClient): "name": "model_kwargs", "type": "code", "list": False, + "advanced": True, } assert template["openai_api_key"] == { "required": False, @@ -251,6 +270,7 @@ def test_openai(client: TestClient): "display_name": "OpenAI API Key", "type": "str", "list": False, + "advanced": False, } assert template["batch_size"] == { "required": False, @@ -262,6 +282,7 @@ def test_openai(client: TestClient): "name": "batch_size", "type": "int", "list": False, + "advanced": True, } assert template["request_timeout"] == { "required": False, @@ -272,6 +293,7 @@ def test_openai(client: TestClient): "name": "request_timeout", "type": "Union[float, Tuple[float, float], NoneType]", "list": False, + "advanced": True, } assert template["logit_bias"] == { "required": False, @@ -282,6 +304,7 @@ def test_openai(client: TestClient): "name": "logit_bias", "type": "code", "list": False, + "advanced": True, } assert template["max_retries"] == { "required": False, @@ -293,6 +316,7 @@ def test_openai(client: TestClient): "name": "max_retries", "type": "int", "list": False, + "advanced": True, } assert template["streaming"] == { "required": False, @@ -304,6 +328,7 @@ def test_openai(client: TestClient): "name": "streaming", "type": "bool", "list": False, + "advanced": True, } @@ -326,6 +351,7 @@ def test_chat_open_ai(client: TestClient): "name": "verbose", "type": "bool", "list": False, + "advanced": True, } assert template["client"] == { "required": False, @@ -336,6 +362,7 @@ def test_chat_open_ai(client: TestClient): "name": "client", "type": "Any", "list": False, + "advanced": True, } assert template["model_name"] == { "required": False, @@ -348,6 +375,7 @@ def test_chat_open_ai(client: TestClient): "name": "model_name", "type": "str", "list": True, + "advanced": False, } assert template["temperature"] == { "required": False, @@ -359,6 +387,7 @@ def test_chat_open_ai(client: TestClient): "name": "temperature", "type": "float", "list": False, + "advanced": False, } assert template["model_kwargs"] == { "required": False, @@ -369,6 +398,7 @@ def test_chat_open_ai(client: TestClient): "name": "model_kwargs", "type": "code", "list": False, + "advanced": True, } assert template["openai_api_key"] == { "required": False, @@ -381,6 +411,7 @@ def test_chat_open_ai(client: TestClient): "display_name": "OpenAI API Key", "type": "str", "list": False, + "advanced": False, } assert template["request_timeout"] == { "required": False, @@ -392,6 +423,7 @@ def test_chat_open_ai(client: TestClient): "name": "request_timeout", "type": "int", "list": False, + "advanced": True, } assert template["max_retries"] == { "required": False, @@ -403,6 +435,7 @@ def test_chat_open_ai(client: TestClient): "name": "max_retries", "type": "int", "list": False, + "advanced": True, } assert template["streaming"] == { "required": False, @@ -414,6 +447,7 @@ def test_chat_open_ai(client: TestClient): "name": "streaming", "type": "bool", "list": False, + "advanced": True, } assert template["n"] == { "required": False, @@ -425,6 +459,7 @@ def test_chat_open_ai(client: TestClient): "name": "n", "type": "int", "list": False, + "advanced": True, } assert template["max_tokens"] == { @@ -436,6 +471,7 @@ def test_chat_open_ai(client: TestClient): "name": "max_tokens", "type": "int", "list": False, + "advanced": True, } assert template["_type"] == "ChatOpenAI" assert ( diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index caa30821c..e5506d4c7 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -27,6 +27,7 @@ def test_prompt_template(client: TestClient): "name": "input_variables", "type": "str", "list": True, + "advanced": True, } assert template["output_parser"] == { "required": False, @@ -37,6 +38,7 @@ def test_prompt_template(client: TestClient): "name": "output_parser", "type": "BaseOutputParser", "list": False, + "advanced": True, } assert template["partial_variables"] == { "required": False, @@ -47,6 +49,7 @@ def test_prompt_template(client: TestClient): "name": "partial_variables", "type": "code", "list": False, + "advanced": True, } assert template["template"] == { "required": True, @@ -57,6 +60,7 @@ def test_prompt_template(client: TestClient): "name": "template", "type": "prompt", "list": False, + "advanced": True, } assert template["template_format"] == { "required": False, @@ -68,6 +72,7 @@ def test_prompt_template(client: TestClient): "name": "template_format", "type": "str", "list": False, + "advanced": True, } assert template["validate_template"] == { "required": False, @@ -79,6 +84,7 @@ def test_prompt_template(client: TestClient): "name": "validate_template", "type": "bool", "list": False, + "advanced": True, } @@ -100,6 +106,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "examples", "type": "prompt", "list": True, + "advanced": True, } assert template["example_selector"] == { "required": False, @@ -110,6 +117,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "example_selector", "type": "BaseExampleSelector", "list": False, + "advanced": True, } assert template["example_prompt"] == { "required": True, @@ -120,6 +128,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "example_prompt", "type": "PromptTemplate", "list": False, + "advanced": True, } assert template["suffix"] == { "required": True, @@ -130,6 +139,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "suffix", "type": "prompt", "list": False, + "advanced": True, } assert template["example_separator"] == { "required": False, @@ -141,6 +151,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "example_separator", "type": "str", "list": False, + "advanced": True, } assert template["prefix"] == { "required": False, @@ -152,6 +163,7 @@ def test_few_shot_prompt_template(client: TestClient): "name": "prefix", "type": "prompt", "list": False, + "advanced": True, } @@ -172,6 +184,7 @@ def test_zero_shot_prompt(client: TestClient): "name": "prefix", "type": "str", "list": False, + "advanced": True, } assert template["suffix"] == { "required": True, @@ -183,6 +196,7 @@ def test_zero_shot_prompt(client: TestClient): "name": "suffix", "type": "str", "list": False, + "advanced": True, } assert template["format_instructions"] == { "required": False, @@ -194,4 +208,5 @@ def test_zero_shot_prompt(client: TestClient): "name": "format_instructions", "type": "str", "list": False, + "advanced": True, }