diff --git a/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py b/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py index e3db93c7e..a7fafc7de 100644 --- a/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py +++ b/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py @@ -5,6 +5,7 @@ from langflow.inputs import StrInput, SecretStrInput, DefaultPromptField from langflow.io import Output from langflow.schema.message import Message +from langchain_core.prompts import HumanMessagePromptTemplate import re @@ -40,9 +41,15 @@ class LangChainHubPromptComponent(Component): if field_name == "langchain_hub_prompt": template = self._fetch_langchain_hub_template() + # Get the template's messages + if hasattr(template, "messages"): + template_messages = template.messages + else: + template_messages = [HumanMessagePromptTemplate(prompt=template)] + # Extract the messages from the prompt data prompt_template = [] - for message_data in template.messages: + for message_data in template_messages: prompt_template.append(message_data.prompt) # Regular expression to find all instances of {} @@ -103,6 +110,10 @@ class LangChainHubPromptComponent(Component): def _fetch_langchain_hub_template(self): import langchain.hub + # Check if the api key is provided + if not self.langchain_api_key: + raise ValueError("Please provide a LangChain API Key") + # Pull the prompt from LangChain Hub prompt_data = langchain.hub.pull(self.langchain_hub_prompt, api_key=self.langchain_api_key)