From e2c53f1166d54590414c2bd7ef762bc7e2db2238 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 26 Feb 2024 20:37:10 -0300 Subject: [PATCH] Refactor ConversationChain and LLMCheckerChain components --- .../components/chains/ConversationChain.py | 6 ++--- .../components/chains/LLMCheckerChain.py | 19 ++++++++++----- .../components/chains/LLMMathChain.py | 23 +++++++++++++++---- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 43f71e67b..3183954a3 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -1,9 +1,9 @@ -from typing import Callable, Optional, Union +from typing import Optional from langchain.chains import ConversationChain from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text +from langflow.field_typing import BaseLanguageModel, BaseMemory, Text class ConversationChainComponent(CustomComponent): @@ -26,7 +26,7 @@ class ConversationChainComponent(CustomComponent): inputs: str, llm: BaseLanguageModel, memory: Optional[BaseMemory] = None, - ) -> Union[Chain, Callable, Text]: + ) -> Text: if memory is None: chain = ConversationChain(llm=llm) else: diff --git a/src/backend/langflow/components/chains/LLMCheckerChain.py b/src/backend/langflow/components/chains/LLMCheckerChain.py index 527cafbb7..bfee0b5a9 100644 --- a/src/backend/langflow/components/chains/LLMCheckerChain.py +++ b/src/backend/langflow/components/chains/LLMCheckerChain.py @@ -1,14 +1,15 @@ -from typing import Callable, Union - from langchain.chains import LLMCheckerChain + from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, Chain +from langflow.field_typing import BaseLanguageModel, Text class LLMCheckerChainComponent(CustomComponent): display_name = "LLMCheckerChain" description = "" - documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker" + documentation = ( + "https://python.langchain.com/docs/modules/chains/additional/llm_checker" + ) def build_config(self): return { @@ -17,6 +18,12 @@ class LLMCheckerChainComponent(CustomComponent): def build( self, + inputs: str, llm: BaseLanguageModel, - ) -> Union[Chain, Callable]: - return LLMCheckerChain.from_llm(llm=llm) + ) -> Text: + + chain = LLMCheckerChain.from_llm(llm=llm) + response = chain.invoke({chain.input_key: inputs}) + result = response.get(chain.output_key) + self.status = result + return result diff --git a/src/backend/langflow/components/chains/LLMMathChain.py b/src/backend/langflow/components/chains/LLMMathChain.py index 28f430e6d..919de34e6 100644 --- a/src/backend/langflow/components/chains/LLMMathChain.py +++ b/src/backend/langflow/components/chains/LLMMathChain.py @@ -1,15 +1,17 @@ -from typing import Callable, Optional, Union +from typing import Optional from langchain.chains import LLMChain, LLMMathChain from langflow import CustomComponent -from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain +from langflow.field_typing import BaseLanguageModel, BaseMemory, Text class LLMMathChainComponent(CustomComponent): display_name = "LLMMathChain" description = "Chain that interprets a prompt and executes python code to do math." - documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math" + documentation = ( + "https://python.langchain.com/docs/modules/chains/additional/llm_math" + ) def build_config(self): return { @@ -22,10 +24,21 @@ class LLMMathChainComponent(CustomComponent): def build( self, + inputs: Text, llm: BaseLanguageModel, llm_chain: LLMChain, input_key: str = "question", output_key: str = "answer", memory: Optional[BaseMemory] = None, - ) -> Union[LLMMathChain, Callable, Chain]: - return LLMMathChain(llm=llm, llm_chain=llm_chain, input_key=input_key, output_key=output_key, memory=memory) + ) -> Text: + chain = LLMMathChain( + llm=llm, + llm_chain=llm_chain, + input_key=input_key, + output_key=output_key, + memory=memory, + ) + response = chain.invoke({input_key: inputs}) + result = response.get(output_key) + self.status = result + return result