🚀 feat(langflow): add new chains to config.yaml and custom chains to interface/chains/custom.py

 feat(langflow): add new chains to config.yaml and custom chains to interface/chains/custom.py
The following chains were added to the config.yaml file: RetrievalQA, RetrievalQAWithSourcesChain, QAWithSourcesChain, ConversationalRetrievalChain, and CombineDocsChain. These chains were added to improve the functionality of the application and provide more options for users.

In addition, custom chains were added to the interface/chains/custom.py file. The CombineDocsChain was added to allow users to combine multiple documents into a single document for use in the question answering chains. The QA_CHAIN_TYPES constant was also added to the frontend_node/constants.py file to provide a list of available question answering chain types.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-07 19:48:51 -03:00
commit f0975ddf63
5 changed files with 76 additions and 41 deletions

View file

@ -16,6 +16,11 @@ chains:
- MidJourneyPromptChain
- TimeTravelGuideChain
- SQLDatabaseChain
- RetrievalQA
- RetrievalQAWithSourcesChain
- QAWithSourcesChain
- ConversationalRetrievalChain
- CombineDocsChain
documentloaders:
- AirbyteJSONLoader
- CoNLLULoader

View file

@ -1,4 +1,3 @@
from abc import ABC
from typing import Any, List, Optional
from langchain import LLMChain
@ -33,27 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.sql_database import SQLDatabase
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
from langflow.interface.base import CustomChain
class CustomAgentExecutor(AgentExecutor, ABC):
"""Custom agent executor"""
@staticmethod
def function_name():
return "CustomAgentExecutor"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class JsonAgent(CustomAgentExecutor):
class JsonAgent(CustomChain):
"""Json agent"""
@staticmethod
@ -89,7 +71,7 @@ class JsonAgent(CustomAgentExecutor):
return super().run(*args, **kwargs)
class CSVAgent(CustomAgentExecutor):
class CSVAgent(CustomChain):
"""CSV agent"""
@staticmethod
@ -137,7 +119,7 @@ class CSVAgent(CustomAgentExecutor):
return super().run(*args, **kwargs)
class VectorStoreAgent(CustomAgentExecutor):
class VectorStoreAgent(CustomChain):
"""Vector Store agent"""
@staticmethod
@ -175,7 +157,7 @@ class VectorStoreAgent(CustomAgentExecutor):
return super().run(*args, **kwargs)
class SQLAgent(CustomAgentExecutor):
class SQLAgent(CustomChain):
"""SQL agent"""
@staticmethod
@ -247,7 +229,7 @@ class SQLAgent(CustomAgentExecutor):
return super().run(*args, **kwargs)
class VectorStoreRouterAgent(CustomAgentExecutor):
class VectorStoreRouterAgent(CustomChain):
"""Vector Store Router Agent"""
@staticmethod
@ -286,7 +268,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
return super().run(*args, **kwargs)
class InitializeAgent(CustomAgentExecutor):
class InitializeAgent(CustomChain):
"""Implementation of initialize_agent function"""
@staticmethod

View file

@ -1,12 +1,13 @@
from typing import Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from langflow.custom.customs import get_custom_nodes
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import chain_type_to_cls_dict
from langflow.interface.importing.utils import import_class
from langflow.settings import settings
from langflow.template.frontend_node.chains import ChainFrontendNode
from langflow.utils.logger import logger
from langflow.utils.util import build_template_from_class
from langflow.utils.util import build_template_from_class, build_template_from_method
from langchain import chains
# Assuming necessary imports for Field, Template, and FrontendNode classes
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
def frontend_node_class(self) -> Type[ChainFrontendNode]:
return ChainFrontendNode
#! We need to find a better solution for this
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
@property
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict = chain_type_to_cls_dict
self.type_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
from langflow.interface.chains.custom import CUSTOM_CHAINS
self.type_dict.update(CUSTOM_CHAINS)
@ -37,20 +44,38 @@ class ChainCreator(LangChainTypeCreator):
try:
if name in get_custom_nodes(self.type_name).keys():
return get_custom_nodes(self.type_name)[name]
elif name in self.from_method_nodes.keys():
return build_template_from_method(
name,
type_to_cls_dict=self.type_to_loader_dict,
method_name=self.from_method_nodes[name],
add_function=True,
)
return build_template_from_class(
name, self.type_to_loader_dict, add_function=True
)
except ValueError as exc:
raise ValueError("Chain not found") from exc
raise ValueError(f"Chain {name} not found: {exc}") from exc
except AttributeError as exc:
logger.error(f"Chain {name} not loaded: {exc}")
return None
def to_list(self) -> List[str]:
custom_chains = list(get_custom_nodes("chains").keys())
default_chains = list(self.type_to_loader_dict.keys())
# def to_list(self) -> List[str]:
# custom_chains = list(get_custom_nodes("chains").keys())
# default_chains = list(self.type_to_loader_dict.keys())
return default_chains + custom_chains
# return default_chains + custom_chains
def to_list(self) -> List[str]:
names = []
for _, chain in self.type_to_loader_dict.items():
chain_name = (
chain.function_name()
if hasattr(chain, "function_name")
else chain.__name__
)
names.append(chain_name)
return names
chain_creator = ChainCreator()

View file

@ -1,11 +1,13 @@
from typing import Dict, Optional, Type
from typing import Dict, Optional, Type, Union
from langchain.chains import ConversationChain
from langchain.memory.buffer import ConversationBufferMemory
from langchain.schema import BaseMemory
from langflow.interface.base import CustomChain
from pydantic import Field, root_validator
from langchain.chains.question_answering import load_qa_chain
from langflow.interface.utils import extract_input_variables_from_prompt
from langchain.base_language import BaseLanguageModel
DEFAULT_SUFFIX = """"
Current conversation:
@ -14,7 +16,7 @@ Human: {input}
{ai_prefix}"""
class BaseCustomChain(ConversationChain):
class BaseCustomConversationChain(ConversationChain):
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
template: Optional[str]
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
return values
class SeriesCharacterChain(BaseCustomChain):
class SeriesCharacterChain(BaseCustomConversationChain):
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
character: str
@ -66,7 +68,7 @@ Human: {input}
"""Default memory store."""
class MidJourneyPromptChain(BaseCustomChain):
class MidJourneyPromptChain(BaseCustomConversationChain):
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
template: Optional[
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
AI:""" # noqa: E501
class TimeTravelGuideChain(BaseCustomChain):
class TimeTravelGuideChain(BaseCustomConversationChain):
template: Optional[
str
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
AI:""" # noqa: E501
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
class CombineDocsChain(CustomChain):
"""Implementation of initialize_agent function"""
@staticmethod
def function_name():
return "load_qa_chain"
@classmethod
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
return load_qa_chain(llm=llm, chain_type=chain_type)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
"CombineDocsChain": CombineDocsChain,
"SeriesCharacterChain": SeriesCharacterChain,
"MidJourneyPromptChain": MidJourneyPromptChain,
"TimeTravelGuideChain": TimeTravelGuideChain,

View file

@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
"""
HUMAN_PROMPT = "{input}"
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]