diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index 92e1fc2d8..e315d27fd 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -2,7 +2,9 @@ from langflow.template import frontend_node # These should always be instantiated CUSTOM_NODES = { - "prompts": {"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode()}, + "prompts": { + "ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(), + }, "tools": { "PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(), "PythonFunction": frontend_node.tools.PythonFunctionNode(), @@ -23,6 +25,7 @@ CUSTOM_NODES = { "SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(), "TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(), "MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(), + "load_qa_chain": frontend_node.chains.CombineDocsChainNode(), }, } diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index 08cbc6681..df03950af 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type, Union +from langchain.chains.base import Chain from pydantic import BaseModel @@ -81,5 +82,24 @@ class LangChainTypeCreator(BaseModel, ABC): ) signature.add_extra_fields() + signature.add_extra_base_classes() return signature + + +class CustomChain(Chain, ABC): + """Custom chain""" + + @staticmethod + def function_name(): + return "CustomChain" + + @classmethod + def initialize(cls, *args, **kwargs): + pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + return super().run(*args, **kwargs) diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 34bc0103e..8bde1565c 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -2,7 +2,6 @@ import inspect from typing import Any from langchain import ( - chains, document_loaders, embeddings, llms, @@ -15,6 +14,8 @@ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI from langchain.chat_models import ChatAnthropic from langflow.interface.importing.utils import import_class +from langflow.interface.agents.custom import CUSTOM_AGENTS +from langflow.interface.chains.custom import CUSTOM_CHAINS ## LLMs llm_type_to_cls_dict = llms.type_to_cls_dict @@ -22,11 +23,6 @@ llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore -## Chains -chain_type_to_cls_dict: dict[str, Any] = { - chain_name: import_class(f"langchain.chains.{chain_name}") - for chain_name in chains.__all__ -} ## Toolkits toolkit_type_to_loader_dict: dict[str, Any] = { @@ -73,3 +69,6 @@ documentloaders_type_to_cls_dict: dict[str, Any] = { textsplitter_type_to_cls_dict: dict[str, Any] = dict( inspect.getmembers(text_splitter, inspect.isclass) ) + +# merge CUSTOM_AGENTS and CUSTOM_CHAINS +CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index f8d26ffca..80f451f03 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -19,9 +19,10 @@ from langchain.chains.loading import load_chain_from_config from langchain.llms.loading import load_llm_from_config from pydantic import ValidationError -from langflow.interface.agents.custom import CUSTOM_AGENTS +from langflow.interface.custom_lists import CUSTOM_NODES from langflow.interface.importing.utils import get_function, import_by_type from langflow.interface.toolkits.base import toolkits_creator +from langflow.interface.chains.base import chain_creator from langflow.interface.types import get_type_list from langflow.interface.utils import load_file_into_dict from langflow.utils import util, validate @@ -31,8 +32,8 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: """Instantiate class from module type and key, and params""" params = convert_params_to_sets(params) params = convert_kwargs(params) - if node_type in CUSTOM_AGENTS: - if custom_agent := CUSTOM_AGENTS.get(node_type): + if node_type in CUSTOM_NODES: + if custom_agent := CUSTOM_NODES.get(node_type): return custom_agent.initialize(**params) class_object = import_by_type(_type=base_type, name=node_type) @@ -77,10 +78,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): return instantiate_textsplitter(class_object, params) elif base_type == "utilities": return instantiate_utility(node_type, class_object, params) + elif base_type == "chains": + return instantiate_chains(node_type, class_object, params) else: return class_object(**params) +def instantiate_chains(node_type, class_object, params): + if "retriever" in params and hasattr(params["retriever"], "as_retriever"): + params["retriever"] = params["retriever"].as_retriever() + if node_type in chain_creator.from_method_nodes: + method = chain_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + return class_method(**params) + raise ValueError(f"Method {method} not found in {class_object}") + + return class_object(**params) + + def instantiate_agent(class_object, params): return load_agent_executor(class_object, params) @@ -141,6 +156,14 @@ def instantiate_vectorstore(class_object, params): "The source you provided did not load correctly or was empty." "This may cause an error in the vectorstore." ) + # Chroma requires all metadata values to not be None + if class_object.__name__ == "Chroma": + for doc in params["documents"]: + if doc.metadata is None: + doc.metadata = {} + for key, value in doc.metadata.items(): + if value is None: + doc.metadata[key] = "" return class_object.from_documents(**params)