diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 0ea2e0779..31cc11968 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -48,10 +48,12 @@ class Vertex: ] template_dict = self.data["node"]["template"] - # self.vertex_type = ( - # self.data["type"] if "Tool" not in self.output else template_dict["_type"] - # ) - self.vertex_type = template_dict["_type"] + self.vertex_type = ( + self.data["type"] + if "Tool" not in self.output or template_dict["_type"].islower() + else template_dict["_type"] + ) + # self.vertex_type = template_dict["_type"] if self.base_type is None: for base_type, value in ALL_TYPES_DICT.items(): if self.vertex_type in value: diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 80f451f03..a765d3b9b 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -33,8 +33,10 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: params = convert_params_to_sets(params) params = convert_kwargs(params) if node_type in CUSTOM_NODES: - if custom_agent := CUSTOM_NODES.get(node_type): - return custom_agent.initialize(**params) + if custom_node := CUSTOM_NODES.get(node_type): + if hasattr(custom_node, "initialize"): + return custom_node.initialize(**params) + return custom_node(**params) class_object = import_by_type(_type=base_type, name=node_type) return instantiate_based_on_type(class_object, base_type, node_type, params) diff --git a/src/backend/langflow/template/frontend_node/prompts.py b/src/backend/langflow/template/frontend_node/prompts.py index 8738f1795..da5d2a300 100644 --- a/src/backend/langflow/template/frontend_node/prompts.py +++ b/src/backend/langflow/template/frontend_node/prompts.py @@ -74,7 +74,7 @@ class BasePromptFrontendNode(FrontendNode): class ZeroShotPromptNode(BasePromptFrontendNode): name: str = "ZeroShotPrompt" template: Template = Template( - type_name="zero_shot", + type_name="ZeroShotPrompt", fields=[ TemplateField( field_type="str", diff --git a/src/backend/langflow/template/frontend_node/tools.py b/src/backend/langflow/template/frontend_node/tools.py index 3094f3568..dd312f906 100644 --- a/src/backend/langflow/template/frontend_node/tools.py +++ b/src/backend/langflow/template/frontend_node/tools.py @@ -108,7 +108,7 @@ class PythonFunctionToolNode(FrontendNode): class PythonFunctionNode(FrontendNode): name: str = "PythonFunction" template: Template = Template( - type_name="python_function", + type_name="PythonFunction", fields=[ TemplateField( field_type="code", diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 8e181711f..e58007238 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -1,15 +1,4 @@ from fastapi.testclient import TestClient -from langflow.settings import settings - - -# check that all agents are in settings.agents -# are in json_response["agents"] -def test_agents_settings(client: TestClient): - response = client.get("api/v1/all") - assert response.status_code == 200 - json_response = response.json() - agents = json_response["agents"] - assert set(agents.keys()) == set(settings.agents) def test_zero_shot_agent(client: TestClient): @@ -131,7 +120,7 @@ def test_initialize_agent(client: TestClient): json_response = response.json() agents = json_response["agents"] - initialize_agent = agents["initialize_agent"] + initialize_agent = agents["AgentInitializer"] assert initialize_agent["base_classes"] == ["AgentExecutor", "function"] template = initialize_agent["template"]