From d59f99613077d2e37f82ab973b5e96d34704dc7f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 25 Mar 2024 15:06:06 -0300 Subject: [PATCH] Refactor model.py to support chat models --- src/backend/langflow/base/models/model.py | 28 +++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/base/models/model.py b/src/backend/langflow/base/models/model.py index 9f9ca7b36..e2ab4b6cf 100644 --- a/src/backend/langflow/base/models/model.py +++ b/src/backend/langflow/base/models/model.py @@ -1,4 +1,8 @@ -from langchain_core.runnables import Runnable +from typing import Optional + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import LLM +from langchain_core.messages import HumanMessage, SystemMessage from langflow import CustomComponent @@ -7,7 +11,7 @@ class LCModelComponent(CustomComponent): display_name: str = "Model Name" description: str = "Model Description" - def get_result(self, output: Runnable, stream: bool, input_value: str): + def get_result(self, runnable: LLM, stream: bool, input_value: str): """ Retrieves the result from the output of a Runnable object. @@ -20,9 +24,25 @@ class LCModelComponent(CustomComponent): The result obtained from the output object. """ if stream: - result = output.stream(input_value) + result = runnable.stream(input_value) else: - message = output.invoke(input_value) + message = runnable.invoke(input_value) result = message.content if hasattr(message, "content") else message self.status = result return result + + def get_chat_result( + self, runnable: BaseChatModel, stream: bool, input_value: str, system_message: Optional[str] = None + ): + messages = [] + if input_value: + messages.append(HumanMessage(input_value)) + if system_message: + messages.append(SystemMessage(system_message)) + if stream: + result = runnable.stream(messages) + else: + message = runnable.invoke(messages) + result = message.content + self.status = result + return result