Refactor ConversationChainComponent to handle inputs and return result as string if applicable
This commit is contained in:
parent
5eb2e7f979
commit
d3c2d2f893
1 changed files with 18 additions and 6 deletions
|
|
@ -1,7 +1,9 @@
|
|||
from langflow import CustomComponent
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from typing import Optional, Union, Callable
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
|
||||
|
||||
|
||||
class ConversationChainComponent(CustomComponent):
|
||||
|
|
@ -23,7 +25,17 @@ class ConversationChainComponent(CustomComponent):
|
|||
self,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
inputs: dict = {},
|
||||
) -> Union[Chain, Callable, Text]:
|
||||
if memory is None:
|
||||
return ConversationChain(llm=llm)
|
||||
return ConversationChain(llm=llm, memory=memory)
|
||||
chain = ConversationChain(llm=llm)
|
||||
chain = ConversationChain(llm=llm, memory=memory)
|
||||
result = chain.invoke(inputs)
|
||||
# result is an AIMessage which is a subclass of BaseMessage
|
||||
# We need to check if it is a string or a BaseMessage
|
||||
if hasattr(result, "content") and isinstance(result.content, str):
|
||||
return result.content
|
||||
elif isinstance(result, str):
|
||||
return result
|
||||
|
||||
return str(result)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue