Refactor ConversationChainComponent to handle inputs and return result as string if applicable

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-29 23:36:26 -03:00
commit d3c2d2f893

View file

@ -1,7 +1,9 @@
from langflow import CustomComponent
from typing import Callable, Optional, Union
from langchain.chains import ConversationChain
from typing import Optional, Union, Callable
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
class ConversationChainComponent(CustomComponent):
@ -23,7 +25,17 @@ class ConversationChainComponent(CustomComponent):
self,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Union[Chain, Callable]:
inputs: dict = {},
) -> Union[Chain, Callable, Text]:
if memory is None:
return ConversationChain(llm=llm)
return ConversationChain(llm=llm, memory=memory)
chain = ConversationChain(llm=llm)
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke(inputs)
# result is an AIMessage which is a subclass of BaseMessage
# We need to check if it is a string or a BaseMessage
if hasattr(result, "content") and isinstance(result.content, str):
return result.content
elif isinstance(result, str):
return result
return str(result)