Fix: update tweaks processing to allow input_type without input_value (#5656)
* fix: tweaks processing in api usage * refactor code * fix: unit tests * Refactor input handling in simple_run_flow function * Refactor input handling in test_endpoints.py --------- Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
1f63ebeaef
commit
f4fca2b6bb
2 changed files with 56 additions and 30 deletions
|
|
@ -8,7 +8,16 @@ from typing import TYPE_CHECKING, Annotated
|
|||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
|
@ -27,7 +36,11 @@ from langflow.api.v1.schemas import (
|
|||
UploadFileResponse,
|
||||
)
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
|
||||
from langflow.custom.utils import (
|
||||
build_custom_component_template,
|
||||
get_instance_name,
|
||||
update_component_build_config,
|
||||
)
|
||||
from langflow.events.event_manager import create_stream_tokens_event_manager
|
||||
from langflow.exceptions.api import APIException, InvalidChatInputError
|
||||
from langflow.exceptions.serialization import SerializationError
|
||||
|
|
@ -42,9 +55,16 @@ from langflow.services.auth.utils import api_key_security, get_current_active_us
|
|||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
|
||||
from langflow.services.database.models.flow.utils import (
|
||||
get_all_webhook_components_in_flow,
|
||||
)
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
|
||||
from langflow.services.deps import (
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
get_telemetry_service,
|
||||
)
|
||||
from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
||||
from langflow.services.telemetry.schema import RunPayload
|
||||
from langflow.utils.version import get_version_info
|
||||
|
|
@ -72,23 +92,32 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None:
|
|||
# then we need to check the tweaks if the ChatInput component is present
|
||||
# and if its input_value is not None
|
||||
# if so, we raise an error
|
||||
if input_request.tweaks is None:
|
||||
if not input_request.tweaks:
|
||||
return
|
||||
|
||||
for key, value in input_request.tweaks.items():
|
||||
if "ChatInput" in key or "Chat Input" in key:
|
||||
if isinstance(value, dict):
|
||||
has_input_value = value.get("input_value") is not None
|
||||
input_value_is_chat = input_request.input_value is not None and input_request.input_type == "chat"
|
||||
if has_input_value and input_value_is_chat:
|
||||
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputError(msg)
|
||||
elif ("Text Input" in key or "TextInput" in key) and isinstance(value, dict):
|
||||
has_input_value = value.get("input_value") is not None
|
||||
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
|
||||
if has_input_value and input_value_is_text:
|
||||
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
|
||||
input_value = value.get("input_value")
|
||||
if input_value is None:
|
||||
continue
|
||||
|
||||
request_has_input = input_request.input_value is not None
|
||||
|
||||
if any(chat_key in key for chat_key in ("ChatInput", "Chat Input")):
|
||||
if request_has_input and input_request.input_type == "chat":
|
||||
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputError(msg)
|
||||
|
||||
elif (
|
||||
any(text_key in key for text_key in ("TextInput", "Text Input"))
|
||||
and request_has_input
|
||||
and input_request.input_type == "text"
|
||||
):
|
||||
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
|
||||
raise InvalidChatInputError(msg)
|
||||
|
||||
|
||||
async def simple_run_flow(
|
||||
flow: Flow,
|
||||
|
|
@ -98,8 +127,7 @@ async def simple_run_flow(
|
|||
api_key_user: User | None = None,
|
||||
event_manager: EventManager | None = None,
|
||||
):
|
||||
if input_request.input_value is not None and input_request.tweaks is not None:
|
||||
validate_input_and_tweaks(input_request)
|
||||
validate_input_and_tweaks(input_request)
|
||||
try:
|
||||
task_result: list[RunOutputs] = []
|
||||
user_id = api_key_user.id if api_key_user else None
|
||||
|
|
@ -110,13 +138,15 @@ async def simple_run_flow(
|
|||
graph_data = flow.data.copy()
|
||||
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id), flow_name=flow.name)
|
||||
inputs = [
|
||||
InputValueRequest(
|
||||
components=[],
|
||||
input_value=input_request.input_value,
|
||||
type=input_request.input_type,
|
||||
)
|
||||
]
|
||||
inputs = None
|
||||
if input_request.input_value is not None:
|
||||
inputs = [
|
||||
InputValueRequest(
|
||||
components=[],
|
||||
input_value=input_request.input_value,
|
||||
type=input_request.input_type,
|
||||
)
|
||||
]
|
||||
if input_request.output_component:
|
||||
outputs = [input_request.output_component]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -287,7 +287,6 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke
|
|||
assert len(outputs_dict) == 2
|
||||
assert "inputs" in outputs_dict
|
||||
assert "outputs" in outputs_dict
|
||||
assert outputs_dict.get("inputs") == {"input_value": ""}
|
||||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
|
|
@ -318,7 +317,6 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre
|
|||
assert len(outputs_dict) == 2
|
||||
assert "inputs" in outputs_dict
|
||||
assert "outputs" in outputs_dict
|
||||
assert outputs_dict.get("inputs") == {"input_value": ""}
|
||||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
|
|
@ -350,7 +348,6 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea
|
|||
assert len(outputs_dict) == 2
|
||||
assert "inputs" in outputs_dict
|
||||
assert "outputs" in outputs_dict
|
||||
assert outputs_dict.get("inputs") == {"input_value": ""}
|
||||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
|
|
@ -383,7 +380,6 @@ async def test_successful_run_with_output_type_debug(client, simple_api_test, cr
|
|||
assert len(outputs_dict) == 2
|
||||
assert "inputs" in outputs_dict
|
||||
assert "outputs" in outputs_dict
|
||||
assert outputs_dict.get("inputs") == {"input_value": ""}
|
||||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 3
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue