Refactor ConversationChainComponent build method

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-22 10:40:46 -03:00
commit c70a227ee4

View file

@ -23,19 +23,25 @@ class ConversationChainComponent(CustomComponent):
def build(
self,
inputs: str,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
inputs: dict = {},
) -> Union[Chain, Callable, Text]:
if memory is None:
chain = ConversationChain(llm=llm)
chain = ConversationChain(llm=llm, memory=memory)
else:
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke(inputs)
# result is an AIMessage which is a subclass of BaseMessage
# We need to check if it is a string or a BaseMessage
if hasattr(result, "content") and isinstance(result.content, str):
return result.content
self.status = "is message"
result = result.content
elif isinstance(result, str):
return result
return str(result)
self.status = "is_string"
result = result
else:
# is dict
result = result.get("response")
self.status = result
return result