From f4fca2b6bbf0f9d99663b0832c2dbe5b1858df15 Mon Sep 17 00:00:00 2001 From: anovazzi1 Date: Wed, 15 Jan 2025 17:43:19 -0300 Subject: [PATCH] Fix: update tweaks processing to allow input_type without input_value (#5656) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: tweaks processing in api usage * refactor code * fix: unit tests * Refactor input handling in simple_run_flow function * Refactor input handling in test_endpoints.py --------- Co-authored-by: Ítalo Johnny --- src/backend/base/langflow/api/v1/endpoints.py | 82 +++++++++++++------ src/backend/tests/unit/test_endpoints.py | 4 - 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 47e20e808..573684dd7 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -8,7 +8,16 @@ from typing import TYPE_CHECKING, Annotated from uuid import UUID import sqlalchemy as sa -from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + Depends, + HTTPException, + Request, + UploadFile, + status, +) from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse from loguru import logger @@ -27,7 +36,11 @@ from langflow.api.v1.schemas import ( UploadFileResponse, ) from langflow.custom.custom_component.component import Component -from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config +from langflow.custom.utils import ( + build_custom_component_template, + get_instance_name, + update_component_build_config, +) from langflow.events.event_manager import create_stream_tokens_event_manager from langflow.exceptions.api import APIException, InvalidChatInputError from langflow.exceptions.serialization import SerializationError @@ -42,9 +55,16 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.flow.model import FlowRead -from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow +from langflow.services.database.models.flow.utils import ( + get_all_webhook_components_in_flow, +) from langflow.services.database.models.user.model import User, UserRead -from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service +from langflow.services.deps import ( + get_session_service, + get_settings_service, + get_task_service, + get_telemetry_service, +) from langflow.services.settings.feature_flags import FEATURE_FLAGS from langflow.services.telemetry.schema import RunPayload from langflow.utils.version import get_version_info @@ -72,23 +92,32 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None: # then we need to check the tweaks if the ChatInput component is present # and if its input_value is not None # if so, we raise an error - if input_request.tweaks is None: + if not input_request.tweaks: return + for key, value in input_request.tweaks.items(): - if "ChatInput" in key or "Chat Input" in key: - if isinstance(value, dict): - has_input_value = value.get("input_value") is not None - input_value_is_chat = input_request.input_value is not None and input_request.input_type == "chat" - if has_input_value and input_value_is_chat: - msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name." - raise InvalidChatInputError(msg) - elif ("Text Input" in key or "TextInput" in key) and isinstance(value, dict): - has_input_value = value.get("input_value") is not None - input_value_is_text = input_request.input_value is not None and input_request.input_type == "text" - if has_input_value and input_value_is_text: - msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name." + if not isinstance(value, dict): + continue + + input_value = value.get("input_value") + if input_value is None: + continue + + request_has_input = input_request.input_value is not None + + if any(chat_key in key for chat_key in ("ChatInput", "Chat Input")): + if request_has_input and input_request.input_type == "chat": + msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name." raise InvalidChatInputError(msg) + elif ( + any(text_key in key for text_key in ("TextInput", "Text Input")) + and request_has_input + and input_request.input_type == "text" + ): + msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name." + raise InvalidChatInputError(msg) + async def simple_run_flow( flow: Flow, @@ -98,8 +127,7 @@ async def simple_run_flow( api_key_user: User | None = None, event_manager: EventManager | None = None, ): - if input_request.input_value is not None and input_request.tweaks is not None: - validate_input_and_tweaks(input_request) + validate_input_and_tweaks(input_request) try: task_result: list[RunOutputs] = [] user_id = api_key_user.id if api_key_user else None @@ -110,13 +138,15 @@ async def simple_run_flow( graph_data = flow.data.copy() graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream) graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id), flow_name=flow.name) - inputs = [ - InputValueRequest( - components=[], - input_value=input_request.input_value, - type=input_request.input_type, - ) - ] + inputs = None + if input_request.input_value is not None: + inputs = [ + InputValueRequest( + components=[], + input_value=input_request.input_value, + type=input_request.input_type, + ) + ] if input_request.output_component: outputs = [input_request.output_component] else: diff --git a/src/backend/tests/unit/test_endpoints.py b/src/backend/tests/unit/test_endpoints.py index 1c1b7dd3f..7e2e9dbc0 100644 --- a/src/backend/tests/unit/test_endpoints.py +++ b/src/backend/tests/unit/test_endpoints.py @@ -287,7 +287,6 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke assert len(outputs_dict) == 2 assert "inputs" in outputs_dict assert "outputs" in outputs_dict - assert outputs_dict.get("inputs") == {"input_value": ""} assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] @@ -318,7 +317,6 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre assert len(outputs_dict) == 2 assert "inputs" in outputs_dict assert "outputs" in outputs_dict - assert outputs_dict.get("inputs") == {"input_value": ""} assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] @@ -350,7 +348,6 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea assert len(outputs_dict) == 2 assert "inputs" in outputs_dict assert "outputs" in outputs_dict - assert outputs_dict.get("inputs") == {"input_value": ""} assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] @@ -383,7 +380,6 @@ async def test_successful_run_with_output_type_debug(client, simple_api_test, cr assert len(outputs_dict) == 2 assert "inputs" in outputs_dict assert "outputs" in outputs_dict - assert outputs_dict.get("inputs") == {"input_value": ""} assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 3