Refactor model.py to support chat models
This commit is contained in:
parent
60139e9c87
commit
d59f996130
1 changed files with 24 additions and 4 deletions
|
|
@ -1,4 +1,8 @@
|
|||
from langchain_core.runnables import Runnable
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
|
@ -7,7 +11,7 @@ class LCModelComponent(CustomComponent):
|
|||
display_name: str = "Model Name"
|
||||
description: str = "Model Description"
|
||||
|
||||
def get_result(self, output: Runnable, stream: bool, input_value: str):
|
||||
def get_result(self, runnable: LLM, stream: bool, input_value: str):
|
||||
"""
|
||||
Retrieves the result from the output of a Runnable object.
|
||||
|
||||
|
|
@ -20,9 +24,25 @@ class LCModelComponent(CustomComponent):
|
|||
The result obtained from the output object.
|
||||
"""
|
||||
if stream:
|
||||
result = output.stream(input_value)
|
||||
result = runnable.stream(input_value)
|
||||
else:
|
||||
message = output.invoke(input_value)
|
||||
message = runnable.invoke(input_value)
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
self.status = result
|
||||
return result
|
||||
|
||||
def get_chat_result(
|
||||
self, runnable: BaseChatModel, stream: bool, input_value: str, system_message: Optional[str] = None
|
||||
):
|
||||
messages = []
|
||||
if input_value:
|
||||
messages.append(HumanMessage(input_value))
|
||||
if system_message:
|
||||
messages.append(SystemMessage(system_message))
|
||||
if stream:
|
||||
result = runnable.stream(messages)
|
||||
else:
|
||||
message = runnable.invoke(messages)
|
||||
result = message.content
|
||||
self.status = result
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue