diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 76e6d8e25..f3a302fe8 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -1,7 +1,9 @@ -from langflow import CustomComponent +from typing import Callable, Optional, Union + from langchain.chains import ConversationChain -from typing import Optional, Union, Callable -from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain + +from langflow import CustomComponent +from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text class ConversationChainComponent(CustomComponent): @@ -23,7 +25,17 @@ class ConversationChainComponent(CustomComponent): self, llm: BaseLanguageModel, memory: Optional[BaseMemory] = None, - ) -> Union[Chain, Callable]: + inputs: dict = {}, + ) -> Union[Chain, Callable, Text]: if memory is None: - return ConversationChain(llm=llm) - return ConversationChain(llm=llm, memory=memory) + chain = ConversationChain(llm=llm) + chain = ConversationChain(llm=llm, memory=memory) + result = chain.invoke(inputs) + # result is an AIMessage which is a subclass of BaseMessage + # We need to check if it is a string or a BaseMessage + if hasattr(result, "content") and isinstance(result.content, str): + return result.content + elif isinstance(result, str): + return result + + return str(result)