diff --git a/poetry.lock b/poetry.lock index 0826e49c8..92ded8aff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -232,14 +232,14 @@ files = [ [[package]] name = "beautifulsoup4" -version = "4.12.0" +version = "4.12.1" description = "Screen-scraping library" category = "main" optional = false python-versions = ">=3.6.0" files = [ - {file = "beautifulsoup4-4.12.0-py3-none-any.whl", hash = "sha256:2130a5ad7f513200fae61a17abb5e338ca980fa28c439c0571014bc0217e9591"}, - {file = "beautifulsoup4-4.12.0.tar.gz", hash = "sha256:c5fceeaec29d09c84970e47c65f2f0efe57872f7cff494c9691a26ec0ff13234"}, + {file = "beautifulsoup4-4.12.1-py3-none-any.whl", hash = "sha256:e44795bb4f156d94abb5fbc56efff871c1045bfef72e9efe77558db9f9616ac3"}, + {file = "beautifulsoup4-4.12.1.tar.gz", hash = "sha256:c7bdbfb20a0dbe09518b96a809d93351b2e2bcb8046c0809466fa6632a10c257"}, ] [package.dependencies] @@ -699,29 +699,30 @@ dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest ( [[package]] name = "debugpy" -version = "1.6.6" +version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "debugpy-1.6.6-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:0ea1011e94416e90fb3598cc3ef5e08b0a4dd6ce6b9b33ccd436c1dffc8cd664"}, - {file = "debugpy-1.6.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dff595686178b0e75580c24d316aa45a8f4d56e2418063865c114eef651a982e"}, - {file = "debugpy-1.6.6-cp310-cp310-win32.whl", hash = "sha256:87755e173fcf2ec45f584bb9d61aa7686bb665d861b81faa366d59808bbd3494"}, - {file = "debugpy-1.6.6-cp310-cp310-win_amd64.whl", hash = "sha256:72687b62a54d9d9e3fb85e7a37ea67f0e803aaa31be700e61d2f3742a5683917"}, - {file = "debugpy-1.6.6-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:78739f77c58048ec006e2b3eb2e0cd5a06d5f48c915e2fc7911a337354508110"}, - {file = "debugpy-1.6.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23c29e40e39ad7d869d408ded414f6d46d82f8a93b5857ac3ac1e915893139ca"}, - {file = "debugpy-1.6.6-cp37-cp37m-win32.whl", hash = "sha256:7aa7e103610e5867d19a7d069e02e72eb2b3045b124d051cfd1538f1d8832d1b"}, - {file = "debugpy-1.6.6-cp37-cp37m-win_amd64.whl", hash = "sha256:f6383c29e796203a0bba74a250615ad262c4279d398e89d895a69d3069498305"}, - {file = "debugpy-1.6.6-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:23363e6d2a04d726bbc1400bd4e9898d54419b36b2cdf7020e3e215e1dcd0f8e"}, - {file = "debugpy-1.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b5d1b13d7c7bf5d7cf700e33c0b8ddb7baf030fcf502f76fc061ddd9405d16c"}, - {file = "debugpy-1.6.6-cp38-cp38-win32.whl", hash = "sha256:70ab53918fd907a3ade01909b3ed783287ede362c80c75f41e79596d5ccacd32"}, - {file = "debugpy-1.6.6-cp38-cp38-win_amd64.whl", hash = "sha256:c05349890804d846eca32ce0623ab66c06f8800db881af7a876dc073ac1c2225"}, - {file = "debugpy-1.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a771739902b1ae22a120dbbb6bd91b2cae6696c0e318b5007c5348519a4211c6"}, - {file = "debugpy-1.6.6-cp39-cp39-win32.whl", hash = "sha256:549ae0cb2d34fc09d1675f9b01942499751d174381b6082279cf19cdb3c47cbe"}, - {file = "debugpy-1.6.6-cp39-cp39-win_amd64.whl", hash = "sha256:de4a045fbf388e120bb6ec66501458d3134f4729faed26ff95de52a754abddb1"}, - {file = "debugpy-1.6.6-py2.py3-none-any.whl", hash = "sha256:be596b44448aac14eb3614248c91586e2bc1728e020e82ef3197189aae556115"}, - {file = "debugpy-1.6.6.zip", hash = "sha256:b9c2130e1c632540fbf9c2c88341493797ddf58016e7cba02e311de9b0a96b67"}, + {file = "debugpy-1.6.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b3e7ac809b991006ad7f857f016fa92014445085711ef111fdc3f74f66144096"}, + {file = "debugpy-1.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3876611d114a18aafef6383695dfc3f1217c98a9168c1aaf1a02b01ec7d8d1e"}, + {file = "debugpy-1.6.7-cp310-cp310-win32.whl", hash = "sha256:33edb4afa85c098c24cc361d72ba7c21bb92f501104514d4ffec1fb36e09c01a"}, + {file = "debugpy-1.6.7-cp310-cp310-win_amd64.whl", hash = "sha256:ed6d5413474e209ba50b1a75b2d9eecf64d41e6e4501977991cdc755dc83ab0f"}, + {file = "debugpy-1.6.7-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:38ed626353e7c63f4b11efad659be04c23de2b0d15efff77b60e4740ea685d07"}, + {file = "debugpy-1.6.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:279d64c408c60431c8ee832dfd9ace7c396984fd7341fa3116aee414e7dcd88d"}, + {file = "debugpy-1.6.7-cp37-cp37m-win32.whl", hash = "sha256:dbe04e7568aa69361a5b4c47b4493d5680bfa3a911d1e105fbea1b1f23f3eb45"}, + {file = "debugpy-1.6.7-cp37-cp37m-win_amd64.whl", hash = "sha256:f90a2d4ad9a035cee7331c06a4cf2245e38bd7c89554fe3b616d90ab8aab89cc"}, + {file = "debugpy-1.6.7-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5224eabbbeddcf1943d4e2821876f3e5d7d383f27390b82da5d9558fd4eb30a9"}, + {file = "debugpy-1.6.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae1123dff5bfe548ba1683eb972329ba6d646c3a80e6b4c06cd1b1dd0205e9b"}, + {file = "debugpy-1.6.7-cp38-cp38-win32.whl", hash = "sha256:9cd10cf338e0907fdcf9eac9087faa30f150ef5445af5a545d307055141dd7a4"}, + {file = "debugpy-1.6.7-cp38-cp38-win_amd64.whl", hash = "sha256:aaf6da50377ff4056c8ed470da24632b42e4087bc826845daad7af211e00faad"}, + {file = "debugpy-1.6.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:0679b7e1e3523bd7d7869447ec67b59728675aadfc038550a63a362b63029d2c"}, + {file = "debugpy-1.6.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de86029696e1b3b4d0d49076b9eba606c226e33ae312a57a46dca14ff370894d"}, + {file = "debugpy-1.6.7-cp39-cp39-win32.whl", hash = "sha256:d71b31117779d9a90b745720c0eab54ae1da76d5b38c8026c654f4a066b0130a"}, + {file = "debugpy-1.6.7-cp39-cp39-win_amd64.whl", hash = "sha256:c0ff93ae90a03b06d85b2c529eca51ab15457868a377c4cc40a23ab0e4e552a3"}, + {file = "debugpy-1.6.7-py2.py3-none-any.whl", hash = "sha256:53f7a456bc50706a0eaabecf2d3ce44c4d5010e46dfc65b6b81a518b42866267"}, + {file = "debugpy-1.6.7.zip", hash = "sha256:c4c2f0810fa25323abfdfa36cbbbb24e5c3b1a42cb762782de64439c575d67f2"}, ] [[package]] @@ -985,14 +986,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"] [[package]] name = "google-api-python-client" -version = "2.83.0" +version = "2.84.0" description = "Google API Client Library for Python" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.83.0.tar.gz", hash = "sha256:d07509f1b2d2b2427363b454db996f7a15e1751a48cfcaf28427050560dd51cf"}, - {file = "google_api_python_client-2.83.0-py2.py3-none-any.whl", hash = "sha256:afa7fe2a5d77e8f136cdb8f40a120dd6660c2292f791c1b22734dfe786bd1dac"}, + {file = "google-api-python-client-2.84.0.tar.gz", hash = "sha256:c398fd6f9ead0be23aade3b2704c72c5146df0e3352d8ff9101286077e1b010a"}, + {file = "google_api_python_client-2.84.0-py2.py3-none-any.whl", hash = "sha256:83041bb895863225ecdd9c59dd58565fa48c57c2f10fe06f7c08da7c42c53abc"}, ] [package.dependencies] @@ -1004,14 +1005,14 @@ uritemplate = ">=3.0.1,<5" [[package]] name = "google-auth" -version = "2.17.1" +version = "2.17.2" description = "Google Authentication Library" category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" files = [ - {file = "google-auth-2.17.1.tar.gz", hash = "sha256:8f379b46bad381ad2a0b989dfb0c13ad28d3c2a79f27348213f8946a1d15d55a"}, - {file = "google_auth-2.17.1-py2.py3-none-any.whl", hash = "sha256:357ff22a75b4c0f6093470f21816a825d2adee398177569824e37b6c10069e19"}, + {file = "google-auth-2.17.2.tar.gz", hash = "sha256:295c80ebb95eac74003c07a696cf3ef6b414e9230ae8894f3843f8215fd2aa56"}, + {file = "google_auth-2.17.2-py2.py3-none-any.whl", hash = "sha256:544536a43d44dff0f64222e4d027d124989fcb9c10979687e589e1694fba9c94"}, ] [package.dependencies] @@ -1314,14 +1315,14 @@ socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "huggingface-hub" -version = "0.13.3" +version = "0.13.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" category = "main" optional = false python-versions = ">=3.7.0" files = [ - {file = "huggingface_hub-0.13.3-py3-none-any.whl", hash = "sha256:f73a298a55028575334f9670d86b8171a4dd890b320315f3ad28a20b9eb3b5bc"}, - {file = "huggingface_hub-0.13.3.tar.gz", hash = "sha256:1f95f65c5e7aa76728701402f55b697ee8a8b50234adda91fbdbb81038fbcd21"}, + {file = "huggingface_hub-0.13.4-py3-none-any.whl", hash = "sha256:4d3d40593de6673d624a4baaaf249b9bf5165bfcafd1ad58de361931f0b4fda5"}, + {file = "huggingface_hub-0.13.4.tar.gz", hash = "sha256:db83d9c2f76aed8cf49893ffadd6be24e82074da2f64b1d36b8ba40eb255e115"}, ] [package.dependencies] @@ -1650,7 +1651,7 @@ tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1786,7 +1787,7 @@ traitlets = "*" name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2694,7 +2695,7 @@ email = ["email-validator (>=1.0.3)"] name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3064,7 +3065,7 @@ idna2008 = ["idna"] name = "rich" version = "13.3.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "dev" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -4312,4 +4313,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a58931b38efad96240042e040a28584ebb30949f60249466c94440593a61052c" +content-hash = "70e86f7d3b5caed792e37ccf9e11ed95008e5078dd8830e4f8b96cc1d35c7b60" diff --git a/pyproject.toml b/pyproject.toml index 2ea704b7d..b6b30d805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langflow" -version = "0.0.52" +version = "0.0.54" description = "A Python package with a built-in web application" authors = ["Logspace "] maintainers = [ @@ -35,6 +35,8 @@ types-pyyaml = "^6.0.12.8" dill = "^0.3.6" pandas = "^1.5.3" chromadb = "^0.3.21" +huggingface-hub = "^0.13.3" +rich = "^13.3.3" [tool.poetry.group.dev.dependencies] black = "^23.1.0" @@ -42,7 +44,6 @@ ipykernel = "^6.21.2" mypy = "^1.1.1" ruff = "^0.0.254" httpx = "^0.23.3" -rich = "^13.3.3" pytest = "^7.2.2" types-requests = "^2.28.11" requests = "^2.28.0" diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 9096aff1b..084e04d65 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -1,5 +1,7 @@ from pydantic import BaseModel, validator +from langflow.graph.utils import extract_input_variables_from_prompt + class Code(BaseModel): code: str @@ -25,3 +27,54 @@ class CodeValidationResponse(BaseModel): class PromptValidationResponse(BaseModel): input_variables: list + + +INVALID_CHARACTERS = { + " ", + ",", + ".", + ":", + ";", + "!", + "?", + "/", + "\\", + "(", + ")", + "[", + "]", + "{", + "}", +} + + +def validate_prompt(template: str): + input_variables = extract_input_variables_from_prompt(template) + + # Check if there are invalid characters in the input_variables + input_variables = check_input_variables(input_variables) + + return PromptValidationResponse(input_variables=input_variables) + + +def check_input_variables(input_variables: list): + invalid_chars = [] + fixed_variables = [] + for variable in input_variables: + new_var = variable + for char in INVALID_CHARACTERS: + if char in variable: + invalid_chars.append(char) + new_var = new_var.replace(char, "") + fixed_variables.append(new_var) + if new_var != variable: + input_variables.remove(variable) + input_variables.append(new_var) + # If any of the input_variables is not in the fixed_variables, then it means that + # there are invalid characters in the input_variables + if any(var not in fixed_variables for var in input_variables): + raise ValueError( + f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead." + ) + + return input_variables diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py index 6dea45df0..a60bcc506 100644 --- a/src/backend/langflow/api/validate.py +++ b/src/backend/langflow/api/validate.py @@ -5,8 +5,8 @@ from langflow.api.base import ( CodeValidationResponse, Prompt, PromptValidationResponse, + validate_prompt, ) -from langflow.graph.utils import extract_input_variables_from_prompt from langflow.utils.logger import logger from langflow.utils.validate import validate_code @@ -29,8 +29,7 @@ def post_validate_code(code: Code): @router.post("/prompt", status_code=200, response_model=PromptValidationResponse) def post_validate_prompt(prompt: Prompt): try: - input_variables = extract_input_variables_from_prompt(prompt.template) - return PromptValidationResponse(input_variables=input_variables) + return validate_prompt(prompt.template) except Exception as e: logger.exception(e) - return HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 8036dba25..be7e82099 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -26,8 +26,9 @@ prompts: llms: - OpenAI - - AzureOpenAI + # - AzureOpenAI - ChatOpenAI + - HuggingFaceHub tools: - Search diff --git a/src/backend/langflow/graph/graph.py b/src/backend/langflow/graph/graph.py index 9cbeb94a3..a7f908ef2 100644 --- a/src/backend/langflow/graph/graph.py +++ b/src/backend/langflow/graph/graph.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, Type, Union from langflow.graph.base import Edge, Node from langflow.graph.nodes import ( @@ -25,7 +25,6 @@ from langflow.interface.prompts.base import prompt_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.tools.constants import FILE_TOOLS -from langflow.interface.tools.util import get_tools_dict from langflow.interface.vectorStore.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator from langflow.utils import payload @@ -114,6 +113,29 @@ class Graph: edges.append(Edge(source, target)) return edges + def _get_node_class(self, node_type: str, node_lc_type: str) -> Type[Node]: + node_type_map: Dict[str, Type[Node]] = { + **{t: PromptNode for t in prompt_creator.to_list()}, + **{t: AgentNode for t in agent_creator.to_list()}, + **{t: ChainNode for t in chain_creator.to_list()}, + **{t: ToolNode for t in tool_creator.to_list()}, + **{t: ToolkitNode for t in toolkits_creator.to_list()}, + **{t: WrapperNode for t in wrapper_creator.to_list()}, + **{t: LLMNode for t in llm_creator.to_list()}, + **{t: MemoryNode for t in memory_creator.to_list()}, + **{t: EmbeddingNode for t in embedding_creator.to_list()}, + **{t: VectorStoreNode for t in vectorstore_creator.to_list()}, + **{t: DocumentLoaderNode for t in documentloader_creator.to_list()}, + } + + if node_type in FILE_TOOLS: + return FileToolNode + if node_type in node_type_map: + return node_type_map[node_type] + if node_lc_type in node_type_map: + return node_type_map[node_lc_type] + return Node + def _build_nodes(self) -> List[Node]: nodes: List[Node] = [] for node in self._nodes: @@ -121,44 +143,9 @@ class Graph: node_type: str = node_data["type"] # type: ignore node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore - if node_type in prompt_creator.to_list(): - nodes.append(PromptNode(node)) - elif ( - node_type in agent_creator.to_list() - or node_lc_type in agent_creator.to_list() - ): - nodes.append(AgentNode(node)) - elif node_type in chain_creator.to_list(): - nodes.append(ChainNode(node)) - elif ( - node_type in tool_creator.to_list() - or node_lc_type in get_tools_dict().keys() - ): - if node_type in FILE_TOOLS: - nodes.append(FileToolNode(node)) - nodes.append(ToolNode(node)) - elif node_type in toolkits_creator.to_list(): - nodes.append(ToolkitNode(node)) - elif node_type in wrapper_creator.to_list(): - nodes.append(WrapperNode(node)) - elif ( - node_type in llm_creator.to_list() - or node_lc_type in llm_creator.to_list() - ): - nodes.append(LLMNode(node)) - elif node_type in embedding_creator.to_list(): - nodes.append(EmbeddingNode(node)) - elif node_type in vectorstore_creator.to_list(): - nodes.append(VectorStoreNode(node)) - elif node_type in documentloader_creator.to_list(): - nodes.append(DocumentLoaderNode(node)) - elif ( - node_type in memory_creator.to_list() - or node_lc_type in memory_creator.to_list() - ): - nodes.append(MemoryNode(node)) - else: - nodes.append(Node(node)) + NodeClass = self._get_node_class(node_type, node_lc_type) + nodes.append(NodeClass(node)) + return nodes def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]: diff --git a/src/backend/langflow/interface/chains/custom.py b/src/backend/langflow/interface/chains/custom.py index 98470d54b..cb76a53c8 100644 --- a/src/backend/langflow/interface/chains/custom.py +++ b/src/backend/langflow/interface/chains/custom.py @@ -19,7 +19,7 @@ class BaseCustomChain(ConversationChain): template: Optional[str] - ai_prefix_key: Optional[str] + ai_prefix_value: Optional[str] """Field to use as the ai_prefix. It needs to be set and has to be in the template""" @root_validator(pre=False) @@ -27,13 +27,13 @@ class BaseCustomChain(ConversationChain): format_dict = {} input_variables = extract_input_variables_from_prompt(values["template"]) - if values.get("ai_prefix_key", None) is None: - values["ai_prefix_key"] = values["memory"].ai_prefix + if values.get("ai_prefix_value", None) is None: + values["ai_prefix_value"] = values["memory"].ai_prefix for key in input_variables: new_value = values.get(key, f"{{{key}}}") format_dict[key] = new_value - if key == values.get("ai_prefix_key", None): + if key == values.get("ai_prefix_value", None): values["memory"].ai_prefix = new_value values["template"] = values["template"].format(**format_dict) @@ -62,7 +62,7 @@ Current conversation: Human: {input} {character}:""" memory: BaseMemory = Field(default_factory=ConversationBufferMemory) - ai_prefix_key: Optional[str] = "character" + ai_prefix_value: Optional[str] = "character" """Default memory store.""" diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index 4c053b19d..0fe665f67 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -178,12 +178,14 @@ class FrontendNode(BaseModel): field.show = bool( (field.required and key not in ["input_variables"]) or key in FORCE_SHOW_FIELDS - or "api_key" in key + or "api" in key + or ("key" in key and "input" not in key and "output" not in key) ) # Add password field - field.password = any( - text in key.lower() for text in {"password", "token", "api", "key"} + field.password = ( + any(text in key.lower() for text in {"password", "token", "api", "key"}) + and field.show ) # Add multline diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 3777eb80f..e8466c3cb 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -309,13 +309,22 @@ class PromptFrontendNode(FrontendNode): def format_field(field: TemplateField, name: Optional[str] = None) -> None: # if field.field_type == "StringPromptTemplate" # change it to str + PROMPT_FIELDS = [ + "template", + "suffix", + "prefix", + "examples", + ] if field.field_type == "StringPromptTemplate" and "Message" in str(name): - field.field_type = "str" + field.field_type = "prompt" field.multiline = True field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT if field.name == "template" and field.value == "": field.value = DEFAULT_PROMPT + if field.name in PROMPT_FIELDS: + field.field_type = "prompt" + if ( "Union" in field.field_type and "BaseMessagePromptTemplate" in field.field_type diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index 6bf038066..1a2eeb9e6 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -14,6 +14,7 @@ import CodeAreaComponent from "../../../../components/codeAreaComponent"; import InputFileComponent from "../../../../components/inputFileComponent"; import { TabsContext } from "../../../../contexts/tabsContext"; import IntComponent from "../../../../components/intComponent"; +import PromptAreaComponent from "../../../../components/promptComponent"; export default function ParameterComponent({ left, @@ -63,6 +64,7 @@ export default function ParameterComponent({ type === "bool" || type === "float" || type === "code" || + type === "prompt" || type === "file" || type === "int") ? ( <> @@ -187,9 +189,16 @@ export default function ParameterComponent({ save(); }} /> - ) : ( - <> - )} + ) : left === true && type === "prompt" ? ( + { + data.node.template[name].value = t; + save(); + }} + /> + ):(<>)} ); diff --git a/src/frontend/src/components/chatComponent/index.tsx b/src/frontend/src/components/chatComponent/index.tsx index 59cc54b4c..11516f004 100644 --- a/src/frontend/src/components/chatComponent/index.tsx +++ b/src/frontend/src/components/chatComponent/index.tsx @@ -14,10 +14,11 @@ import { } from "react"; import { sendAll } from "../../controllers/API"; import { alertContext } from "../../contexts/alertContext"; -import { classNames, nodeColors } from "../../utils"; +import { classNames, nodeColors, snakeToNormalCase } from "../../utils"; import { TabsContext } from "../../contexts/tabsContext"; import { ChatType } from "../../types/chat"; import ChatMessage from "./chatMessage"; +import { NodeType } from "../../types/flow"; const _ = require("lodash"); @@ -28,7 +29,7 @@ export default function Chat({ flow, reactFlowInstance }: ChatType) { const [open, setOpen] = useState(true); const [chatValue, setChatValue] = useState(""); const [chatHistory, setChatHistory] = useState(flow.chat); - const { setErrorData } = useContext(alertContext); + const { setErrorData, setNoticeData } = useContext(alertContext); const addChatHistory = ( message: string, isSend: boolean, @@ -73,36 +74,58 @@ export default function Chat({ flow, reactFlowInstance }: ChatType) { useEffect(() => { if (ref.current) ref.current.scrollIntoView({ behavior: "smooth" }); }, [chatHistory]); - function validateNodes() { - if ( - reactFlowInstance.getNodes().some( - (n) => - n.data.node && - Object.keys(n.data.node.template).some((t: any) => { - return ( - n.data.node.template[t].required && - (!n.data.node.template[t].value || - n.data.node.template[t].value === "") && - !reactFlowInstance - .getEdges() - .some( - (e) => - e.targetHandle.split("|")[1] === t && - e.targetHandle.split("|")[2] === n.id - ) - ); - }) - ) - ) { - return false; + + function validateNode(n: NodeType): Array { + if (!n.data?.node?.template || !Object.keys(n.data.node.template)) { + setNoticeData({ + title: + "We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!", + }); + return []; } - return true; + + const { + type, + node: { template }, + } = n.data; + + return Object.keys(template).reduce( + (errors: Array, t) => + errors.concat( + (template[t].required && template[t].show) && + (!template[t].value || template[t].value === "") && + !reactFlowInstance + .getEdges() + .some( + (e) => + e.targetHandle.split("|")[1] === t && + e.targetHandle.split("|")[2] === n.id + ) + ? [ + `${type} is missing ${ + template.display_name + ? template.display_name + : snakeToNormalCase(template[t].name) + }.`, + ] + : [] + ), + [] as string[] + ); } + + function validateNodes() { + return reactFlowInstance + .getNodes() + .flatMap((n: NodeType) => validateNode(n)); + } + const ref = useRef(null); function sendMessage() { if (chatValue !== "") { - if (validateNodes()) { + let nodeValidationErrors = validateNodes(); + if (nodeValidationErrors.length === 0) { setLockChat(true); let message = chatValue; setChatValue(""); @@ -136,10 +159,8 @@ export default function Chat({ flow, reactFlowInstance }: ChatType) { }); } else { setErrorData({ - title: "Error sending message", - list: [ - "Oops! Looks like you missed some required information. Please fill in all the required fields before continuing.", - ], + title: "Oops! Looks like you missed some required information:", + list: nodeValidationErrors, }); } } else { diff --git a/src/frontend/src/components/promptComponent/index.tsx b/src/frontend/src/components/promptComponent/index.tsx new file mode 100644 index 000000000..6ad51f87b --- /dev/null +++ b/src/frontend/src/components/promptComponent/index.tsx @@ -0,0 +1,35 @@ +import { ArrowTopRightOnSquareIcon } from "@heroicons/react/24/outline"; +import { useContext, useEffect, useState } from "react"; +import { PopUpContext } from "../../contexts/popUpContext"; +import CodeAreaModal from "../../modals/codeAreaModal"; +import TextAreaModal from "../../modals/textAreaModal"; +import { TextAreaComponentType } from "../../types/components"; +import PromptAreaModal from "../../modals/promptModal"; + +export default function PromptAreaComponent({ value, onChange, disabled }:TextAreaComponentType) { + const [myValue, setMyValue] = useState(value); + const { openPopUp } = useContext(PopUpContext); + useEffect(() => { + if (disabled) { + setMyValue(""); + onChange(""); + } + }, [disabled, onChange]); + return ( +
+
+ + {myValue !== "" ? myValue : 'Text empty'} + + +
+
+ ); +} diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index bad966ea9..8fa7ff527 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -1,4 +1,4 @@ -import { errorsTypeAPI } from './../../types/api/index'; +import { PromptTypeAPI, errorsTypeAPI } from './../../types/api/index'; import { APIObjectType, sendAllProps } from '../../types/api/index'; import axios, { AxiosResponse } from "axios"; @@ -13,4 +13,9 @@ export async function sendAll(data:sendAllProps) { export async function checkCode(code:string):Promise>{ return await axios.post('/validate/code',{code}) +} + +export async function checkPrompt(template:string):Promise>{ + + return await axios.post('/validate/prompt',{template}) } \ No newline at end of file diff --git a/src/frontend/src/modals/codeAreaModal/index.tsx b/src/frontend/src/modals/codeAreaModal/index.tsx index 72b40ed40..db4de60f4 100644 --- a/src/frontend/src/modals/codeAreaModal/index.tsx +++ b/src/frontend/src/modals/codeAreaModal/index.tsx @@ -96,7 +96,6 @@ export default function CodeAreaModal({
- {/* need to insert code editor */} void; + value: string; +}) { + const [open, setOpen] = useState(true); + const [myValue, setMyValue] = useState(value); + const { dark } = useContext(darkContext); + const { setErrorData, setSuccessData } = useContext(alertContext); + const { closePopUp } = useContext(PopUpContext); + const ref = useRef(); + function setModalOpen(x: boolean) { + setOpen(x); + if (x === false) { + setTimeout(() => { + closePopUp(); + }, 300); + } + } + return ( + + + +
+ + +
+
+ + +
+ +
+
+
+
+
+
+ + Edit Prompt + +
+
+
+
+
+