diff --git a/src/backend/langflow/interface/agents/base.py b/src/backend/langflow/interface/agents/base.py index d404f845d..b272144bc 100644 --- a/src/backend/langflow/interface/agents/base.py +++ b/src/backend/langflow/interface/agents/base.py @@ -6,13 +6,20 @@ from langflow.custom.customs import get_custom_nodes from langflow.interface.agents.custom import CUSTOM_AGENTS from langflow.interface.base import LangChainTypeCreator from langflow.settings import settings +from langflow.template.frontend_node.agents import AgentFrontendNode from langflow.utils.logger import logger -from langflow.utils.util import build_template_from_class +from langflow.utils.util import build_template_from_class, build_template_from_method class AgentCreator(LangChainTypeCreator): type_name: str = "agents" + from_method_nodes = {"ZeroShotAgent": "from_llm_and_tools"} + + @property + def frontend_node_class(self) -> type[AgentFrontendNode]: + return AgentFrontendNode + @property def type_to_loader_dict(self) -> Dict: if self.type_dict is None: @@ -27,6 +34,13 @@ class AgentCreator(LangChainTypeCreator): try: if name in get_custom_nodes(self.type_name).keys(): return get_custom_nodes(self.type_name)[name] + elif name in self.from_method_nodes: + return build_template_from_method( + name, + type_to_cls_dict=self.type_to_loader_dict, + add_function=True, + method_name=self.from_method_nodes[name], + ) return build_template_from_class( name, self.type_to_loader_dict, add_function=True ) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index cd5585656..3c3616dd8 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -15,6 +15,7 @@ from pydantic import ValidationError from langflow.interface.custom_lists import CUSTOM_NODES from langflow.interface.importing.utils import get_function, import_by_type +from langflow.interface.agents.base import agent_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.output_parsers.base import output_parser_creator @@ -61,7 +62,7 @@ def convert_kwargs(params): def instantiate_based_on_type(class_object, base_type, node_type, params): if base_type == "agents": - return instantiate_agent(class_object, params) + return instantiate_agent(node_type, class_object, params) elif base_type == "prompts": return instantiate_prompt(node_type, class_object, params) elif base_type == "tools": @@ -159,7 +160,16 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict): return class_object(**params) -def instantiate_agent(class_object: Type[agent_module.Agent], params: Dict): +def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: Dict): + if node_type in agent_creator.from_method_nodes: + method = agent_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + agent = class_method(**params) + tools = params.get("tools", []) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + ) return load_agent_executor(class_object, params) diff --git a/src/backend/langflow/template/frontend_node/agents.py b/src/backend/langflow/template/frontend_node/agents.py index f692a7d6c..b988cbe20 100644 --- a/src/backend/langflow/template/frontend_node/agents.py +++ b/src/backend/langflow/template/frontend_node/agents.py @@ -13,6 +13,17 @@ NON_CHAT_AGENTS = { } +class AgentFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: str | None = None) -> None: + if field.name in ["suffix", "prefix"]: + field.show = True + if field.name == "Tools" and name == "ZeroShotAgent": + # field. + field.type_name = "BaseTool" + field.is_list = True + + class SQLAgentNode(FrontendNode): name: str = "SQLAgent" template: Template = Template(