From 229717a98af322c4dbf481c33b1e25fa8e3e96fc Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 25 Oct 2023 18:13:58 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(endpoints.py):=20fix=20missi?= =?UTF-8?q?ng=20return=20statement=20in=20get=5Fall=20function=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20fix=20issue=20with=20args=5Fsche?= =?UTF-8?q?ma=20value=20for=20PythonInputs=20template=20=F0=9F=90=9B=20fix?= =?UTF-8?q?(test=5Fprompts=5Ftemplate.py):=20fix=20incorrect=20value=20for?= =?UTF-8?q?=20validate=5Ftemplate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/endpoints.py | 3 ++- src/backend/langflow/interface/tools/base.py | 9 ++++++++- tests/test_prompts_template.py | 17 +---------------- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 4e68193d8..5c76539ab 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -54,9 +54,10 @@ def get_all( logger.debug("Building langchain types dict") try: - return get_all_types_dict(settings_service) + types_dict = get_all_types_dict(settings_service) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc + return types_dict # For backwards compatibility we will keep the old endpoint diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index a99025ff7..796d9d69c 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -5,6 +5,7 @@ from langchain.agents.load_tools import ( _EXTRA_OPTIONAL_TOOLS, _LLM_TOOLS, ) +from langchain.tools.python.tool import PythonInputs from langflow.custom import customs from langflow.interface.base import LangChainTypeCreator @@ -161,8 +162,14 @@ class ToolCreator(LangChainTypeCreator): template = Template(fields=fields, type_name=tool_type) tool_params = {**tool_params, **self.type_to_loader_dict[name]["params"]} + template_dict = template.to_dict() + if ( + "args_schema" in template_dict + and template_dict.get("args_schema").get("value") == PythonInputs + ): + template_dict["args_schema"]["value"] = "" return { - "template": util.format_dict(template.to_dict()), + "template": util.format_dict(template_dict), **tool_params, "base_classes": base_classes, } diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index ae9c1b4f6..f9292f441 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -75,28 +75,13 @@ def test_prompt_template(client: TestClient, logged_in_headers): "info": "", } - assert template["template_format"] == { - "required": False, - "dynamic": False, - "placeholder": "", - "show": False, - "multiline": False, - "value": "f-string", - "password": False, - "name": "template_format", - "type": "str", - "list": False, - "advanced": False, - "info": "", - } - assert template["validate_template"] == { "required": False, "dynamic": False, "placeholder": "", "show": False, "multiline": False, - "value": True, + "value": False, "password": False, "name": "validate_template", "type": "bool",