Refactor PromptRunner class to use langchain_core.messages.BaseMessage
This commit is contained in:
parent
d3c2d2f893
commit
7a3d057a50
1 changed files with 13 additions and 8 deletions
|
|
@ -1,8 +1,9 @@
|
|||
from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
class PromptRunner(CustomComponent):
|
||||
|
|
@ -18,11 +19,15 @@ class PromptRunner(CustomComponent):
|
|||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document:
|
||||
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Text:
|
||||
chain = prompt | llm
|
||||
# The input is an empty dict because the prompt is already filled
|
||||
result = chain.invoke(input=inputs)
|
||||
if hasattr(result, "content"):
|
||||
result = result.content
|
||||
result_message: BaseMessage = chain.invoke(input=inputs)
|
||||
if hasattr(result_message, "content"):
|
||||
result: str = result_message.content
|
||||
elif isinstance(result_message, str):
|
||||
result = result_message
|
||||
else:
|
||||
result = str(result_message)
|
||||
self.repr_value = result
|
||||
return Document(page_content=str(result))
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue