diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index ad1d53ee8..b820cd53c 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -1,7 +1,9 @@ from langflow.template import nodes CUSTOM_NODES = { - "prompts": {**nodes.ZeroShotPromptNode().to_dict()}, + "prompts": { + **nodes.ZeroShotPromptNode().to_dict(), + }, "tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()}, "agents": { **nodes.JsonAgentNode().to_dict(), diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index eb7b8898b..8ed6cd7e3 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Union from pydantic import BaseModel +from langflow.template.constants import FORCE_SHOW_FIELDS from langflow.utils import constants @@ -68,17 +69,7 @@ class TemplateFieldCreator(BaseModel, ABC): # Show or not field self.show = bool( (self.required and key not in ["input_variables"]) - or key - in [ - "allowed_tools", - "memory", - "prefix", - "examples", - "temperature", - "model_name", - "headers", - "max_value_length", - ] + or key in FORCE_SHOW_FIELDS or "api_key" in key ) diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 3504bd910..46d39db9d 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -49,6 +49,16 @@ class ZeroShotPromptNode(FrontendNode): return super().to_dict() +class PromptTemplateNode(FrontendNode): + name: str = "PromptTemplate" + template: Template + description: str + base_classes: list[str] = ["BasePromptTemplate"] + + def to_dict(self): + return super().to_dict() + + class PythonFunctionNode(FrontendNode): name: str = "PythonFunction" template: Template = Template( diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 59a19eb33..68dfb0a69 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -3,6 +3,7 @@ import inspect import re from typing import Dict, Optional +from langflow.template.constants import FORCE_SHOW_FIELDS from langflow.utils import constants @@ -284,17 +285,7 @@ def format_dict(d, name: Optional[str] = None): # Show or not field value["show"] = bool( (value["required"] and key not in ["input_variables"]) - or key - in [ - "allowed_tools", - "memory", - "prefix", - "examples", - "temperature", - "model_name", - "headers", - "max_value_length", - ] + or key in FORCE_SHOW_FIELDS or "api_key" in key ) diff --git a/tests/conftest.py b/tests/conftest.py index e6eb3562f..7e8316384 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest from fastapi.testclient import TestClient +from langflow.graph.graph import Graph def pytest_configure(): diff --git a/tests/test_graph.py b/tests/test_graph.py index 0ea2d0f51..bbdacc7cb 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -340,43 +340,21 @@ def test_build_params(basic_graph): assert isinstance(llm_node.params["model_name"], str) -def test_build(basic_graph, complex_graph): +def test_build(basic_graph, complex_graph, openapi_graph): """Test Node's build method""" - # def build(self): - # # The params dict is used to build the module - # # it contains values and keys that point to nodes which - # # have their own params dict - # # When build is called, we iterate through the params dict - # # and if the value is a node, we call build on that node - # # and use the output of that build as the value for the param - # # if the value is not a node, then we use the value as the param - # # and continue - # # Another aspect is that the node_type is the class that we need to import - # # and instantiate with these built params + assert_agent_was_built(basic_graph) + assert_agent_was_built(complex_graph) + assert_agent_was_built(openapi_graph) - # # Build each node in the params dict - # for key, value in self.params.items(): - # if isinstance(value, Node): - # self.params[key] = value.build() - # # Get the class from LANGCHAIN_TYPES_DICT - # # and instantiate it with the params - # # and return the instance - # return LANGCHAIN_TYPES_DICT[self.node_type](**self.params) - - assert isinstance(basic_graph, Graph) +def assert_agent_was_built(graph): + """Assert that the agent was built""" + assert isinstance(graph, Graph) # Now we test the build method # Build the Agent - agent = basic_graph.build() + result = graph.build() # The agent should be a AgentExecutor - assert isinstance(agent, AgentExecutor) - - # Now we test the complex example - assert isinstance(complex_graph, Graph) - # Now we test the build method - agent = complex_graph.build() - # The agent should be a AgentExecutor - assert isinstance(agent, AgentExecutor) + assert isinstance(result, AgentExecutor) def test_agent_node_build(basic_graph): @@ -384,7 +362,6 @@ def test_agent_node_build(basic_graph): assert agent_node is not None built_object = agent_node.build() assert built_object is not None - # Add any further assertions specific to the AgentNode's build() method def test_tool_node_build(basic_graph): diff --git a/tests/test_template.py b/tests/test_template.py index 9f7d78c55..f557256e0 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -1,4 +1,5 @@ import importlib +import re from typing import Dict, List, Optional import pytest