diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index d8cd4a325..26fec2be3 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -73,6 +73,11 @@ prompts: - PromptTemplate - FewShotPromptTemplate - ZeroShotPrompt + - ChatPromptTemplate + - SystemMessagePromptTemplate + - AIMessagePromptTemplate + - HumanMessagePromptTemplate + - ChatMessagePromptTemplate textsplitters: - CharacterTextSplitter - RecursiveCharacterTextSplitter diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b6a8f9a91..775a2ebd2 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -180,11 +180,17 @@ class PromptVertex(Vertex): ] else: prompt_params = ["template"] - for param in prompt_params: - prompt_text = self.params[param] - variables = extract_input_variables_from_prompt(prompt_text) - self.params["input_variables"].extend(variables) - self.params["input_variables"] = list(set(self.params["input_variables"])) + + if "prompt" not in self.params and "messages" not in self.params: + for param in prompt_params: + prompt_text = self.params[param] + variables = extract_input_variables_from_prompt(prompt_text) + self.params["input_variables"].extend(variables) + self.params["input_variables"] = list( + set(self.params["input_variables"]) + ) + else: + self.params.pop("input_variables", None) self._build() return self._built_object diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 88b981f9d..235eabaff 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -97,6 +97,19 @@ def instantiate_prompt(node_type, class_object, params): if "tools" not in params: params["tools"] = [] return ZeroShotAgent.create_prompt(**params) + if "MessagePromptTemplate" in node_type: + # Then we only need the template + from_template_params = { + "template": params.pop("prompt", params.pop("template", "")) + } + + if not from_template_params.get("template"): + raise ValueError("Prompt template is required") + return class_object.from_template(**from_template_params) + + if node_type == "ChatPromptTemplate": + return class_object.from_messages(**params) + return class_object(**params)