Refactor ConversationChain and LLMCheckerChain components

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 20:37:10 -03:00
commit e2c53f1166
3 changed files with 34 additions and 14 deletions

View file

@ -1,9 +1,9 @@
from typing import Callable, Optional, Union
from typing import Optional
from langchain.chains import ConversationChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
class ConversationChainComponent(CustomComponent):
@ -26,7 +26,7 @@ class ConversationChainComponent(CustomComponent):
inputs: str,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Union[Chain, Callable, Text]:
) -> Text:
if memory is None:
chain = ConversationChain(llm=llm)
else:

View file

@ -1,14 +1,15 @@
from typing import Callable, Union
from langchain.chains import LLMCheckerChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, Chain
from langflow.field_typing import BaseLanguageModel, Text
class LLMCheckerChainComponent(CustomComponent):
display_name = "LLMCheckerChain"
description = ""
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
documentation = (
"https://python.langchain.com/docs/modules/chains/additional/llm_checker"
)
def build_config(self):
return {
@ -17,6 +18,12 @@ class LLMCheckerChainComponent(CustomComponent):
def build(
self,
inputs: str,
llm: BaseLanguageModel,
) -> Union[Chain, Callable]:
return LLMCheckerChain.from_llm(llm=llm)
) -> Text:
chain = LLMCheckerChain.from_llm(llm=llm)
response = chain.invoke({chain.input_key: inputs})
result = response.get(chain.output_key)
self.status = result
return result

View file

@ -1,15 +1,17 @@
from typing import Callable, Optional, Union
from typing import Optional
from langchain.chains import LLMChain, LLMMathChain
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
class LLMMathChainComponent(CustomComponent):
display_name = "LLMMathChain"
description = "Chain that interprets a prompt and executes python code to do math."
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math"
documentation = (
"https://python.langchain.com/docs/modules/chains/additional/llm_math"
)
def build_config(self):
return {
@ -22,10 +24,21 @@ class LLMMathChainComponent(CustomComponent):
def build(
self,
inputs: Text,
llm: BaseLanguageModel,
llm_chain: LLMChain,
input_key: str = "question",
output_key: str = "answer",
memory: Optional[BaseMemory] = None,
) -> Union[LLMMathChain, Callable, Chain]:
return LLMMathChain(llm=llm, llm_chain=llm_chain, input_key=input_key, output_key=output_key, memory=memory)
) -> Text:
chain = LLMMathChain(
llm=llm,
llm_chain=llm_chain,
input_key=input_key,
output_key=output_key,
memory=memory,
)
response = chain.invoke({input_key: inputs})
result = response.get(output_key)
self.status = result
return result