Add New Chains: RetrievalQA, RetrievalQAWithSourcesChain, ConversationalRetrievalChain, CombineDocsChain

This commit is contained in:
gustavoschaedler 2023-06-13 17:05:40 +01:00
commit 144f2b470e
17 changed files with 761 additions and 66 deletions

View file

@ -1,4 +1,4 @@
from langflow.cache import cache_manager
from langflow.interface.loading import load_flow_from_json
from langflow.processing.process import load_flow_from_json
__all__ = ["load_flow_from_json", "cache_manager"]

View file

@ -16,6 +16,10 @@ chains:
- MidJourneyPromptChain
- TimeTravelGuideChain
- SQLDatabaseChain
- RetrievalQA
- RetrievalQAWithSourcesChain
- ConversationalRetrievalChain
- CombineDocsChain
documentloaders:
- AirbyteJSONLoader
- CoNLLULoader

View file

@ -22,6 +22,7 @@ CUSTOM_NODES = {
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
},
}

View file

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union
from langchain.chains.base import Chain
from pydantic import BaseModel
from langflow.template.field.base import TemplateField
@ -81,5 +81,24 @@ class LangChainTypeCreator(BaseModel, ABC):
)
signature.add_extra_fields()
signature.add_extra_base_classes()
return signature
class CustomChain(Chain, ABC):
"""Custom chain"""
@staticmethod
def function_name():
return "CustomChain"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)

View file

@ -1,12 +1,13 @@
from typing import Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from langflow.custom.customs import get_custom_nodes
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import chain_type_to_cls_dict
from langflow.interface.importing.utils import import_class
from langflow.settings import settings
from langflow.template.frontend_node.chains import ChainFrontendNode
from langflow.utils.logger import logger
from langflow.utils.util import build_template_from_class
from langflow.utils.util import build_template_from_class, build_template_from_method
from langchain import chains
# Assuming necessary imports for Field, Template, and FrontendNode classes
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
def frontend_node_class(self) -> Type[ChainFrontendNode]:
return ChainFrontendNode
#! We need to find a better solution for this
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
@property
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict = chain_type_to_cls_dict
self.type_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
from langflow.interface.chains.custom import CUSTOM_CHAINS
self.type_dict.update(CUSTOM_CHAINS)
@ -37,20 +44,32 @@ class ChainCreator(LangChainTypeCreator):
try:
if name in get_custom_nodes(self.type_name).keys():
return get_custom_nodes(self.type_name)[name]
elif name in self.from_method_nodes.keys():
return build_template_from_method(
name,
type_to_cls_dict=self.type_to_loader_dict,
method_name=self.from_method_nodes[name],
add_function=True,
)
return build_template_from_class(
name, self.type_to_loader_dict, add_function=True
)
except ValueError as exc:
raise ValueError("Chain not found") from exc
raise ValueError(f"Chain {name} not found: {exc}") from exc
except AttributeError as exc:
logger.error(f"Chain {name} not loaded: {exc}")
return None
def to_list(self) -> List[str]:
custom_chains = list(get_custom_nodes("chains").keys())
default_chains = list(self.type_to_loader_dict.keys())
return default_chains + custom_chains
names = []
for _, chain in self.type_to_loader_dict.items():
chain_name = (
chain.function_name()
if hasattr(chain, "function_name")
else chain.__name__
)
names.append(chain_name)
return names
chain_creator = ChainCreator()

View file

@ -1,11 +1,13 @@
from typing import Dict, Optional, Type
from typing import Dict, Optional, Type, Union
from langchain.chains import ConversationChain
from langchain.memory.buffer import ConversationBufferMemory
from langchain.schema import BaseMemory
from langflow.interface.base import CustomChain
from pydantic import Field, root_validator
from langflow.graph.utils import extract_input_variables_from_prompt
from langchain.chains.question_answering import load_qa_chain
from langflow.interface.utils import extract_input_variables_from_prompt
from langchain.base_language import BaseLanguageModel
DEFAULT_SUFFIX = """"
Current conversation:
@ -14,7 +16,7 @@ Human: {input}
{ai_prefix}"""
class BaseCustomChain(ConversationChain):
class BaseCustomConversationChain(ConversationChain):
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
template: Optional[str]
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
return values
class SeriesCharacterChain(BaseCustomChain):
class SeriesCharacterChain(BaseCustomConversationChain):
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
character: str
@ -66,7 +68,7 @@ Human: {input}
"""Default memory store."""
class MidJourneyPromptChain(BaseCustomChain):
class MidJourneyPromptChain(BaseCustomConversationChain):
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
template: Optional[
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
AI:""" # noqa: E501
class TimeTravelGuideChain(BaseCustomChain):
class TimeTravelGuideChain(BaseCustomConversationChain):
template: Optional[
str
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
AI:""" # noqa: E501
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
class CombineDocsChain(CustomChain):
"""Implementation of AgentInitializer function"""
@staticmethod
def function_name():
return "load_qa_chain"
@classmethod
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
return load_qa_chain(llm=llm, chain_type=chain_type)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
"CombineDocsChain": CombineDocsChain,
"SeriesCharacterChain": SeriesCharacterChain,
"MidJourneyPromptChain": MidJourneyPromptChain,
"TimeTravelGuideChain": TimeTravelGuideChain,

View file

@ -14,18 +14,20 @@ from langchain.agents import agent_toolkits
from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.chains.custom import CUSTOM_CHAINS
## LLMs
# LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains
# Chains
chain_type_to_cls_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
## Toolkits
# Toolkits
toolkit_type_to_loader_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is lower case it is a loader
@ -40,25 +42,25 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
if not toolkit_name.islower()
}
## Memories
# Memories
memory_type_to_cls_dict: dict[str, Any] = {
memory_name: import_class(f"langchain.memory.{memory_name}")
for memory_name in memory.__all__
}
## Wrappers
# Wrappers
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
}
## Embeddings
# Embeddings
embedding_type_to_cls_dict: dict[str, Any] = {
embedding_name: import_class(f"langchain.embeddings.{embedding_name}")
for embedding_name in embeddings.__all__
}
## Document Loaders
# Document Loaders
documentloaders_type_to_cls_dict: dict[str, Any] = {
documentloader_name: import_class(
f"langchain.document_loaders.{documentloader_name}"
@ -66,7 +68,10 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
for documentloader_name in document_loaders.__all__
}
## Text Splitters
# Text Splitters
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore

View file

@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.tools import BaseTool
from langflow.utils import validate
def import_module(module_path: str) -> Any:
@ -147,3 +148,10 @@ def import_utility(utility: str) -> Any:
if utility == "SQLDatabase":
return import_class(f"langchain.sql_database.{utility}")
return import_class(f"langchain.utilities.{utility}")
def get_function(code):
"""Get the function"""
function_name = validate.extract_function_name(code)
return validate.create_function(code, function_name)

View file

@ -19,10 +19,10 @@ from langchain.chains.loading import load_chain_from_config
from langchain.llms.loading import load_llm_from_config
from pydantic import ValidationError
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import import_by_type
from langflow.interface.run import fix_memory_inputs
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.types import get_type_list
from langflow.interface.utils import load_file_into_dict
from langflow.utils import util, validate
@ -32,10 +32,11 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_AGENTS:
custom_agent = CUSTOM_AGENTS.get(node_type)
if custom_agent:
return custom_agent.initialize(**params)
if node_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(node_type):
if hasattr(custom_node, "initialize"):
return custom_node.initialize(**params)
return custom_node(**params)
class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(class_object, base_type, node_type, params)
@ -79,10 +80,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
return instantiate_textsplitter(class_object, params)
elif base_type == "utilities":
return instantiate_utility(node_type, class_object, params)
elif base_type == "chains":
return instantiate_chains(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_chains(node_type, class_object, params):
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
params["retriever"] = params["retriever"].as_retriever()
if node_type in chain_creator.from_method_nodes:
method = chain_creator.from_method_nodes[node_type]
if class_method := getattr(class_object, method, None):
return class_method(**params)
raise ValueError(f"Method {method} not found in {class_object}")
return class_object(**params)
def instantiate_agent(class_object, params):
return load_agent_executor(class_object, params)
@ -161,38 +176,6 @@ def instantiate_utility(node_type, class_object, params):
return class_object(**params)
def load_flow_from_json(path: str, build=True):
"""Load flow from json file"""
# This is done to avoid circular imports
from langflow.graph import Graph
with open(path, "r", encoding="utf-8") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
# Substitute ZeroShotPrompt with PromptTemplate
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
if build:
langchain_object = graph.build()
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
return langchain_object
return graph
def replace_zero_shot_prompt_with_prompt_template(nodes):
"""Replace ZeroShotPrompt with PromptTemplate"""
for node in nodes:

View file

@ -2,6 +2,7 @@ import base64
import json
import os
from io import BytesIO
import re
import yaml
from langchain.base_language import BaseLanguageModel
@ -52,3 +53,8 @@ def try_setting_streaming_options(langchain_object, websocket):
llm.stream = True
return langchain_object
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)

View file

@ -0,0 +1,61 @@
import json
from langflow.interface.run import (
get_memory_key,
update_memory_keys,
)
from langflow.graph import Graph
def fix_memory_inputs(langchain_object):
"""
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
get_memory_key function and updates the memory keys using the update_memory_keys function.
"""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
try:
if langchain_object.memory.memory_key in langchain_object.input_variables:
return
except AttributeError:
input_variables = (
langchain_object.prompt.input_variables
if hasattr(langchain_object, "prompt")
else langchain_object.input_keys
)
if langchain_object.memory.memory_key in input_variables:
return
possible_new_mem_key = get_memory_key(langchain_object)
if possible_new_mem_key is not None:
update_memory_keys(langchain_object, possible_new_mem_key)
def load_flow_from_json(path: str, build=True):
"""Load flow from json file"""
# This is done to avoid circular imports
with open(path, "r", encoding="utf-8") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
# Substitute ZeroShotPrompt with PromptTemplate
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
if build:
langchain_object = graph.build()
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
return langchain_object
return graph

View file

@ -27,6 +27,9 @@ class FrontendNode(BaseModel):
def add_extra_fields(self) -> None:
pass
def add_extra_base_classes(self) -> None:
pass
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""

View file

@ -2,10 +2,24 @@ from typing import Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
from langflow.template.template.base import Template
class ChainFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
if self.template.type_name == "ConversationalRetrievalChain":
# add memory
self.template.add_field(
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode):
"ConversationChain",
"MidJourneyPromptChain",
]
class CombineDocsChainNode(FrontendNode):
name: str = "CombineDocsChain"
template: Template = Template(
type_name="load_qa_chain",
fields=[
TemplateField(
field_type="str",
required=True,
is_list=True,
show=True,
multiline=False,
options=QA_CHAIN_TYPES,
value=QA_CHAIN_TYPES[0],
name="chain_type",
advanced=False,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
advanced=False,
),
],
)
description: str = """Construct a zero shot agent from an LLM and tools."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# do nothing and don't return anything
pass

View file

@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
"""
HUMAN_PROMPT = "{input}"
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]

View file

@ -20,6 +20,9 @@ class VectorStoreFrontendNode(FrontendNode):
self.template.add_field(extra_field)
def add_extra_base_classes(self) -> None:
self.base_classes.append("BaseRetriever")
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)

File diff suppressed because one or more lines are too long