🚀 feat(langflow): add new Retrieval chains (#456)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-09 16:25:43 -03:00 committed by GitHub
commit c216c78b1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 226 additions and 71 deletions

View file

@ -51,7 +51,6 @@ build_and_run:
rm -rf dist
make build && poetry run pip install dist/*.tar.gz && poetry run langflow
build_frontend:
cd src/frontend && CI='' npm run build
cp -r src/frontend/build src/backend/langflow/frontend

View file

@ -16,6 +16,10 @@ chains:
- MidJourneyPromptChain
- TimeTravelGuideChain
- SQLDatabaseChain
- RetrievalQA
- RetrievalQAWithSourcesChain
- ConversationalRetrievalChain
- CombineDocsChain
documentloaders:
- AirbyteJSONLoader
- CoNLLULoader

View file

@ -2,7 +2,9 @@ from langflow.template import frontend_node
# These should always be instantiated
CUSTOM_NODES = {
"prompts": {"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode()},
"prompts": {
"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
},
"tools": {
"PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(),
"PythonFunction": frontend_node.tools.PythonFunctionNode(),
@ -23,6 +25,7 @@ CUSTOM_NODES = {
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
},
}

View file

@ -49,8 +49,11 @@ class Vertex:
template_dict = self.data["node"]["template"]
self.vertex_type = (
self.data["type"] if "Tool" not in self.output else template_dict["_type"]
self.data["type"]
if "Tool" not in self.output or template_dict["_type"].islower()
else template_dict["_type"]
)
if self.base_type is None:
for base_type, value in ALL_TYPES_DICT.items():
if self.vertex_type in value:

View file

@ -1,4 +1,3 @@
from abc import ABC
from typing import Any, List, Optional
from langchain import LLMChain
@ -33,24 +32,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.sql_database import SQLDatabase
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class CustomAgentExecutor(AgentExecutor, ABC):
"""Custom agent executor"""
@staticmethod
def function_name():
return "CustomAgentExecutor"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
from langflow.interface.base import CustomAgentExecutor
class JsonAgent(CustomAgentExecutor):

View file

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union
from langchain.chains.base import Chain
from langchain.agents import AgentExecutor
from pydantic import BaseModel
from langflow.template.field.base import TemplateField
@ -81,5 +82,42 @@ class LangChainTypeCreator(BaseModel, ABC):
)
signature.add_extra_fields()
signature.add_extra_base_classes()
return signature
class CustomChain(Chain, ABC):
"""Custom chain"""
@staticmethod
def function_name():
return "CustomChain"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class CustomAgentExecutor(AgentExecutor, ABC):
"""Custom chain"""
@staticmethod
def function_name():
return "CustomChain"
@classmethod
def initialize(cls, *args, **kwargs):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)

View file

@ -1,12 +1,13 @@
from typing import Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from langflow.custom.customs import get_custom_nodes
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import chain_type_to_cls_dict
from langflow.interface.importing.utils import import_class
from langflow.settings import settings
from langflow.template.frontend_node.chains import ChainFrontendNode
from langflow.utils.logger import logger
from langflow.utils.util import build_template_from_class
from langflow.utils.util import build_template_from_class, build_template_from_method
from langchain import chains
# Assuming necessary imports for Field, Template, and FrontendNode classes
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
def frontend_node_class(self) -> Type[ChainFrontendNode]:
return ChainFrontendNode
#! We need to find a better solution for this
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
@property
def type_to_loader_dict(self) -> Dict:
if self.type_dict is None:
self.type_dict = chain_type_to_cls_dict
self.type_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
from langflow.interface.chains.custom import CUSTOM_CHAINS
self.type_dict.update(CUSTOM_CHAINS)
@ -37,20 +44,32 @@ class ChainCreator(LangChainTypeCreator):
try:
if name in get_custom_nodes(self.type_name).keys():
return get_custom_nodes(self.type_name)[name]
elif name in self.from_method_nodes.keys():
return build_template_from_method(
name,
type_to_cls_dict=self.type_to_loader_dict,
method_name=self.from_method_nodes[name],
add_function=True,
)
return build_template_from_class(
name, self.type_to_loader_dict, add_function=True
)
except ValueError as exc:
raise ValueError("Chain not found") from exc
raise ValueError(f"Chain {name} not found: {exc}") from exc
except AttributeError as exc:
logger.error(f"Chain {name} not loaded: {exc}")
return None
def to_list(self) -> List[str]:
custom_chains = list(get_custom_nodes("chains").keys())
default_chains = list(self.type_to_loader_dict.keys())
return default_chains + custom_chains
names = []
for _, chain in self.type_to_loader_dict.items():
chain_name = (
chain.function_name()
if hasattr(chain, "function_name")
else chain.__name__
)
names.append(chain_name)
return names
chain_creator = ChainCreator()

View file

@ -1,11 +1,13 @@
from typing import Dict, Optional, Type
from typing import Dict, Optional, Type, Union
from langchain.chains import ConversationChain
from langchain.memory.buffer import ConversationBufferMemory
from langchain.schema import BaseMemory
from langflow.interface.base import CustomChain
from pydantic import Field, root_validator
from langchain.chains.question_answering import load_qa_chain
from langflow.interface.utils import extract_input_variables_from_prompt
from langchain.base_language import BaseLanguageModel
DEFAULT_SUFFIX = """"
Current conversation:
@ -14,7 +16,7 @@ Human: {input}
{ai_prefix}"""
class BaseCustomChain(ConversationChain):
class BaseCustomConversationChain(ConversationChain):
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
template: Optional[str]
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
return values
class SeriesCharacterChain(BaseCustomChain):
class SeriesCharacterChain(BaseCustomConversationChain):
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
character: str
@ -66,7 +68,7 @@ Human: {input}
"""Default memory store."""
class MidJourneyPromptChain(BaseCustomChain):
class MidJourneyPromptChain(BaseCustomConversationChain):
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
template: Optional[
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
AI:""" # noqa: E501
class TimeTravelGuideChain(BaseCustomChain):
class TimeTravelGuideChain(BaseCustomConversationChain):
template: Optional[
str
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
AI:""" # noqa: E501
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
class CombineDocsChain(CustomChain):
"""Implementation of initialize_agent function"""
@staticmethod
def function_name():
return "load_qa_chain"
@classmethod
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
return load_qa_chain(llm=llm, chain_type=chain_type)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
"CombineDocsChain": CombineDocsChain,
"SeriesCharacterChain": SeriesCharacterChain,
"MidJourneyPromptChain": MidJourneyPromptChain,
"TimeTravelGuideChain": TimeTravelGuideChain,

View file

@ -2,7 +2,6 @@ import inspect
from typing import Any
from langchain import (
chains,
document_loaders,
embeddings,
llms,
@ -15,6 +14,8 @@ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langflow.interface.importing.utils import import_class
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.chains.custom import CUSTOM_CHAINS
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
@ -22,11 +23,6 @@ llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains
chain_type_to_cls_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
## Toolkits
toolkit_type_to_loader_dict: dict[str, Any] = {
@ -73,3 +69,6 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore

View file

@ -19,9 +19,10 @@ from langchain.chains.loading import load_chain_from_config
from langchain.llms.loading import load_llm_from_config
from pydantic import ValidationError
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import get_function, import_by_type
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.types import get_type_list
from langflow.interface.utils import load_file_into_dict
from langflow.utils import util, validate
@ -31,10 +32,11 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_AGENTS:
custom_agent = CUSTOM_AGENTS.get(node_type)
if custom_agent:
return custom_agent.initialize(**params)
if node_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(node_type):
if hasattr(custom_node, "initialize"):
return custom_node.initialize(**params)
return custom_node(**params)
class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(class_object, base_type, node_type, params)
@ -78,10 +80,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
return instantiate_textsplitter(class_object, params)
elif base_type == "utilities":
return instantiate_utility(node_type, class_object, params)
elif base_type == "chains":
return instantiate_chains(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_chains(node_type, class_object, params):
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
params["retriever"] = params["retriever"].as_retriever()
if node_type in chain_creator.from_method_nodes:
method = chain_creator.from_method_nodes[node_type]
if class_method := getattr(class_object, method, None):
return class_method(**params)
raise ValueError(f"Method {method} not found in {class_object}")
return class_object(**params)
def instantiate_agent(class_object, params):
return load_agent_executor(class_object, params)
@ -142,6 +158,14 @@ def instantiate_vectorstore(class_object, params):
"The source you provided did not load correctly or was empty."
"This may cause an error in the vectorstore."
)
# Chroma requires all metadata values to not be None
if class_object.__name__ == "Chroma":
for doc in params["documents"]:
if doc.metadata is None:
doc.metadata = {}
for key, value in doc.metadata.items():
if value is None:
doc.metadata[key] = ""
return class_object.from_documents(**params)

View file

@ -71,7 +71,3 @@ Human: {input}
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
"SeriesCharacterPrompt": SeriesCharacterPrompt
}
if __name__ == "__main__":
prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter")
print(prompt.template)

View file

@ -154,9 +154,9 @@ class CSVAgentNode(FrontendNode):
class InitializeAgentNode(FrontendNode):
name: str = "initialize_agent"
name: str = "AgentInitializer"
template: Template = Template(
type_name="initailize_agent",
type_name="initialize_agent",
fields=[
TemplateField(
field_type="str",

View file

@ -27,6 +27,9 @@ class FrontendNode(BaseModel):
def add_extra_fields(self) -> None:
pass
def add_extra_base_classes(self) -> None:
pass
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""

View file

@ -2,10 +2,24 @@ from typing import Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
from langflow.template.template.base import Template
class ChainFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
if self.template.type_name == "ConversationalRetrievalChain":
# add memory
self.template.add_field(
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
advanced=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode):
"ConversationChain",
"MidJourneyPromptChain",
]
class CombineDocsChainNode(FrontendNode):
name: str = "CombineDocsChain"
template: Template = Template(
type_name="load_qa_chain",
fields=[
TemplateField(
field_type="str",
required=True,
is_list=True,
show=True,
multiline=False,
options=QA_CHAIN_TYPES,
value=QA_CHAIN_TYPES[0],
name="chain_type",
advanced=False,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
advanced=False,
),
],
)
description: str = """Construct a zero shot agent from an LLM and tools."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# do nothing and don't return anything
pass

View file

@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
"""
HUMAN_PROMPT = "{input}"
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]

View file

@ -5,6 +5,20 @@ from langflow.template.frontend_node.base import FrontendNode
class MemoryFrontendNode(FrontendNode):
#! Needs testing
def add_extra_fields(self) -> None:
# add return_messages field
self.template.add_field(
TemplateField(
field_type="bool",
required=False,
show=True,
name="return_messages",
advanced=False,
value=False,
)
)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
@ -18,3 +32,7 @@ class MemoryFrontendNode(FrontendNode):
field.value = 10
field.display_name = "Memory Size"
field.password = False
if field.name == "return_messages":
field.required = False
field.show = True
field.advanced = False

View file

@ -74,7 +74,7 @@ class BasePromptFrontendNode(FrontendNode):
class ZeroShotPromptNode(BasePromptFrontendNode):
name: str = "ZeroShotPrompt"
template: Template = Template(
type_name="zero_shot",
type_name="ZeroShotPrompt",
fields=[
TemplateField(
field_type="str",

View file

@ -108,7 +108,7 @@ class PythonFunctionToolNode(FrontendNode):
class PythonFunctionNode(FrontendNode):
name: str = "PythonFunction"
template: Template = Template(
type_name="python_function",
type_name="PythonFunction",
fields=[
TemplateField(
field_type="code",

View file

@ -20,6 +20,9 @@ class VectorStoreFrontendNode(FrontendNode):
self.template.add_field(extra_field)
def add_extra_base_classes(self) -> None:
self.base_classes.append("BaseRetriever")
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)

View file

@ -352,7 +352,7 @@
"type": "str",
"list": false
},
"_type": "zero_shot"
"_type": "ZeroShotPrompt"
},
"description": "Prompt template for Zero Shot Agent.",
"base_classes": [

View file

@ -1,15 +1,4 @@
from fastapi.testclient import TestClient
from langflow.settings import settings
# check that all agents are in settings.agents
# are in json_response["agents"]
def test_agents_settings(client: TestClient):
response = client.get("api/v1/all")
assert response.status_code == 200
json_response = response.json()
agents = json_response["agents"]
assert set(agents.keys()) == set(settings.agents)
def test_zero_shot_agent(client: TestClient):
@ -131,7 +120,7 @@ def test_initialize_agent(client: TestClient):
json_response = response.json()
agents = json_response["agents"]
initialize_agent = agents["initialize_agent"]
initialize_agent = agents["AgentInitializer"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
template = initialize_agent["template"]