🔀 refactor(customs.py): add load_qa_chain to CUSTOM_NODES
🔀 refactor(custom_lists.py): merge CUSTOM_AGENTS and CUSTOM_CHAINS into CUSTOM_NODES 🔀 refactor(loading.py): add instantiate_chains function to instantiate chains 🔀 refactor(base.py): add CustomChain class The changes add support for a new chain called load_qa_chain to CUSTOM_NODES. CUSTOM_AGENTS and CUSTOM_CHAINS are merged into CUSTOM_NODES. A new function called instantiate_chains is added to instantiate chains. A new class called CustomChain is added to the base.py file. This class is used to define a custom chain.
This commit is contained in:
parent
b5ab95c4cc
commit
c8125b3386
4 changed files with 55 additions and 10 deletions
|
|
@ -2,7 +2,9 @@ from langflow.template import frontend_node
|
|||
|
||||
# These should always be instantiated
|
||||
CUSTOM_NODES = {
|
||||
"prompts": {"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode()},
|
||||
"prompts": {
|
||||
"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
|
||||
},
|
||||
"tools": {
|
||||
"PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(),
|
||||
"PythonFunction": frontend_node.tools.PythonFunctionNode(),
|
||||
|
|
@ -23,6 +25,7 @@ CUSTOM_NODES = {
|
|||
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
|
||||
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
|
||||
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
|
||||
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -81,5 +82,24 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
)
|
||||
|
||||
signature.add_extra_fields()
|
||||
signature.add_extra_base_classes()
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
class CustomChain(Chain, ABC):
|
||||
"""Custom chain"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomChain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import inspect
|
|||
from typing import Any
|
||||
|
||||
from langchain import (
|
||||
chains,
|
||||
document_loaders,
|
||||
embeddings,
|
||||
llms,
|
||||
|
|
@ -15,6 +14,8 @@ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
|||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
## LLMs
|
||||
llm_type_to_cls_dict = llms.type_to_cls_dict
|
||||
|
|
@ -22,11 +23,6 @@ llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
|
|||
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
|
||||
## Chains
|
||||
chain_type_to_cls_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
|
||||
## Toolkits
|
||||
toolkit_type_to_loader_dict: dict[str, Any] = {
|
||||
|
|
@ -73,3 +69,6 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
|
|||
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
||||
inspect.getmembers(text_splitter, inspect.isclass)
|
||||
)
|
||||
|
||||
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ from langchain.chains.loading import load_chain_from_config
|
|||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import get_function, import_by_type
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.types import get_type_list
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.utils import util, validate
|
||||
|
|
@ -31,8 +32,8 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_AGENTS:
|
||||
if custom_agent := CUSTOM_AGENTS.get(node_type):
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_agent := CUSTOM_NODES.get(node_type):
|
||||
return custom_agent.initialize(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
|
|
@ -77,10 +78,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
elif base_type == "chains":
|
||||
return instantiate_chains(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_chains(node_type, class_object, params):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
method = chain_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
|
@ -141,6 +156,14 @@ def instantiate_vectorstore(class_object, params):
|
|||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
# Chroma requires all metadata values to not be None
|
||||
if class_object.__name__ == "Chroma":
|
||||
for doc in params["documents"]:
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue