🚀 feat(langflow): add new Retrieval chains (#456)
This commit is contained in:
commit
c216c78b1b
21 changed files with 226 additions and 71 deletions
1
Makefile
1
Makefile
|
|
@ -51,7 +51,6 @@ build_and_run:
|
|||
rm -rf dist
|
||||
make build && poetry run pip install dist/*.tar.gz && poetry run langflow
|
||||
|
||||
|
||||
build_frontend:
|
||||
cd src/frontend && CI='' npm run build
|
||||
cp -r src/frontend/build src/backend/langflow/frontend
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ chains:
|
|||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
- SQLDatabaseChain
|
||||
- RetrievalQA
|
||||
- RetrievalQAWithSourcesChain
|
||||
- ConversationalRetrievalChain
|
||||
- CombineDocsChain
|
||||
documentloaders:
|
||||
- AirbyteJSONLoader
|
||||
- CoNLLULoader
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from langflow.template import frontend_node
|
|||
|
||||
# These should always be instantiated
|
||||
CUSTOM_NODES = {
|
||||
"prompts": {"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode()},
|
||||
"prompts": {
|
||||
"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
|
||||
},
|
||||
"tools": {
|
||||
"PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(),
|
||||
"PythonFunction": frontend_node.tools.PythonFunctionNode(),
|
||||
|
|
@ -23,6 +25,7 @@ CUSTOM_NODES = {
|
|||
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
|
||||
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
|
||||
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
|
||||
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,8 +49,11 @@ class Vertex:
|
|||
|
||||
template_dict = self.data["node"]["template"]
|
||||
self.vertex_type = (
|
||||
self.data["type"] if "Tool" not in self.output else template_dict["_type"]
|
||||
self.data["type"]
|
||||
if "Tool" not in self.output or template_dict["_type"].islower()
|
||||
else template_dict["_type"]
|
||||
)
|
||||
|
||||
if self.base_type is None:
|
||||
for base_type, value in ALL_TYPES_DICT.items():
|
||||
if self.vertex_type in value:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
|
|
@ -33,24 +32,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom agent executor"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomAgentExecutor"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
from langflow.interface.base import CustomAgentExecutor
|
||||
|
||||
|
||||
class JsonAgent(CustomAgentExecutor):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.agents import AgentExecutor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
|
|
@ -81,5 +82,42 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
)
|
||||
|
||||
signature.add_extra_fields()
|
||||
signature.add_extra_base_classes()
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
class CustomChain(Chain, ABC):
|
||||
"""Custom chain"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomChain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor, ABC):
|
||||
"""Custom chain"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "CustomChain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import chain_type_to_cls_dict
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.chains import ChainFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langchain import chains
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -18,10 +19,16 @@ class ChainCreator(LangChainTypeCreator):
|
|||
def frontend_node_class(self) -> Type[ChainFrontendNode]:
|
||||
return ChainFrontendNode
|
||||
|
||||
#! We need to find a better solution for this
|
||||
from_method_nodes = {"ConversationalRetrievalChain": "from_llm"}
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = chain_type_to_cls_dict
|
||||
self.type_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
self.type_dict.update(CUSTOM_CHAINS)
|
||||
|
|
@ -37,20 +44,32 @@ class ChainCreator(LangChainTypeCreator):
|
|||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
elif name in self.from_method_nodes.keys():
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
add_function=True,
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Chain not found") from exc
|
||||
raise ValueError(f"Chain {name} not found: {exc}") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Chain {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
custom_chains = list(get_custom_nodes("chains").keys())
|
||||
default_chains = list(self.type_to_loader_dict.keys())
|
||||
|
||||
return default_chains + custom_chains
|
||||
names = []
|
||||
for _, chain in self.type_to_loader_dict.items():
|
||||
chain_name = (
|
||||
chain.function_name()
|
||||
if hasattr(chain, "function_name")
|
||||
else chain.__name__
|
||||
)
|
||||
names.append(chain_name)
|
||||
return names
|
||||
|
||||
|
||||
chain_creator = ChainCreator()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from langflow.interface.base import CustomChain
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
DEFAULT_SUFFIX = """"
|
||||
Current conversation:
|
||||
|
|
@ -14,7 +16,7 @@ Human: {input}
|
|||
{ai_prefix}"""
|
||||
|
||||
|
||||
class BaseCustomChain(ConversationChain):
|
||||
class BaseCustomConversationChain(ConversationChain):
|
||||
"""BaseCustomChain is a chain you can use to have a conversation with a custom character."""
|
||||
|
||||
template: Optional[str]
|
||||
|
|
@ -47,7 +49,7 @@ class BaseCustomChain(ConversationChain):
|
|||
return values
|
||||
|
||||
|
||||
class SeriesCharacterChain(BaseCustomChain):
|
||||
class SeriesCharacterChain(BaseCustomConversationChain):
|
||||
"""SeriesCharacterChain is a chain you can use to have a conversation with a character from a series."""
|
||||
|
||||
character: str
|
||||
|
|
@ -66,7 +68,7 @@ Human: {input}
|
|||
"""Default memory store."""
|
||||
|
||||
|
||||
class MidJourneyPromptChain(BaseCustomChain):
|
||||
class MidJourneyPromptChain(BaseCustomConversationChain):
|
||||
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
|
||||
|
||||
template: Optional[
|
||||
|
|
@ -84,7 +86,7 @@ class MidJourneyPromptChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
class TimeTravelGuideChain(BaseCustomChain):
|
||||
class TimeTravelGuideChain(BaseCustomConversationChain):
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
|
|
@ -94,7 +96,26 @@ class TimeTravelGuideChain(BaseCustomChain):
|
|||
AI:""" # noqa: E501
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[ConversationChain]] = {
|
||||
class CombineDocsChain(CustomChain):
|
||||
"""Implementation of initialize_agent function"""
|
||||
|
||||
@staticmethod
|
||||
def function_name():
|
||||
return "load_qa_chain"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, llm: BaseLanguageModel, chain_type: str):
|
||||
return load_qa_chain(llm=llm, chain_type=chain_type)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
CUSTOM_CHAINS: Dict[str, Type[Union[ConversationChain, CustomChain]]] = {
|
||||
"CombineDocsChain": CombineDocsChain,
|
||||
"SeriesCharacterChain": SeriesCharacterChain,
|
||||
"MidJourneyPromptChain": MidJourneyPromptChain,
|
||||
"TimeTravelGuideChain": TimeTravelGuideChain,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import inspect
|
|||
from typing import Any
|
||||
|
||||
from langchain import (
|
||||
chains,
|
||||
document_loaders,
|
||||
embeddings,
|
||||
llms,
|
||||
|
|
@ -15,6 +14,8 @@ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
|||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
## LLMs
|
||||
llm_type_to_cls_dict = llms.type_to_cls_dict
|
||||
|
|
@ -22,11 +23,6 @@ llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
|
|||
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
|
||||
## Chains
|
||||
chain_type_to_cls_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
}
|
||||
|
||||
## Toolkits
|
||||
toolkit_type_to_loader_dict: dict[str, Any] = {
|
||||
|
|
@ -73,3 +69,6 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
|
|||
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
||||
inspect.getmembers(text_splitter, inspect.isclass)
|
||||
)
|
||||
|
||||
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ from langchain.chains.loading import load_chain_from_config
|
|||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import get_function, import_by_type
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.types import get_type_list
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.utils import util, validate
|
||||
|
|
@ -31,10 +32,11 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_AGENTS:
|
||||
custom_agent = CUSTOM_AGENTS.get(node_type)
|
||||
if custom_agent:
|
||||
return custom_agent.initialize(**params)
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
|
@ -78,10 +80,24 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
elif base_type == "chains":
|
||||
return instantiate_chains(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_chains(node_type, class_object, params):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
method = chain_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
|
@ -142,6 +158,14 @@ def instantiate_vectorstore(class_object, params):
|
|||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
# Chroma requires all metadata values to not be None
|
||||
if class_object.__name__ == "Chroma":
|
||||
for doc in params["documents"]:
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,3 @@ Human: {input}
|
|||
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
|
||||
"SeriesCharacterPrompt": SeriesCharacterPrompt
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt = SeriesCharacterPrompt(character="Harry Potter", series="Harry Potter")
|
||||
print(prompt.template)
|
||||
|
|
|
|||
|
|
@ -154,9 +154,9 @@ class CSVAgentNode(FrontendNode):
|
|||
|
||||
|
||||
class InitializeAgentNode(FrontendNode):
|
||||
name: str = "initialize_agent"
|
||||
name: str = "AgentInitializer"
|
||||
template: Template = Template(
|
||||
type_name="initailize_agent",
|
||||
type_name="initialize_agent",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ class FrontendNode(BaseModel):
|
|||
def add_extra_fields(self) -> None:
|
||||
pass
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
|
|
|
|||
|
|
@ -2,10 +2,24 @@ from typing import Optional
|
|||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import QA_CHAIN_TYPES
|
||||
from langflow.template.template.base import Template
|
||||
|
||||
|
||||
class ChainFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
if self.template.type_name == "ConversationalRetrievalChain":
|
||||
# add memory
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="BaseChatMemory",
|
||||
required=False,
|
||||
show=True,
|
||||
name="memory",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
|
@ -155,3 +169,41 @@ class MidJourneyPromptChainNode(FrontendNode):
|
|||
"ConversationChain",
|
||||
"MidJourneyPromptChain",
|
||||
]
|
||||
|
||||
|
||||
class CombineDocsChainNode(FrontendNode):
|
||||
name: str = "CombineDocsChain"
|
||||
template: Template = Template(
|
||||
type_name="load_qa_chain",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=True,
|
||||
show=True,
|
||||
multiline=False,
|
||||
options=QA_CHAIN_TYPES,
|
||||
value=QA_CHAIN_TYPES[0],
|
||||
name="chain_type",
|
||||
advanced=False,
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
display_name="LLM",
|
||||
advanced=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a zero shot agent from an LLM and tools."""
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
# do nothing and don't return anything
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -30,3 +30,5 @@ You are a good listener and you can talk about anything.
|
|||
"""
|
||||
|
||||
HUMAN_PROMPT = "{input}"
|
||||
|
||||
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,20 @@ from langflow.template.frontend_node.base import FrontendNode
|
|||
|
||||
|
||||
class MemoryFrontendNode(FrontendNode):
|
||||
#! Needs testing
|
||||
def add_extra_fields(self) -> None:
|
||||
# add return_messages field
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="return_messages",
|
||||
advanced=False,
|
||||
value=False,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
|
@ -18,3 +32,7 @@ class MemoryFrontendNode(FrontendNode):
|
|||
field.value = 10
|
||||
field.display_name = "Memory Size"
|
||||
field.password = False
|
||||
if field.name == "return_messages":
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class BasePromptFrontendNode(FrontendNode):
|
|||
class ZeroShotPromptNode(BasePromptFrontendNode):
|
||||
name: str = "ZeroShotPrompt"
|
||||
template: Template = Template(
|
||||
type_name="zero_shot",
|
||||
type_name="ZeroShotPrompt",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class PythonFunctionToolNode(FrontendNode):
|
|||
class PythonFunctionNode(FrontendNode):
|
||||
name: str = "PythonFunction"
|
||||
template: Template = Template(
|
||||
type_name="python_function",
|
||||
type_name="PythonFunction",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="code",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
|
||||
self.template.add_field(extra_field)
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes.append("BaseRetriever")
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@
|
|||
"type": "str",
|
||||
"list": false
|
||||
},
|
||||
"_type": "zero_shot"
|
||||
"_type": "ZeroShotPrompt"
|
||||
},
|
||||
"description": "Prompt template for Zero Shot Agent.",
|
||||
"base_classes": [
|
||||
|
|
|
|||
|
|
@ -1,15 +1,4 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.settings import settings
|
||||
|
||||
|
||||
# check that all agents are in settings.agents
|
||||
# are in json_response["agents"]
|
||||
def test_agents_settings(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
assert set(agents.keys()) == set(settings.agents)
|
||||
|
||||
|
||||
def test_zero_shot_agent(client: TestClient):
|
||||
|
|
@ -131,7 +120,7 @@ def test_initialize_agent(client: TestClient):
|
|||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
||||
initialize_agent = agents["initialize_agent"]
|
||||
initialize_agent = agents["AgentInitializer"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
|
||||
template = initialize_agent["template"]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue