Fix: update tweaks processing to allow input_type without input_value (#5656)

* fix: tweaks processing in api usage

* refactor code

* fix: unit tests

* Refactor input handling in simple_run_flow function

* Refactor input handling in test_endpoints.py

---------

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
anovazzi1 2025-01-15 17:43:19 -03:00 committed by GitHub
commit f4fca2b6bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 30 deletions

View file

@ -8,7 +8,16 @@ from typing import TYPE_CHECKING, Annotated
from uuid import UUID
import sqlalchemy as sa
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from loguru import logger
@ -27,7 +36,11 @@ from langflow.api.v1.schemas import (
UploadFileResponse,
)
from langflow.custom.custom_component.component import Component
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
from langflow.custom.utils import (
build_custom_component_template,
get_instance_name,
update_component_build_config,
)
from langflow.events.event_manager import create_stream_tokens_event_manager
from langflow.exceptions.api import APIException, InvalidChatInputError
from langflow.exceptions.serialization import SerializationError
@ -42,9 +55,16 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.flow.utils import (
get_all_webhook_components_in_flow,
)
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
from langflow.services.deps import (
get_session_service,
get_settings_service,
get_task_service,
get_telemetry_service,
)
from langflow.services.settings.feature_flags import FEATURE_FLAGS
from langflow.services.telemetry.schema import RunPayload
from langflow.utils.version import get_version_info
@ -72,23 +92,32 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None:
# then we need to check the tweaks if the ChatInput component is present
# and if its input_value is not None
# if so, we raise an error
if input_request.tweaks is None:
if not input_request.tweaks:
return
for key, value in input_request.tweaks.items():
if "ChatInput" in key or "Chat Input" in key:
if isinstance(value, dict):
has_input_value = value.get("input_value") is not None
input_value_is_chat = input_request.input_value is not None and input_request.input_type == "chat"
if has_input_value and input_value_is_chat:
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
raise InvalidChatInputError(msg)
elif ("Text Input" in key or "TextInput" in key) and isinstance(value, dict):
has_input_value = value.get("input_value") is not None
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
if has_input_value and input_value_is_text:
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
if not isinstance(value, dict):
continue
input_value = value.get("input_value")
if input_value is None:
continue
request_has_input = input_request.input_value is not None
if any(chat_key in key for chat_key in ("ChatInput", "Chat Input")):
if request_has_input and input_request.input_type == "chat":
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
raise InvalidChatInputError(msg)
elif (
any(text_key in key for text_key in ("TextInput", "Text Input"))
and request_has_input
and input_request.input_type == "text"
):
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
raise InvalidChatInputError(msg)
async def simple_run_flow(
flow: Flow,
@ -98,8 +127,7 @@ async def simple_run_flow(
api_key_user: User | None = None,
event_manager: EventManager | None = None,
):
if input_request.input_value is not None and input_request.tweaks is not None:
validate_input_and_tweaks(input_request)
validate_input_and_tweaks(input_request)
try:
task_result: list[RunOutputs] = []
user_id = api_key_user.id if api_key_user else None
@ -110,13 +138,15 @@ async def simple_run_flow(
graph_data = flow.data.copy()
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id), flow_name=flow.name)
inputs = [
InputValueRequest(
components=[],
input_value=input_request.input_value,
type=input_request.input_type,
)
]
inputs = None
if input_request.input_value is not None:
inputs = [
InputValueRequest(
components=[],
input_value=input_request.input_value,
type=input_request.input_type,
)
]
if input_request.output_component:
outputs = [input_request.output_component]
else:

View file

@ -287,7 +287,6 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke
assert len(outputs_dict) == 2
assert "inputs" in outputs_dict
assert "outputs" in outputs_dict
assert outputs_dict.get("inputs") == {"input_value": ""}
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
@ -318,7 +317,6 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre
assert len(outputs_dict) == 2
assert "inputs" in outputs_dict
assert "outputs" in outputs_dict
assert outputs_dict.get("inputs") == {"input_value": ""}
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
@ -350,7 +348,6 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea
assert len(outputs_dict) == 2
assert "inputs" in outputs_dict
assert "outputs" in outputs_dict
assert outputs_dict.get("inputs") == {"input_value": ""}
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
@ -383,7 +380,6 @@ async def test_successful_run_with_output_type_debug(client, simple_api_test, cr
assert len(outputs_dict) == 2
assert "inputs" in outputs_dict
assert "outputs" in outputs_dict
assert outputs_dict.get("inputs") == {"input_value": ""}
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 3