Refactor PromptRunner class to use langchain_core.messages.BaseMessage

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-29 23:36:48 -03:00
commit 7a3d057a50

View file

@ -1,8 +1,9 @@
from langflow import CustomComponent
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_core.messages import BaseMessage
from langflow import CustomComponent
from langflow.field_typing import Text
class PromptRunner(CustomComponent):
@ -18,11 +19,15 @@ class PromptRunner(CustomComponent):
"code": {"show": False},
}
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document:
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Text:
chain = prompt | llm
# The input is an empty dict because the prompt is already filled
result = chain.invoke(input=inputs)
if hasattr(result, "content"):
result = result.content
result_message: BaseMessage = chain.invoke(input=inputs)
if hasattr(result_message, "content"):
result: str = result_message.content
elif isinstance(result_message, str):
result = result_message
else:
result = str(result_message)
self.repr_value = result
return Document(page_content=str(result))
return result