🔀 refactor(customs.py): add load_qa_chain to CUSTOM_NODES

🔀 refactor(custom_lists.py): merge CUSTOM_AGENTS and CUSTOM_CHAINS into CUSTOM_NODES
🔀 refactor(loading.py): add instantiate_chains function to instantiate chains
🔀 refactor(base.py): add CustomChain class
The changes add support for a new chain called load_qa_chain to CUSTOM_NODES. CUSTOM_AGENTS and CUSTOM_CHAINS are merged into CUSTOM_NODES. A new function called instantiate_chains is added to instantiate chains. A new class called CustomChain is added to the base.py file. This class is used to define a custom chain.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-07 19:48:16 -03:00
commit c8125b3386
4 changed files with 55 additions and 10 deletions

View file

@ -2,7 +2,9 @@ from langflow.template import frontend_node
# These should always be instantiated
CUSTOM_NODES = {
"prompts": {"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode()},
"prompts": {
"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
},
"tools": {
"PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(),
"PythonFunction": frontend_node.tools.PythonFunctionNode(),
@ -23,6 +25,7 @@ CUSTOM_NODES = {
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
},
}

View file

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union
from langchain.chains.base import Chain
from pydantic import BaseModel
@ -81,5 +82,24 @@ class LangChainTypeCreator(BaseModel, ABC):
)
signature.add_extra_fields()
signature.add_extra_base_classes()
return signature
class CustomChain(Chain, ABC):
"""Custom chain"""
@staticmethod
def function_name():
return "CustomChain"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)

View file

@ -2,7 +2,6 @@ import inspect
from typing import Any
from langchain import (
chains,
document_loaders,
embeddings,
llms,
@ -15,6 +14,8 @@ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langflow.interface.importing.utils import import_class
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.chains.custom import CUSTOM_CHAINS
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
@ -22,11 +23,6 @@ llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains
chain_type_to_cls_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
## Toolkits
toolkit_type_to_loader_dict: dict[str, Any] = {
@ -73,3 +69,6 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS}

View file

@ -19,9 +19,10 @@ from langchain.chains.loading import load_chain_from_config
from langchain.llms.loading import load_llm_from_config
from pydantic import ValidationError
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import get_function, import_by_type
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.types import get_type_list
from langflow.interface.utils import load_file_into_dict
from langflow.utils import util, validate
@ -31,8 +32,8 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_AGENTS:
if custom_agent := CUSTOM_AGENTS.get(node_type):
if node_type in CUSTOM_NODES:
if custom_agent := CUSTOM_NODES.get(node_type):
return custom_agent.initialize(**params)
class_object = import_by_type(_type=base_type, name=node_type)
@ -77,10 +78,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
return instantiate_textsplitter(class_object, params)
elif base_type == "utilities":
return instantiate_utility(node_type, class_object, params)
elif base_type == "chains":
return instantiate_chains(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_chains(node_type, class_object, params):
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
params["retriever"] = params["retriever"].as_retriever()
if node_type in chain_creator.from_method_nodes:
method = chain_creator.from_method_nodes[node_type]
if class_method := getattr(class_object, method, None):
return class_method(**params)
raise ValueError(f"Method {method} not found in {class_object}")
return class_object(**params)
def instantiate_agent(class_object, params):
return load_agent_executor(class_object, params)
@ -141,6 +156,14 @@ def instantiate_vectorstore(class_object, params):
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
# Chroma requires all metadata values to not be None
if class_object.__name__ == "Chroma":
for doc in params["documents"]:
if doc.metadata is None:
doc.metadata = {}
for key, value in doc.metadata.items():
if value is None:
doc.metadata[key] = ""
return class_object.from_documents(**params)