diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 690dc01ba..72f7ca0a7 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -1,10 +1,13 @@ +import warnings from typing import Optional, Union from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import LLM +from langchain_core.load import load from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langflow.custom import CustomComponent +from langflow.schema.schema import Record class LCModelComponent(CustomComponent): @@ -82,7 +85,7 @@ class LCModelComponent(CustomComponent): return status_message def get_chat_result( - self, runnable: BaseChatModel, stream: bool, input_value: str, system_message: Optional[str] = None + self, runnable: BaseChatModel, stream: bool, input_value: str | Record, system_message: Optional[str] = None ): messages: list[Union[HumanMessage, SystemMessage]] = [] if not input_value and not system_message: @@ -90,7 +93,16 @@ class LCModelComponent(CustomComponent): if system_message: messages.append(SystemMessage(content=system_message)) if input_value: - messages.append(HumanMessage(content=input_value)) + if isinstance(input_value, Record): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "prompt" in input_value: + prompt = load(input_value.prompt) + runnable = prompt | runnable + else: + messages.append(input_value.to_lc_message()) + else: + messages.append(HumanMessage(content=input_value)) if stream: return runnable.stream(messages) else: