Add New Chains: RetrievalQA, RetrievalQAWithSourcesChain, ConversationalRetrievalChain, CombineDocsChain
This commit is contained in:
parent
5e99ba0e99
commit
144f2b470e
17 changed files with 761 additions and 66 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.cache import cache_manager
|
||||
from langflow.interface.loading import load_flow_from_json
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
|
||||
__all__ = ["load_flow_from_json", "cache_manager"]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ chains:
|
|||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
- SQLDatabaseChain
|
||||
- RetrievalQA
|
||||
- RetrievalQAWithSourcesChain
|
||||
- ConversationalRetrievalChain
|
||||
- CombineDocsChain
|
||||
documentloaders:
|
||||
- AirbyteJSONLoader
|
||||
- CoNLLULoader
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ CUSTOM_NODES = {
|
|||
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
|
||||
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
|
||||
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
|
||||
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
|
|
@ -81,5 +81,24 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
)
|
||||
|
||||
signature.add_extra_fields()
|
||||
signature.add_extra_base_classes()
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
class CustomChain(Chain, ABC):
|
||||
"""Custom chain"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomChain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import chain_type_to_cls_dict
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.chains import ChainFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langchain import chains
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
|
|||
def frontend_node_class(self) -> Type[ChainFrontendNode]:
|
||||
return ChainFrontendNode
|
||||
|
||||
#! We need to find a better solution for this
|
||||
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = chain_type_to_cls_dict
|
||||
self.type_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
self.type_dict.update(CUSTOM_CHAINS)
|
||||
|
|
@ -37,20 +44,32 @@ class ChainCreator(LangChainTypeCreator):
|
|||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
elif name in self.from_method_nodes.keys():
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
add_function=True,
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Chain not found") from exc
|
||||
raise ValueError(f"Chain {name} not found: {exc}") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Chain {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
custom_chains = list(get_custom_nodes("chains").keys())
|
||||
default_chains = list(self.type_to_loader_dict.keys())
|
||||
|
||||
return default_chains + custom_chains
|
||||
names = []
|
||||
for _, chain in self.type_to_loader_dict.items():
|
||||
chain_name = (
|
||||
chain.function_name()
|
||||
if hasattr(chain, "function_name")
|
||||
else chain.__name__
|
||||
)
|
||||
names.append(chain_name)
|
||||
return names
|
||||
|
||||
|
||||
chain_creator = ChainCreator()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from langflow.interface.base import CustomChain
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
DEFAULT_SUFFIX = """"
|
||||
Current conversation:
|
||||
|
|
@ -14,7 +16,7 @@ Human: {input}
|
|||
{ai_prefix}"""
|
||||
|
||||
|
||||
class BaseCustomChain(ConversationChain):
|
||||
class BaseCustomConversationChain(ConversationChain):
|
||||
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
|
||||
|
||||
template: Optional[str]
|
||||
|
|
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
|
|||
return values
|
||||
|
||||
|
||||
class SeriesCharacterChain(BaseCustomChain):
|
||||
class SeriesCharacterChain(BaseCustomConversationChain):
|
||||
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
|
||||
|
||||
character: str
|
||||
|
|
@ -66,7 +68,7 @@ Human: {input}
|
|||
"""Default memory store."""
|
||||
|
||||
|
||||
class MidJourneyPromptChain(BaseCustomChain):
|
||||
class MidJourneyPromptChain(BaseCustomConversationChain):
|
||||
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
|
||||
|
||||
template: Optional[
|
||||
|
|
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
class TimeTravelGuideChain(BaseCustomChain):
|
||||
class TimeTravelGuideChain(BaseCustomConversationChain):
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
|
|
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
|
||||
class CombineDocsChain(CustomChain):
|
||||
"""Implementation of AgentInitializer function"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "load_qa_chain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
|
||||
return load_qa_chain(llm=llm, chain_type=chain_type)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
|
||||
"CombineDocsChain": CombineDocsChain,
|
||||
"SeriesCharacterChain": SeriesCharacterChain,
|
||||
"MidJourneyPromptChain": MidJourneyPromptChain,
|
||||
"TimeTravelGuideChain": TimeTravelGuideChain,
|
||||
|
|
|
|||
|
|
@ -14,18 +14,20 @@ from langchain.agents import agent_toolkits
|
|||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
## LLMs
|
||||
# LLMs
|
||||
llm_type_to_cls_dict = llms.type_to_cls_dict
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
|
||||
## Chains
|
||||
# Chains
|
||||
chain_type_to_cls_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
|
||||
## Toolkits
|
||||
# Toolkits
|
||||
toolkit_type_to_loader_dict: dict[str, Any] = {
|
||||
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
|
||||
# if toolkit_name is lower case it is a loader
|
||||
|
|
@ -40,25 +42,25 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
|
|||
if not toolkit_name.islower()
|
||||
}
|
||||
|
||||
## Memories
|
||||
# Memories
|
||||
memory_type_to_cls_dict: dict[str, Any] = {
|
||||
memory_name: import_class(f"langchain.memory.{memory_name}")
|
||||
for memory_name in memory.__all__
|
||||
}
|
||||
|
||||
## Wrappers
|
||||
# Wrappers
|
||||
wrapper_type_to_cls_dict: dict[str, Any] = {
|
||||
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
|
||||
}
|
||||
|
||||
## Embeddings
|
||||
# Embeddings
|
||||
embedding_type_to_cls_dict: dict[str, Any] = {
|
||||
embedding_name: import_class(f"langchain.embeddings.{embedding_name}")
|
||||
for embedding_name in embeddings.__all__
|
||||
}
|
||||
|
||||
|
||||
## Document Loaders
|
||||
# Document Loaders
|
||||
documentloaders_type_to_cls_dict: dict[str, Any] = {
|
||||
documentloader_name: import_class(
|
||||
f"langchain.document_loaders.{documentloader_name}"
|
||||
|
|
@ -66,7 +68,10 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
|
|||
for documentloader_name in document_loaders.__all__
|
||||
}
|
||||
|
||||
## Text Splitters
|
||||
# Text Splitters
|
||||
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
||||
inspect.getmembers(text_splitter, inspect.isclass)
|
||||
)
|
||||
|
||||
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.tools import BaseTool
|
||||
from langflow.utils import validate
|
||||
|
||||
|
||||
def import_module(module_path: str) -> Any:
|
||||
|
|
@ -147,3 +148,10 @@ def import_utility(utility: str) -> Any:
|
|||
if utility == "SQLDatabase":
|
||||
return import_class(f"langchain.sql_database.{utility}")
|
||||
return import_class(f"langchain.utilities.{utility}")
|
||||
|
||||
|
||||
def get_function(code):
|
||||
"""Get the function"""
|
||||
function_name = validate.extract_function_name(code)
|
||||
|
||||
return validate.create_function(code, function_name)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ from langchain.chains.loading import load_chain_from_config
|
|||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import import_by_type
|
||||
from langflow.interface.run import fix_memory_inputs
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.types import get_type_list
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.utils import util, validate
|
||||
|
|
@ -32,10 +32,11 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_AGENTS:
|
||||
custom_agent = CUSTOM_AGENTS.get(node_type)
|
||||
if custom_agent:
|
||||
return custom_agent.initialize(**params)
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
|
@ -79,10 +80,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
elif base_type == "chains":
|
||||
return instantiate_chains(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_chains(node_type, class_object, params):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
method = chain_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
|
@ -161,38 +176,6 @@ def instantiate_utility(node_type, class_object, params):
|
|||
return class_object(**params)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
"""Load flow from json file"""
|
||||
# This is done to avoid circular imports
|
||||
from langflow.graph import Graph
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
if build:
|
||||
langchain_object = graph.build()
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
fix_memory_inputs(langchain_object)
|
||||
return langchain_object
|
||||
return graph
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
"""Replace ZeroShotPrompt with PromptTemplate"""
|
||||
for node in nodes:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
|
@ -52,3 +53,8 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
llm.stream = True
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
||||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
|
|
|||
0
src/backend/langflow/processing/__init__.py
Normal file
0
src/backend/langflow/processing/__init__.py
Normal file
61
src/backend/langflow/processing/process.py
Normal file
61
src/backend/langflow/processing/process.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import json
|
||||
from langflow.interface.run import (
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
|
||||
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
|
||||
get_memory_key function and updates the memory keys using the update_memory_keys function.
|
||||
"""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
try:
|
||||
if langchain_object.memory.memory_key in langchain_object.input_variables:
|
||||
return
|
||||
except AttributeError:
|
||||
input_variables = (
|
||||
langchain_object.prompt.input_variables
|
||||
if hasattr(langchain_object, "prompt")
|
||||
else langchain_object.input_keys
|
||||
)
|
||||
if langchain_object.memory.memory_key in input_variables:
|
||||
return
|
||||
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
"""Load flow from json file"""
|
||||
# This is done to avoid circular imports
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
if build:
|
||||
langchain_object = graph.build()
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
fix_memory_inputs(langchain_object)
|
||||
return langchain_object
|
||||
return graph
|
||||
|
|
@ -27,6 +27,9 @@ class FrontendNode(BaseModel):
|
|||
def add_extra_fields(self) -> None:
|
||||
pass
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
|
|
|
|||
|
|
@ -2,10 +2,24 @@ from typing import Optional
|
|||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
|
||||
from langflow.template.template.base import Template
|
||||
|
||||
|
||||
class ChainFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
if self.template.type_name == "ConversationalRetrievalChain":
|
||||
# add memory
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="BaseChatMemory",
|
||||
required=False,
|
||||
show=True,
|
||||
name="memory",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
|
@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode):
|
|||
"ConversationChain",
|
||||
"MidJourneyPromptChain",
|
||||
]
|
||||
|
||||
|
||||
class CombineDocsChainNode(FrontendNode):
|
||||
name: str = "CombineDocsChain"
|
||||
template: Template = Template(
|
||||
type_name="load_qa_chain",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=True,
|
||||
show=True,
|
||||
multiline=False,
|
||||
options=QA_CHAIN_TYPES,
|
||||
value=QA_CHAIN_TYPES[0],
|
||||
name="chain_type",
|
||||
advanced=False,
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
display_name="LLM",
|
||||
advanced=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a zero shot agent from an LLM and tools."""
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
# do nothing and don't return anything
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
|
|||
"""
|
||||
|
||||
HUMAN_PROMPT = "{input}"
|
||||
|
||||
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
|
||||
self.template.add_field(extra_field)
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes.append("BaseRetriever")
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue