🚀 feat(langflow): add new chains to config.yaml and custom chains to interface/chains/custom.py
✨ feat(langflow): add new chains to config.yaml and custom chains to interface/chains/custom.py
The following chains were added to the config.yaml file: RetrievalQA, RetrievalQAWithSourcesChain, QAWithSourcesChain, ConversationalRetrievalChain, and CombineDocsChain. These chains were added to improve the functionality of the application and provide more options for users.
In addition, custom chains were added to the interface/chains/custom.py file. The CombineDocsChain was added to allow users to combine multiple documents into a single document for use in the question answering chains. The QA_CHAIN_TYPES constant was also added to the frontend_node/constants.py file to provide a list of available question answering chain types.
This commit is contained in:
parent
c8125b3386
commit
f0975ddf63
5 changed files with 76 additions and 41 deletions
|
|
@ -16,6 +16,11 @@ chains:
|
|||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
- SQLDatabaseChain
|
||||
- RetrievalQA
|
||||
- RetrievalQAWithSourcesChain
|
||||
- QAWithSourcesChain
|
||||
- ConversationalRetrievalChain
|
||||
- CombineDocsChain
|
||||
documentloaders:
|
||||
- AirbyteJSONLoader
|
||||
- CoNLLULoader
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
|
|
@ -33,27 +32,10 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langflow.interface.base import CustomChain
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom agent executor"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomAgentExecutor"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class JsonAgent(CustomAgentExecutor):
|
||||
class JsonAgent(CustomChain):
|
||||
"""Json agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -89,7 +71,7 @@ class JsonAgent(CustomAgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CSVAgent(CustomAgentExecutor):
|
||||
class CSVAgent(CustomChain):
|
||||
"""CSV agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -137,7 +119,7 @@ class CSVAgent(CustomAgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreAgent(CustomAgentExecutor):
|
||||
class VectorStoreAgent(CustomChain):
|
||||
"""Vector Store agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -175,7 +157,7 @@ class VectorStoreAgent(CustomAgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class SQLAgent(CustomAgentExecutor):
|
||||
class SQLAgent(CustomChain):
|
||||
"""SQL agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -247,7 +229,7 @@ class SQLAgent(CustomAgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class VectorStoreRouterAgent(CustomAgentExecutor):
|
||||
class VectorStoreRouterAgent(CustomChain):
|
||||
"""Vector Store Router Agent"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -286,7 +268,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class InitializeAgent(CustomAgentExecutor):
|
||||
class InitializeAgent(CustomChain):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import chain_type_to_cls_dict
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.chains import ChainFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langchain import chains
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
|
|||
def frontend_node_class(self) -> Type[ChainFrontendNode]:
|
||||
return ChainFrontendNode
|
||||
|
||||
#! We need to find a better solution for this
|
||||
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = chain_type_to_cls_dict
|
||||
self.type_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
self.type_dict.update(CUSTOM_CHAINS)
|
||||
|
|
@ -37,20 +44,38 @@ class ChainCreator(LangChainTypeCreator):
|
|||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
elif name in self.from_method_nodes.keys():
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
add_function=True,
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Chain not found") from exc
|
||||
raise ValueError(f"Chain {name} not found: {exc}") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Chain {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
custom_chains = list(get_custom_nodes("chains").keys())
|
||||
default_chains = list(self.type_to_loader_dict.keys())
|
||||
# def to_list(self) -> List[str]:
|
||||
# custom_chains = list(get_custom_nodes("chains").keys())
|
||||
# default_chains = list(self.type_to_loader_dict.keys())
|
||||
|
||||
return default_chains + custom_chains
|
||||
# return default_chains + custom_chains
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
names = []
|
||||
for _, chain in self.type_to_loader_dict.items():
|
||||
chain_name = (
|
||||
chain.function_name()
|
||||
if hasattr(chain, "function_name")
|
||||
else chain.__name__
|
||||
)
|
||||
names.append(chain_name)
|
||||
return names
|
||||
|
||||
|
||||
chain_creator = ChainCreator()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from langflow.interface.base import CustomChain
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
DEFAULT_SUFFIX = """"
|
||||
Current conversation:
|
||||
|
|
@ -14,7 +16,7 @@ Human: {input}
|
|||
{ai_prefix}"""
|
||||
|
||||
|
||||
class BaseCustomChain(ConversationChain):
|
||||
class BaseCustomConversationChain(ConversationChain):
|
||||
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
|
||||
|
||||
template: Optional[str]
|
||||
|
|
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
|
|||
return values
|
||||
|
||||
|
||||
class SeriesCharacterChain(BaseCustomChain):
|
||||
class SeriesCharacterChain(BaseCustomConversationChain):
|
||||
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
|
||||
|
||||
character: str
|
||||
|
|
@ -66,7 +68,7 @@ Human: {input}
|
|||
"""Default memory store."""
|
||||
|
||||
|
||||
class MidJourneyPromptChain(BaseCustomChain):
|
||||
class MidJourneyPromptChain(BaseCustomConversationChain):
|
||||
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
|
||||
|
||||
template: Optional[
|
||||
|
|
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
class TimeTravelGuideChain(BaseCustomChain):
|
||||
class TimeTravelGuideChain(BaseCustomConversationChain):
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
|
|
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
|
||||
class CombineDocsChain(CustomChain):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "load_qa_chain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
|
||||
return load_qa_chain(llm=llm, chain_type=chain_type)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
|
||||
"CombineDocsChain": CombineDocsChain,
|
||||
"SeriesCharacterChain": SeriesCharacterChain,
|
||||
"MidJourneyPromptChain": MidJourneyPromptChain,
|
||||
"TimeTravelGuideChain": TimeTravelGuideChain,
|
||||
|
|
|
|||
|
|
@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
|
|||
"""
|
||||
|
||||
HUMAN_PROMPT = "{input}"
|
||||
|
||||
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue