diff --git a/.github/changes-filter.yaml b/.github/changes-filter.yaml index 925f23746..f0d36110c 100644 --- a/.github/changes-filter.yaml +++ b/.github/changes-filter.yaml @@ -25,6 +25,7 @@ starter-projects: - "src/backend/base/langflow/components/**" - "src/backend/base/langflow/services/**" - "src/backend/base/langflow/custom/**" + - "src/backend/base/langflow/api/v1/chat.py" - "src/frontend/src/pages/MainPage/**" - "src/frontend/src/utils/reactflowUtils.ts" - "src/frontend/tests/extended/features/**" diff --git a/src/backend/base/langflow/api/build.py b/src/backend/base/langflow/api/build.py new file mode 100644 index 000000000..384a74bdd --- /dev/null +++ b/src/backend/base/langflow/api/build.py @@ -0,0 +1,428 @@ +import asyncio +import json +import time +import traceback +import uuid +from collections.abc import AsyncIterator + +from fastapi import BackgroundTasks, HTTPException +from fastapi.responses import JSONResponse +from loguru import logger +from sqlmodel import select + +from langflow.api.disconnect import DisconnectHandlerStreamingResponse +from langflow.api.utils import ( + CurrentActiveUser, + build_graph_from_data, + build_graph_from_db, + format_elapsed_time, + format_exception_message, + get_top_level_vertices, + parse_exception, +) +from langflow.api.v1.schemas import ( + FlowDataRequest, + InputValueRequest, + ResultDataResponse, + VertexBuildResponse, +) +from langflow.events.event_manager import EventManager +from langflow.exceptions.component import ComponentBuildError +from langflow.graph.graph.base import Graph +from langflow.graph.utils import log_vertex_build +from langflow.schema.message import ErrorMessage +from langflow.schema.schema import OutputValue +from langflow.services.database.models.flow import Flow +from langflow.services.deps import get_chat_service, get_telemetry_service, session_scope +from langflow.services.job_queue.service import JobQueueService +from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload + + +async def start_flow_build( + *, + flow_id: uuid.UUID, + background_tasks: BackgroundTasks, + inputs: InputValueRequest | None, + data: FlowDataRequest | None, + files: list[str] | None, + stop_component_id: str | None, + start_component_id: str | None, + log_builds: bool, + current_user: CurrentActiveUser, + queue_service: JobQueueService, +) -> str: + """Start the flow build process by setting up the queue and starting the build task. + + Returns: + the job_id. + """ + job_id = str(uuid.uuid4()) + try: + _, event_manager = queue_service.create_queue(job_id) + task_coro = generate_flow_events( + flow_id=flow_id, + background_tasks=background_tasks, + event_manager=event_manager, + inputs=inputs, + data=data, + files=files, + stop_component_id=stop_component_id, + start_component_id=start_component_id, + log_builds=log_builds, + current_user=current_user, + ) + queue_service.start_job(job_id, task_coro) + except Exception as e: + logger.exception("Failed to create queue and start task") + raise HTTPException(status_code=500, detail=str(e)) from e + return job_id + + +async def get_flow_events_response( + *, + job_id: str, + queue_service: JobQueueService, + stream: bool = True, +): + """Get events for a specific build job, either as a stream or single event.""" + try: + main_queue, event_manager, event_task = queue_service.get_queue_data(job_id) + if stream: + if event_task is None: + raise HTTPException(status_code=404, detail="No event task found for job") + return await create_flow_response( + queue=main_queue, + event_manager=event_manager, + event_task=event_task, + ) + + # Polling mode - get exactly one event + _, value, _ = await main_queue.get() + if value is None: + # End of stream, trigger end event + if event_task is not None: + event_task.cancel() + event_manager.on_end(data={}) + + return JSONResponse({"event": value.decode("utf-8") if value else None}) + + except ValueError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + +async def create_flow_response( + queue: asyncio.Queue, + event_manager: EventManager, + event_task: asyncio.Task, +) -> DisconnectHandlerStreamingResponse: + """Create a streaming response for the flow build process.""" + + async def consume_and_yield() -> AsyncIterator[str]: + while True: + try: + event_id, value, put_time = await queue.get() + if value is None: + break + get_time = time.time() + yield value.decode("utf-8") + logger.debug(f"Event {event_id} consumed in {get_time - put_time:.4f}s") + except Exception as exc: # noqa: BLE001 + logger.exception(f"Error consuming event: {exc}") + break + + def on_disconnect() -> None: + logger.debug("Client disconnected, closing tasks") + event_task.cancel() + event_manager.on_end(data={}) + + return DisconnectHandlerStreamingResponse( + consume_and_yield(), + media_type="application/x-ndjson", + on_disconnect=on_disconnect, + ) + + +async def generate_flow_events( + *, + flow_id: uuid.UUID, + background_tasks: BackgroundTasks, + event_manager: EventManager, + inputs: InputValueRequest | None, + data: FlowDataRequest | None, + files: list[str] | None, + stop_component_id: str | None, + start_component_id: str | None, + log_builds: bool, + current_user: CurrentActiveUser, +) -> None: + """Generate events for flow building process. + + This function handles the core flow building logic and generates appropriate events: + - Building and validating the graph + - Processing vertices + - Handling errors and cleanup + """ + chat_service = get_chat_service() + telemetry_service = get_telemetry_service() + if not inputs: + inputs = InputValueRequest(session=str(flow_id)) + + async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]: + start_time = time.perf_counter() + components_count = 0 + graph = None + try: + flow_id_str = str(flow_id) + # Create a fresh session for database operations + async with session_scope() as fresh_session: + graph = await create_graph(fresh_session, flow_id_str) + + graph.validate_stream() + first_layer = sort_vertices(graph) + + if inputs is not None and getattr(inputs, "session", None) is not None: + graph.session_id = inputs.session + + for vertex_id in first_layer: + graph.run_manager.add_to_vertices_being_run(vertex_id) + + # Now vertices is a list of lists + # We need to get the id of each vertex + # and return the same structure but only with the ids + components_count = len(graph.vertices) + vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) + + await chat_service.set_cache(flow_id_str, graph) + await log_telemetry(start_time, components_count, success=True) + + except Exception as exc: + await log_telemetry(start_time, components_count, success=False, error_message=str(exc)) + + if "stream or streaming set to True" in str(exc): + raise HTTPException(status_code=400, detail=str(exc)) from exc + logger.exception("Error checking build status") + raise HTTPException(status_code=500, detail=str(exc)) from exc + return first_layer, vertices_to_run, graph + + async def log_telemetry( + start_time: float, components_count: int, *, success: bool, error_message: str | None = None + ): + background_tasks.add_task( + telemetry_service.log_package_playground, + PlaygroundPayload( + playground_seconds=int(time.perf_counter() - start_time), + playground_component_count=components_count, + playground_success=success, + playground_error_message=str(error_message) if error_message else "", + ), + ) + + async def create_graph(fresh_session, flow_id_str: str) -> Graph: + if not data: + return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service) + + result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id)) + flow_name = result.first() + + return await build_graph_from_data( + flow_id=flow_id_str, + payload=data.model_dump(), + user_id=str(current_user.id), + flow_name=flow_name, + ) + + def sort_vertices(graph: Graph) -> list[str]: + try: + return graph.sort_vertices(stop_component_id, start_component_id) + except Exception: # noqa: BLE001 + logger.exception("Error sorting vertices") + return graph.sort_vertices() + + async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse: + flow_id_str = str(flow_id) + next_runnable_vertices = [] + top_level_vertices = [] + start_time = time.perf_counter() + error_message = None + try: + vertex = graph.get_vertex(vertex_id) + try: + lock = chat_service.async_cache_locks[flow_id_str] + vertex_build_result = await graph.build_vertex( + vertex_id=vertex_id, + user_id=str(current_user.id), + inputs_dict=inputs.model_dump() if inputs else {}, + files=files, + get_cache=chat_service.get_cache, + set_cache=chat_service.set_cache, + event_manager=event_manager, + ) + result_dict = vertex_build_result.result_dict + params = vertex_build_result.params + valid = vertex_build_result.valid + artifacts = vertex_build_result.artifacts + next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) + top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) + + result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) + except Exception as exc: # noqa: BLE001 + if isinstance(exc, ComponentBuildError): + params = exc.message + tb = exc.formatted_traceback + else: + tb = traceback.format_exc() + logger.exception("Error building Component") + params = format_exception_message(exc) + message = {"errorMessage": params, "stackTrace": tb} + valid = False + error_message = params + output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" + outputs = {output_label: OutputValue(message=message, type="error")} + result_data_response = ResultDataResponse(results={}, outputs=outputs) + artifacts = {} + background_tasks.add_task(graph.end_all_traces, error=exc) + + result_data_response.message = artifacts + + # Log the vertex build + if not vertex.will_stream and log_builds: + background_tasks.add_task( + log_vertex_build, + flow_id=flow_id_str, + vertex_id=vertex_id, + valid=valid, + params=params, + data=result_data_response, + artifacts=artifacts, + ) + else: + await chat_service.set_cache(flow_id_str, graph) + + timedelta = time.perf_counter() - start_time + duration = format_elapsed_time(timedelta) + result_data_response.duration = duration + result_data_response.timedelta = timedelta + vertex.add_build_time(timedelta) + inactivated_vertices = list(graph.inactivated_vertices) + graph.reset_inactivated_vertices() + graph.reset_activated_vertices() + # graph.stop_vertex tells us if the user asked + # to stop the build of the graph at a certain vertex + # if it is in next_vertices_ids, we need to remove other + # vertices from next_vertices_ids + if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: + next_runnable_vertices = [graph.stop_vertex] + + if not graph.run_manager.vertices_being_run and not next_runnable_vertices: + background_tasks.add_task(graph.end_all_traces) + + build_response = VertexBuildResponse( + inactivated_vertices=list(set(inactivated_vertices)), + next_vertices_ids=list(set(next_runnable_vertices)), + top_level_vertices=list(set(top_level_vertices)), + valid=valid, + params=params, + id=vertex.id, + data=result_data_response, + ) + background_tasks.add_task( + telemetry_service.log_package_component, + ComponentPayload( + component_name=vertex_id.split("-")[0], + component_seconds=int(time.perf_counter() - start_time), + component_success=valid, + component_error_message=error_message, + ), + ) + except Exception as exc: + background_tasks.add_task( + telemetry_service.log_package_component, + ComponentPayload( + component_name=vertex_id.split("-")[0], + component_seconds=int(time.perf_counter() - start_time), + component_success=False, + component_error_message=str(exc), + ), + ) + logger.exception("Error building Component") + message = parse_exception(exc) + raise HTTPException(status_code=500, detail=message) from exc + + return build_response + + async def build_vertices( + vertex_id: str, + graph: Graph, + event_manager: EventManager, + ) -> None: + """Build vertices and handle their events. + + Args: + vertex_id: The ID of the vertex to build + graph: The graph instance + event_manager: Manager for handling events + """ + try: + vertex_build_response: VertexBuildResponse = await _build_vertex(vertex_id, graph, event_manager) + except asyncio.CancelledError as exc: + logger.exception(exc) + raise + + # send built event or error event + try: + vertex_build_response_json = vertex_build_response.model_dump_json() + build_data = json.loads(vertex_build_response_json) + except Exception as exc: + msg = f"Error serializing vertex build response: {exc}" + raise ValueError(msg) from exc + + event_manager.on_end_vertex(data={"build_data": build_data}) + + if vertex_build_response.valid and vertex_build_response.next_vertices_ids: + tasks = [] + for next_vertex_id in vertex_build_response.next_vertices_ids: + task = asyncio.create_task( + build_vertices( + next_vertex_id, + graph, + event_manager, + ) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + try: + ids, vertices_to_run, graph = await build_graph_and_get_order() + except Exception as e: + error_message = ErrorMessage( + flow_id=flow_id, + exception=e, + ) + event_manager.on_error(data=error_message.data) + raise + + event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run}) + + tasks = [] + for vertex_id in ids: + task = asyncio.create_task(build_vertices(vertex_id, graph, event_manager)) + tasks.append(task) + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + background_tasks.add_task(graph.end_all_traces) + raise + except Exception as e: + logger.error(f"Error building vertices: {e}") + custom_component = graph.get_vertex(vertex_id).custom_component + trace_name = getattr(custom_component, "trace_name", None) + error_message = ErrorMessage( + flow_id=flow_id, + exception=e, + session_id=graph.session_id, + trace_name=trace_name, + ) + event_manager.on_error(data=error_message.data) + raise + event_manager.on_end(data={}) + await event_manager.queue.put((None, None, time.time())) diff --git a/src/backend/base/langflow/api/disconnect.py b/src/backend/base/langflow/api/disconnect.py new file mode 100644 index 000000000..a11454c73 --- /dev/null +++ b/src/backend/base/langflow/api/disconnect.py @@ -0,0 +1,31 @@ +import asyncio +import typing + +from fastapi.responses import StreamingResponse +from starlette.background import BackgroundTask +from starlette.responses import ContentStream +from starlette.types import Receive + + +class DisconnectHandlerStreamingResponse(StreamingResponse): + def __init__( + self, + content: ContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + on_disconnect: typing.Callable | None = None, + ): + super().__init__(content, status_code, headers, media_type, background) + self.on_disconnect = on_disconnect + + async def listen_for_disconnect(self, receive: Receive) -> None: + while True: + message = await receive() + if message["type"] == "http.disconnect": + if self.on_disconnect: + coro = self.on_disconnect() + if asyncio.iscoroutine(coro): + await coro + break diff --git a/src/backend/base/langflow/api/limited_background_tasks.py b/src/backend/base/langflow/api/limited_background_tasks.py new file mode 100644 index 000000000..b09bc31db --- /dev/null +++ b/src/backend/base/langflow/api/limited_background_tasks.py @@ -0,0 +1,29 @@ +from fastapi import BackgroundTasks + +from langflow.graph.utils import log_vertex_build +from langflow.services.deps import get_settings_service + + +class LimitVertexBuildBackgroundTasks(BackgroundTasks): + """A subclass of FastAPI BackgroundTasks that limits the number of tasks added per vertex_id. + + If more than max_vertex_builds_per_vertex tasks are added for a given vertex_id, + the oldest task is removed so that only the most recent remain. + This only applies to log_vertex_build tasks. + """ + + def add_task(self, func, *args, **kwargs): + # Only apply limiting logic to log_vertex_build tasks + if func == log_vertex_build: + vertex_id = kwargs.get("vertex_id") + if vertex_id is not None: + # Filter tasks that are log_vertex_build calls with the same vertex_id + relevant_tasks = [ + t for t in self.tasks if t.func == log_vertex_build and t.kwargs.get("vertex_id") == vertex_id + ] + if len(relevant_tasks) >= get_settings_service().settings.max_vertex_builds_per_vertex: + # Remove the oldest task for this vertex_id + oldest_task = relevant_tasks[0] + self.tasks.remove(oldest_task) + + super().add_task(func, *args, **kwargs) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index f0213e167..d9c88908e 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -1,26 +1,23 @@ from __future__ import annotations -import asyncio -import json import time import traceback -import typing import uuid from typing import TYPE_CHECKING, Annotated -from fastapi import APIRouter, BackgroundTasks, Body, HTTPException +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException from fastapi.responses import StreamingResponse from loguru import logger -from sqlmodel import select -from starlette.background import BackgroundTask -from starlette.responses import ContentStream -from starlette.types import Receive +from langflow.api.build import ( + get_flow_events_response, + start_flow_build, +) +from langflow.api.limited_background_tasks import LimitVertexBuildBackgroundTasks from langflow.api.utils import ( CurrentActiveUser, DbSession, build_and_cache_graph_from_data, - build_graph_from_data, build_graph_from_db, format_elapsed_time, format_exception_message, @@ -35,16 +32,21 @@ from langflow.api.v1.schemas import ( VertexBuildResponse, VerticesOrderResponse, ) -from langflow.events.event_manager import EventManager, create_default_event_manager from langflow.exceptions.component import ComponentBuildError from langflow.graph.graph.base import Graph from langflow.graph.utils import log_vertex_build -from langflow.schema.message import ErrorMessage from langflow.schema.schema import OutputValue from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.database.models.flow.model import Flow -from langflow.services.deps import get_chat_service, get_session, get_telemetry_service, session_scope +from langflow.services.deps import ( + get_chat_service, + get_queue_service, + get_session, + get_telemetry_service, + session_scope, +) +from langflow.services.job_queue.service import JobQueueService from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload if TYPE_CHECKING: @@ -53,22 +55,6 @@ if TYPE_CHECKING: router = APIRouter(tags=["Chat"]) -async def try_running_celery_task(vertex, user_id): - # Try running the task in celery - # and set the task_id to the local vertex - # if it fails, run the task locally - try: - from langflow.worker import build_vertex - - task = build_vertex.delay(vertex) - vertex.task_id = task.id - except Exception: # noqa: BLE001 - logger.opt(exception=True).debug("Error running task in celery") - vertex.task_id = None - await vertex.build(user_id=user_id) - return vertex - - @router.post("/build/{flow_id}/vertices", deprecated=True) async def retrieve_vertices_order( *, @@ -143,322 +129,52 @@ async def retrieve_vertices_order( @router.post("/build/{flow_id}/flow") async def build_flow( *, - background_tasks: BackgroundTasks, flow_id: uuid.UUID, + background_tasks: LimitVertexBuildBackgroundTasks, inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, data: Annotated[FlowDataRequest | None, Body(embed=True)] = None, files: list[str] | None = None, stop_component_id: str | None = None, start_component_id: str | None = None, - log_builds: bool | None = True, + log_builds: bool = True, current_user: CurrentActiveUser, + queue_service: Annotated[JobQueueService, Depends(get_queue_service)], ): - chat_service = get_chat_service() - telemetry_service = get_telemetry_service() - if not inputs: - inputs = InputValueRequest(session=str(flow_id)) + """Build and process a flow, returning a job ID for event polling.""" + # First verify the flow exists + async with session_scope() as session: + flow = await session.get(Flow, flow_id) + if not flow: + raise HTTPException(status_code=404, detail=f"Flow with id {flow_id} not found") - async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]: - start_time = time.perf_counter() - components_count = 0 - graph = None - try: - flow_id_str = str(flow_id) - # Create a fresh session for database operations - async with session_scope() as fresh_session: - graph = await create_graph(fresh_session, flow_id_str) - - graph.validate_stream() - first_layer = sort_vertices(graph) - - if inputs is not None and hasattr(inputs, "session") and inputs.session is not None: - graph.session_id = inputs.session - - for vertex_id in first_layer: - graph.run_manager.add_to_vertices_being_run(vertex_id) - - # Now vertices is a list of lists - # We need to get the id of each vertex - # and return the same structure but only with the ids - components_count = len(graph.vertices) - vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) - - await chat_service.set_cache(flow_id_str, graph) - await log_telemetry(start_time, components_count, success=True) - - except Exception as exc: - await log_telemetry(start_time, components_count, success=False, error_message=str(exc)) - - if "stream or streaming set to True" in str(exc): - raise HTTPException(status_code=400, detail=str(exc)) from exc - logger.exception("Error checking build status") - raise HTTPException(status_code=500, detail=str(exc)) from exc - return first_layer, vertices_to_run, graph - - async def log_telemetry( - start_time: float, components_count: int, *, success: bool, error_message: str | None = None - ): - background_tasks.add_task( - telemetry_service.log_package_playground, - PlaygroundPayload( - playground_seconds=int(time.perf_counter() - start_time), - playground_component_count=components_count, - playground_success=success, - playground_error_message=str(error_message) if error_message else "", - ), - ) - - async def create_graph(fresh_session, flow_id_str: str) -> Graph: - if not data: - return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service) - - result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id)) - flow_name = result.first() - - return await build_graph_from_data( - flow_id=flow_id_str, - payload=data.model_dump(), - user_id=str(current_user.id), - flow_name=flow_name, - ) - - def sort_vertices(graph: Graph) -> list[str]: - try: - return graph.sort_vertices(stop_component_id, start_component_id) - except Exception: # noqa: BLE001 - logger.exception("Error sorting vertices") - return graph.sort_vertices() - - async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse: - flow_id_str = str(flow_id) - next_runnable_vertices = [] - top_level_vertices = [] - start_time = time.perf_counter() - error_message = None - try: - vertex = graph.get_vertex(vertex_id) - try: - lock = chat_service.async_cache_locks[flow_id_str] - vertex_build_result = await graph.build_vertex( - vertex_id=vertex_id, - user_id=str(current_user.id), - inputs_dict=inputs.model_dump() if inputs else {}, - files=files, - get_cache=chat_service.get_cache, - set_cache=chat_service.set_cache, - event_manager=event_manager, - ) - result_dict = vertex_build_result.result_dict - params = vertex_build_result.params - valid = vertex_build_result.valid - artifacts = vertex_build_result.artifacts - next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) - top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) - - result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) - except Exception as exc: # noqa: BLE001 - if isinstance(exc, ComponentBuildError): - params = exc.message - tb = exc.formatted_traceback - else: - tb = traceback.format_exc() - logger.exception("Error building Component") - params = format_exception_message(exc) - message = {"errorMessage": params, "stackTrace": tb} - valid = False - error_message = params - output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" - outputs = {output_label: OutputValue(message=message, type="error")} - result_data_response = ResultDataResponse(results={}, outputs=outputs) - artifacts = {} - background_tasks.add_task(graph.end_all_traces, error=exc) - - result_data_response.message = artifacts - - # Log the vertex build - if not vertex.will_stream and log_builds: - background_tasks.add_task( - log_vertex_build, - flow_id=flow_id_str, - vertex_id=vertex_id, - valid=valid, - params=params, - data=result_data_response, - artifacts=artifacts, - ) - else: - await chat_service.set_cache(flow_id_str, graph) - - timedelta = time.perf_counter() - start_time - duration = format_elapsed_time(timedelta) - result_data_response.duration = duration - result_data_response.timedelta = timedelta - vertex.add_build_time(timedelta) - inactivated_vertices = list(graph.inactivated_vertices) - graph.reset_inactivated_vertices() - graph.reset_activated_vertices() - # graph.stop_vertex tells us if the user asked - # to stop the build of the graph at a certain vertex - # if it is in next_vertices_ids, we need to remove other - # vertices from next_vertices_ids - if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: - next_runnable_vertices = [graph.stop_vertex] - - if not graph.run_manager.vertices_being_run and not next_runnable_vertices: - background_tasks.add_task(graph.end_all_traces) - - build_response = VertexBuildResponse( - inactivated_vertices=list(set(inactivated_vertices)), - next_vertices_ids=list(set(next_runnable_vertices)), - top_level_vertices=list(set(top_level_vertices)), - valid=valid, - params=params, - id=vertex.id, - data=result_data_response, - ) - background_tasks.add_task( - telemetry_service.log_package_component, - ComponentPayload( - component_name=vertex_id.split("-")[0], - component_seconds=int(time.perf_counter() - start_time), - component_success=valid, - component_error_message=error_message, - ), - ) - except Exception as exc: - background_tasks.add_task( - telemetry_service.log_package_component, - ComponentPayload( - component_name=vertex_id.split("-")[0], - component_seconds=int(time.perf_counter() - start_time), - component_success=False, - component_error_message=str(exc), - ), - ) - logger.exception("Error building Component") - message = parse_exception(exc) - raise HTTPException(status_code=500, detail=message) from exc - - return build_response - - async def build_vertices( - vertex_id: str, - graph: Graph, - client_consumed_queue: asyncio.Queue, - event_manager: EventManager, - ) -> None: - try: - vertex_build_response: VertexBuildResponse = await _build_vertex(vertex_id, graph, event_manager) - except asyncio.CancelledError as exc: - logger.exception(exc) - raise - - # send built event or error event - try: - vertex_build_response_json = vertex_build_response.model_dump_json() - build_data = json.loads(vertex_build_response_json) - except Exception as exc: - msg = f"Error serializing vertex build response: {exc}" - raise ValueError(msg) from exc - event_manager.on_end_vertex(data={"build_data": build_data}) - await client_consumed_queue.get() - if vertex_build_response.valid and vertex_build_response.next_vertices_ids: - tasks = [] - for next_vertex_id in vertex_build_response.next_vertices_ids: - task = asyncio.create_task(build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)) - tasks.append(task) - await asyncio.gather(*tasks) - - async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None: - try: - ids, vertices_to_run, graph = await build_graph_and_get_order() - except Exception as e: - error_message = ErrorMessage( - flow_id=flow_id, - exception=e, - ) - event_manager.on_error(data=error_message.data) - raise - event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run}) - await client_consumed_queue.get() - tasks = [] - for vertex_id in ids: - task = asyncio.create_task(build_vertices(vertex_id, graph, client_consumed_queue, event_manager)) - tasks.append(task) - try: - await asyncio.gather(*tasks) - except asyncio.CancelledError: - background_tasks.add_task(graph.end_all_traces) - raise - - except Exception as e: - logger.error(f"Error building vertices: {e}") - custom_component = graph.get_vertex(vertex_id).custom_component - trace_name = getattr(custom_component, "trace_name", None) - error_message = ErrorMessage( - flow_id=flow_id, - exception=e, - session_id=graph.session_id, - trace_name=trace_name, - ) - event_manager.on_error(data=error_message.data) - raise - event_manager.on_end(data={}) - await event_manager.queue.put((None, None, time.time)) - - async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator: - while True: - event_id, value, put_time = await queue.get() - if value is None: - break - get_time = time.time() - yield value - get_time_yield = time.time() - client_consumed_queue.put_nowait(event_id) - logger.debug( - f"consumed event {event_id} " - f"(time in queue, {get_time - put_time:.4f}, " - f"client {get_time_yield - get_time:.4f})" - ) - - asyncio_queue: asyncio.Queue = asyncio.Queue() - asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue() - event_manager = create_default_event_manager(queue=asyncio_queue) - main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed)) - - def on_disconnect() -> None: - logger.debug("Client disconnected, closing tasks") - main_task.cancel() - - return DisconnectHandlerStreamingResponse( - consume_and_yield(asyncio_queue, asyncio_queue_client_consumed), - media_type="application/x-ndjson", - on_disconnect=on_disconnect, + job_id = await start_flow_build( + flow_id=flow_id, + background_tasks=background_tasks, + inputs=inputs, + data=data, + files=files, + stop_component_id=stop_component_id, + start_component_id=start_component_id, + log_builds=log_builds, + current_user=current_user, + queue_service=queue_service, ) + return {"job_id": job_id} -class DisconnectHandlerStreamingResponse(StreamingResponse): - def __init__( - self, - content: ContentStream, - status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, - media_type: str | None = None, - background: BackgroundTask | None = None, - on_disconnect: typing.Callable | None = None, - ): - super().__init__(content, status_code, headers, media_type, background) - self.on_disconnect = on_disconnect - - async def listen_for_disconnect(self, receive: Receive) -> None: - while True: - message = await receive() - if message["type"] == "http.disconnect": - if self.on_disconnect: - coro = self.on_disconnect() - if asyncio.iscoroutine(coro): - await coro - break +@router.get("/build/{job_id}/events") +async def get_build_events( + job_id: str, + queue_service: Annotated[JobQueueService, Depends(get_queue_service)], + *, + stream: bool = True, +): + """Get events for a specific build job.""" + return await get_flow_events_response( + job_id=job_id, + queue_service=queue_service, + stream=stream, + ) @router.post("/build/{flow_id}/vertices/{vertex_id}", deprecated=True) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 89a31daa0..3a6ff403a 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -44,7 +44,7 @@ from langflow.services.database.models.flow import Flow from langflow.services.database.models.flow.model import FlowRead from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow from langflow.services.database.models.user.model import User, UserRead -from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service +from langflow.services.deps import get_session_service, get_settings_service, get_telemetry_service from langflow.services.settings.feature_flags import FEATURE_FLAGS from langflow.services.telemetry.schema import RunPayload from langflow.utils.version import get_version_info @@ -599,29 +599,16 @@ async def process() -> None: ) -@router.get("/task/{task_id}") -async def get_task_status(task_id: str) -> TaskStatusResponse: - task_service = get_task_service() - task = task_service.get_task(task_id) - result = None - if task is None: - raise HTTPException(status_code=404, detail="Task not found") - if task.ready(): - result = task.result - # If result isinstance of Exception, can we get the traceback? - if isinstance(result, Exception): - logger.exception(task.traceback) +@router.get("/task/{_task_id}", deprecated=True) +async def get_task_status(_task_id: str) -> TaskStatusResponse: + """Get the status of a task by ID (Deprecated). - if isinstance(result, dict) and "result" in result: - result = result["result"] - elif hasattr(result, "result"): - result = result.result - - if task.status == "FAILURE": - result = str(task.result) - logger.error(f"Task {task_id} failed: {task.traceback}") - - return TaskStatusResponse(status=task.status, result=result) + This endpoint is deprecated and will be removed in a future version. + """ + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The /task endpoint is deprecated and will be removed in a future version. Please use /run instead.", + ) @router.post( diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index 5181a9ec4..718f426f4 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -35,6 +35,7 @@ async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSessio async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None: try: await delete_vertex_builds_by_flow_id(session, flow_id) + await session.commit() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/backend/base/langflow/api/v1/schemas.py b/src/backend/base/langflow/api/v1/schemas.py index 0d90d12df..a1a58f5b4 100644 --- a/src/backend/base/langflow/api/v1/schemas.py +++ b/src/backend/base/langflow/api/v1/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Any +from typing import Any, Literal from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer @@ -376,3 +376,4 @@ class ConfigResponse(BaseModel): auto_saving_interval: int health_check_max_retries: int max_file_size_upload: int + event_delivery: Literal["polling", "streaming"] diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index ce4ee01ad..33c95e138 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -169,8 +169,9 @@ class LCAgentComponent(Component): cast("SendMessageFunctionType", self.send_message), ) except ExceptionWithMessageError as e: - msg_id = e.agent_message.id - await delete_message(id_=msg_id) + if hasattr(e, "agent_message") and hasattr(e.agent_message, "id"): + msg_id = e.agent_message.id + await delete_message(id_=msg_id) await self._send_message_event(e.agent_message, category="remove_message") logger.error(f"ExceptionWithMessageError: {e}") raise diff --git a/src/backend/base/langflow/events/event_manager.py b/src/backend/base/langflow/events/event_manager.py index 9ae4e9172..e499030d9 100644 --- a/src/backend/base/langflow/events/event_manager.py +++ b/src/backend/base/langflow/events/event_manager.py @@ -1,21 +1,26 @@ -import asyncio +from __future__ import annotations + import inspect import json import time import uuid from functools import partial -from typing import Literal +from typing import TYPE_CHECKING, Literal from fastapi.encoders import jsonable_encoder from loguru import logger from typing_extensions import Protocol -from langflow.schema.log import LoggableType from langflow.schema.playground_events import create_event_by_type +if TYPE_CHECKING: + import asyncio + + from langflow.schema.log import LoggableType + class EventCallback(Protocol): - def __call__(self, *, manager: "EventManager", event_type: str, data: LoggableType): ... + def __call__(self, *, manager: EventManager, event_type: str, data: LoggableType): ... class PartialEventCallback(Protocol): diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index 6d029d900..47ef56dce 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -137,7 +137,8 @@ async def log_transaction( async with session_getter(get_db_service()) as session: with session.no_autoflush: inserted = await crud_log_transaction(session, transaction) - logger.debug(f"Logged transaction: {inserted.id}") + if inserted: + logger.debug(f"Logged transaction: {inserted.id}") except Exception: # noqa: BLE001 logger.error("Error logging transaction") diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index e5caf6bde..c5297e1c7 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -34,7 +34,7 @@ from langflow.interface.components import get_and_cache_all_types_dict from langflow.interface.utils import setup_llm_caching from langflow.logging.logger import configure from langflow.middleware import ContentSizeLimitMiddleware -from langflow.services.deps import get_settings_service, get_telemetry_service +from langflow.services.deps import get_queue_service, get_settings_service, get_telemetry_service from langflow.services.utils import initialize_services, teardown_services if TYPE_CHECKING: @@ -43,6 +43,7 @@ if TYPE_CHECKING: # Ignore Pydantic deprecation warnings from Langchain warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) +_tasks: list[asyncio.Task] = [] MAX_PORT = 65535 @@ -127,6 +128,9 @@ def get_lifespan(*, fix_migration=False, version=None): await create_or_update_starter_projects(all_types_dict) telemetry_service.start() await load_flows_from_directory() + queue_service = get_queue_service() + if not queue_service.is_started(): # Start if not already started + queue_service.start() yield except Exception as exc: diff --git a/src/backend/base/langflow/services/database/models/transactions/crud.py b/src/backend/base/langflow/services/database/models/transactions/crud.py index 409671d93..810a03d7b 100644 --- a/src/backend/base/langflow/services/database/models/transactions/crud.py +++ b/src/backend/base/langflow/services/database/models/transactions/crud.py @@ -1,5 +1,6 @@ from uuid import UUID +from loguru import logger from sqlmodel import col, delete, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -25,7 +26,7 @@ async def get_transactions_by_flow_id( return list(transactions) -async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> TransactionTable: +async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> TransactionTable | None: """Log a transaction and maintain a maximum number of transactions in the database. This function logs a new transaction into the database and ensures that the number of transactions @@ -42,6 +43,9 @@ async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> Tra Raises: IntegrityError: If there is a database integrity error """ + if not transaction.flow_id: + logger.debug("Transaction flow_id is None") + return None table = TransactionTable(**transaction.model_dump()) try: @@ -63,7 +67,6 @@ async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> Tra db.add(table) await db.exec(delete_older) await db.commit() - await db.refresh(table) except Exception: await db.rollback() diff --git a/src/backend/base/langflow/services/database/models/vertex_builds/crud.py b/src/backend/base/langflow/services/database/models/vertex_builds/crud.py index abcc8b093..2f8d1c249 100644 --- a/src/backend/base/langflow/services/database/models/vertex_builds/crud.py +++ b/src/backend/base/langflow/services/database/models/vertex_builds/crud.py @@ -143,4 +143,3 @@ async def delete_vertex_builds_by_flow_id(db: AsyncSession, flow_id: UUID) -> No """ stmt = delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow_id) await db.exec(stmt) - await db.commit() diff --git a/src/backend/base/langflow/services/deps.py b/src/backend/base/langflow/services/deps.py index 4dd4a2639..a60dcb407 100644 --- a/src/backend/base/langflow/services/deps.py +++ b/src/backend/base/langflow/services/deps.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from langflow.services.cache.service import AsyncBaseCacheService, CacheService from langflow.services.chat.service import ChatService from langflow.services.database.service import DatabaseService + from langflow.services.job_queue.service import JobQueueService from langflow.services.session.service import SessionService from langflow.services.settings.service import SettingsService from langflow.services.socket.service import SocketIOService @@ -239,3 +240,10 @@ def get_store_service() -> StoreService: StoreService: The StoreService instance. """ return get_service(ServiceType.STORE_SERVICE) + + +def get_queue_service() -> JobQueueService: + """Retrieves the QueueService instance from the service manager.""" + from langflow.services.job_queue.factory import JobQueueServiceFactory + + return get_service(ServiceType.JOB_QUEUE_SERVICE, JobQueueServiceFactory()) diff --git a/src/backend/base/langflow/services/job_queue/__init__.py b/src/backend/base/langflow/services/job_queue/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/services/job_queue/factory.py b/src/backend/base/langflow/services/job_queue/factory.py new file mode 100644 index 000000000..71d629fdc --- /dev/null +++ b/src/backend/base/langflow/services/job_queue/factory.py @@ -0,0 +1,11 @@ +from langflow.services.base import Service +from langflow.services.factory import ServiceFactory +from langflow.services.job_queue.service import JobQueueService + + +class JobQueueServiceFactory(ServiceFactory): + def __init__(self): + super().__init__(JobQueueService) + + def create(self) -> Service: + return JobQueueService() diff --git a/src/backend/base/langflow/services/job_queue/service.py b/src/backend/base/langflow/services/job_queue/service.py new file mode 100644 index 000000000..e042a9892 --- /dev/null +++ b/src/backend/base/langflow/services/job_queue/service.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import asyncio + +from loguru import logger + +from langflow.events.event_manager import EventManager, create_default_event_manager +from langflow.services.base import Service + + +class JobQueueService(Service): + """Asynchronous service for managing job-specific queues and their associated tasks. + + This service allows clients to: + - Create dedicated asyncio queues for individual jobs. + - Associate each queue with an EventManager, enabling event-driven handling. + - Launch and manage asynchronous tasks that process these job queues. + - Safely clean up resources by cancelling active tasks and emptying queues. + - Automatically perform periodic cleanup of inactive or completed job queues. + + Attributes: + name (str): Unique identifier for the service. + _queues (dict[str, tuple[asyncio.Queue, EventManager, asyncio.Task | None]]): + Dictionary mapping job IDs to a tuple containing: + * The job's asyncio.Queue instance. + * The associated EventManager instance. + * The asyncio.Task processing the job (if any). + _cleanup_task (asyncio.Task | None): Background task for periodic cleanup. + _closed (bool): Flag indicating whether the service is currently active. + + Example: + service = JobQueueService() + await service.start() + queue, event_manager = service.create_queue("job123") + service.start_job("job123", some_async_coroutine()) + # Retrieve and use the queue data as needed + data = service.get_queue_data("job123") + await service.cleanup_job("job123") + await service.stop() + """ + + name = "job_queue_service" + + def __init__(self) -> None: + """Initialize the JobQueueService. + + Sets up the internal registry for job queues, initializes the cleanup task, and sets the service state + to active. + """ + self._queues: dict[str, tuple[asyncio.Queue, EventManager, asyncio.Task | None]] = {} + self._cleanup_task: asyncio.Task | None = None + self._closed = False + self.ready = False + + def is_started(self) -> bool: + """Check if the JobQueueService has started. + + Returns: + bool: True if the service has started, False otherwise. + """ + return self._cleanup_task is not None + + def set_ready(self) -> None: + if not self.is_started(): + self.start() + super().set_ready() + + def start(self) -> None: + """Start the JobQueueService and begin the periodic cleanup routine. + + This method marks the service as active and launches a background task that + periodically checks and cleans up job queues whose tasks have been completed or cancelled. + """ + self._closed = False + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + logger.debug("JobQueueService started: periodic cleanup task initiated.") + + async def stop(self) -> None: + """Gracefully stop the JobQueueService by terminating background operations and cleaning up all resources. + + This coroutine performs the following steps: + 1. Marks the service as closed, preventing further job queue creation. + 2. Cancels the background periodic cleanup task and awaits its termination. + 3. Iterates over all registered job queues to clean up their resources—cancelling active tasks and + clearing queued items. + """ + self._closed = True + if self._cleanup_task: + self._cleanup_task.cancel() + await asyncio.wait([self._cleanup_task]) + if not self._cleanup_task.cancelled(): + exc = self._cleanup_task.exception() + if exc is not None: + raise exc + + # Clean up each registered job queue. + for job_id in list(self._queues.keys()): + await self.cleanup_job(job_id) + logger.info("JobQueueService stopped: all job queues have been cleaned up.") + + async def teardown(self) -> None: + await self.stop() + + def create_queue(self, job_id: str) -> tuple[asyncio.Queue, EventManager]: + """Create and register a new queue along with its corresponding event manager for a job. + + Args: + job_id (str): Unique identifier for the job. + + Returns: + tuple[asyncio.Queue, EventManager]: A tuple containing: + - The asyncio.Queue instance for handling the job's tasks or messages. + - The EventManager instance for event handling tied to the queue. + """ + if job_id in self._queues: + msg = f"Queue for job_id {job_id} already exists" + logger.error(msg) + raise ValueError(msg) + + if self._closed: + msg = "Queue service is closed" + logger.error(msg) + raise RuntimeError(msg) + + main_queue: asyncio.Queue = asyncio.Queue() + event_manager = create_default_event_manager(main_queue) + + # Register the queue without an active task. + self._queues[job_id] = (main_queue, event_manager, None) + logger.debug(f"Queue and event manager successfully created for job_id {job_id}") + return main_queue, event_manager + + def start_job(self, job_id: str, task_coro) -> None: + """Start an asynchronous task for a given job, replacing any existing active task. + + The method performs the following: + - Verifies the presence of a registered queue for the job. + - Cancels any currently running task associated with the job. + - Launches a new asynchronous task using the provided coroutine. + - Updates the internal registry with the new task. + + Args: + job_id (str): Unique identifier for the job. + task_coro: A coroutine representing the job's asynchronous task. + """ + if job_id not in self._queues: + msg = f"No queue found for job_id {job_id}" + logger.error(msg) + raise ValueError(msg) + + if self._closed: + msg = "Queue service is closed" + logger.error(msg) + raise RuntimeError(msg) + + main_queue, event_manager, existing_task = self._queues[job_id] + + if existing_task and not existing_task.done(): + logger.debug(f"Existing task for job_id {job_id} detected; cancelling it.") + existing_task.cancel() + + # Initiate the new asynchronous task. + task = asyncio.create_task(task_coro) + self._queues[job_id] = (main_queue, event_manager, task) + logger.debug(f"New task started for job_id {job_id}") + + def get_queue_data(self, job_id: str) -> tuple[asyncio.Queue, EventManager, asyncio.Task | None]: + """Retrieve the complete data structure associated with a job's queue. + + Args: + job_id (str): Unique identifier for the job. + + Returns: + tuple[asyncio.Queue, EventManager, asyncio.Task | None]: + A tuple containing the job's main queue, its linked event manager, and the associated task (if any). + """ + if job_id not in self._queues: + msg = f"No queue found for job_id {job_id}" + logger.error(msg) + raise ValueError(msg) + + if self._closed: + msg = "Queue service is closed" + logger.error(msg) + raise RuntimeError(msg) + + return self._queues[job_id] + + async def cleanup_job(self, job_id: str) -> None: + """Clean up and release resources for a specific job. + + The cleanup process includes: + 1. Verifying if the job's queue is registered. + 2. Cancelling the running task (if active) and awaiting its termination. + 3. Clearing all items from the job's queue. + 4. Removing the job's entry from the internal registry. + + Args: + job_id (str): Unique identifier for the job to be cleaned up. + """ + if job_id not in self._queues: + logger.debug(f"No queue found for job_id {job_id} during cleanup.") + return + + logger.info(f"Commencing cleanup for job_id {job_id}") + main_queue, event_manager, task = self._queues[job_id] + + # Cancel the associated task if it is still running. + if task and not task.done(): + logger.debug(f"Cancelling active task for job_id {job_id}") + task.cancel() + await asyncio.wait([task]) + # Log any exceptions that occurred during the task's execution. + if exc := task.exception(): + logger.error(f"Error in task for job_id {job_id}: {exc}") + logger.debug(f"Task cancellation complete for job_id {job_id}") + + # Clear the queue since we just cancelled the task or it has completed + items_cleared = 0 + while not main_queue.empty(): + try: + main_queue.get_nowait() + items_cleared += 1 + except asyncio.QueueEmpty: + break + + logger.debug(f"Removed {items_cleared} items from queue for job_id {job_id}") + # Remove the job entry from the registry + self._queues.pop(job_id, None) + logger.info(f"Cleanup successful for job_id {job_id}: resources have been released.") + + async def _periodic_cleanup(self) -> None: + """Execute a periodic task that cleans up completed or cancelled job queues. + + This internal coroutine continuously: + - Sleeps for a fixed interval (60 seconds). + - Initiates the cleanup of job queues by calling _cleanup_old_queues. + - Monitors and logs any exceptions during the cleanup cycle. + + The loop terminates when the service is marked as closed. + """ + while not self._closed: + try: + await asyncio.sleep(60) # Sleep for 60 seconds before next cleanup attempt. + await self._cleanup_old_queues() + except asyncio.CancelledError: + logger.debug("Periodic cleanup task received cancellation signal.") + raise + except Exception as exc: # noqa: BLE001 + logger.error(f"Exception encountered during periodic cleanup: {exc}") + + async def _cleanup_old_queues(self) -> None: + """Scan all registered job queues and clean up those with inactive tasks. + + For each job: + - Check whether the associated task is either complete or cancelled. + - If so, execute the cleanup_job method to release the job's resources. + """ + for job_id in list(self._queues.keys()): + _, _, task = self._queues[job_id] + if task and task.done(): + logger.debug(f"Job queue for job_id {job_id} marked for cleanup.") + await self.cleanup_job(job_id) diff --git a/src/backend/base/langflow/services/schema.py b/src/backend/base/langflow/services/schema.py index 8227d0f69..c8282d122 100644 --- a/src/backend/base/langflow/services/schema.py +++ b/src/backend/base/langflow/services/schema.py @@ -19,3 +19,4 @@ class ServiceType(str, Enum): STATE_SERVICE = "state_service" TRACING_SERVICE = "tracing_service" TELEMETRY_SERVICE = "telemetry_service" + JOB_QUEUE_SERVICE = "job_queue_service" diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index 73afa7d3f..bb023775f 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -219,6 +219,9 @@ class Settings(BaseSettings): mcp_server_enable_progress_notifications: bool = False """If set to False, Langflow will not send progress notifications in the MCP server.""" + event_delivery: Literal["polling", "streaming"] = "streaming" + """How to deliver build events to the frontend. Can be 'polling' or 'streaming'.""" + @field_validator("dev") @classmethod def set_dev(cls, value): diff --git a/src/backend/base/langflow/services/task/backends/anyio.py b/src/backend/base/langflow/services/task/backends/anyio.py index 8f167283f..817a27079 100644 --- a/src/backend/base/langflow/services/task/backends/anyio.py +++ b/src/backend/base/langflow/services/task/backends/anyio.py @@ -1,19 +1,24 @@ +from __future__ import annotations + import traceback -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import anyio -from loguru import logger from langflow.services.task.backends.base import TaskBackend +if TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + class AnyIOTaskResult: - def __init__(self, scope) -> None: - self._scope = scope + def __init__(self) -> None: self._status = "PENDING" self._result = None self._exception: Exception | None = None + self._traceback: TracebackType | None = None + self.cancel_scope: anyio.CancelScope | None = None @property def status(self) -> str: @@ -34,9 +39,11 @@ class AnyIOTaskResult: def ready(self) -> bool: return self._status == "DONE" - async def run(self, func, *args, **kwargs) -> None: + async def run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: try: - self._result = await func(*args, **kwargs) + async with anyio.CancelScope() as scope: + self.cancel_scope = scope + self._result = await func(*args, **kwargs) except Exception as e: # noqa: BLE001 self._exception = e self._traceback = e.__traceback__ @@ -45,36 +52,66 @@ class AnyIOTaskResult: class AnyIOBackend(TaskBackend): + """Backend for handling asynchronous tasks using AnyIO.""" + name = "anyio" def __init__(self) -> None: + """Initialize the AnyIO backend with an empty task dictionary.""" self.tasks: dict[str, AnyIOTaskResult] = {} + self._run_tasks: list[anyio.TaskGroup] = [] async def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any - ) -> tuple[str | None, AnyIOTaskResult | None]: + ) -> tuple[str, AnyIOTaskResult]: """Launch a new task in an asynchronous manner. - Parameters: + Args: task_func: The asynchronous function to run. *args: Positional arguments to pass to task_func. **kwargs: Keyword arguments to pass to task_func. Returns: - A tuple containing a unique task ID and the task result object. - """ - async with anyio.create_task_group() as tg: - try: - task_result = AnyIOTaskResult(tg) - tg.start_soon(task_result.run, task_func, *args, **kwargs) - except Exception: # noqa: BLE001 - logger.exception("An error occurred while launching the task") - return None, None + tuple[str, AnyIOTaskResult]: A tuple containing the task ID and task result object. + Raises: + RuntimeError: If task creation fails. + """ + try: + task_result = AnyIOTaskResult() + + # Create task ID before starting the task task_id = str(id(task_result)) self.tasks[task_id] = task_result - logger.info(f"Task {task_id} started.") - return task_id, task_result - def get_task(self, task_id: str) -> Any: + # Start the task in the background using TaskGroup + async with anyio.create_task_group() as tg: + tg.start_soon(task_result.run, task_func, *args, **kwargs) + self._run_tasks.append(tg) + + except Exception as e: + msg = f"Failed to launch task: {e!s}" + raise RuntimeError(msg) from e + return task_id, task_result + + def get_task(self, task_id: str) -> AnyIOTaskResult | None: + """Retrieve a task by its ID. + + Args: + task_id: The unique identifier of the task. + + Returns: + AnyIOTaskResult | None: The task result object if found, None otherwise. + """ return self.tasks.get(task_id) + + async def cleanup_task(self, task_id: str) -> None: + """Clean up a completed task and its resources. + + Args: + task_id: The unique identifier of the task to clean up. + """ + if task := self.tasks.get(task_id): + if task.cancel_scope: + task.cancel_scope.cancel() + self.tasks.pop(task_id, None) diff --git a/src/backend/base/langflow/services/task/service.py b/src/backend/base/langflow/services/task/service.py index b113cdc5d..01ce048d6 100644 --- a/src/backend/base/langflow/services/task/service.py +++ b/src/backend/base/langflow/services/task/service.py @@ -3,45 +3,20 @@ from __future__ import annotations from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any -from loguru import logger - from langflow.services.base import Service from langflow.services.task.backends.anyio import AnyIOBackend -from langflow.services.task.utils import get_celery_worker_status if TYPE_CHECKING: from langflow.services.settings.service import SettingsService from langflow.services.task.backends.base import TaskBackend -def check_celery_availability(): - try: - from langflow.worker import celery_app - - status = get_celery_worker_status(celery_app) - logger.debug(f"Celery status: {status}") - except Exception: # noqa: BLE001 - logger.opt(exception=True).debug("Celery not available") - status = {"availability": None} - return status - - class TaskService(Service): name = "task_service" def __init__(self, settings_service: SettingsService): self.settings_service = settings_service - try: - if self.settings_service.settings.celery_enabled: - status = check_celery_availability() - - use_celery = status.get("availability") is not None - else: - use_celery = False - except ImportError: - use_celery = False - - self.use_celery = use_celery + self.use_celery = False self.backend = self.get_backend() @property @@ -49,12 +24,6 @@ class TaskService(Service): return self.backend.name def get_backend(self) -> TaskBackend: - if self.use_celery: - from langflow.services.task.backends.celery import CeleryBackend - - logger.debug("Using Celery backend") - return CeleryBackend() - logger.debug("Using AnyIO backend") return AnyIOBackend() # In your TaskService class @@ -64,24 +33,8 @@ class TaskService(Service): *args: Any, **kwargs: Any, ) -> Any: - if not self.use_celery: - return None, await task_func(*args, **kwargs) - if not hasattr(task_func, "apply"): - msg = f"Task function {task_func} does not have an apply method" - raise ValueError(msg) - task = task_func.apply(args=args, kwargs=kwargs) - - result = task.get() - # if result is coroutine - if isinstance(result, Coroutine): - result = await result - return task.id, result + return await task_func(*args, **kwargs) async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}") - logger.debug(f"Using backend {self.backend}") task = self.backend.launch_task(task_func, *args, **kwargs) return await task if isinstance(task, Coroutine) else task - - def get_task(self, task_id: str) -> Any: - return self.backend.get_task(task_id) diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 727f6d0e7..17bca1e1b 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -135,11 +135,12 @@ def get_text(): async def delete_transactions_by_flow_id(db: AsyncSession, flow_id: UUID): + if not flow_id: + return stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id) transactions = await db.exec(stmt) for transaction in transactions: await db.delete(transaction) - await db.commit() async def _delete_transactions_and_vertex_builds(session, flows: list[Flow]): @@ -147,8 +148,14 @@ async def _delete_transactions_and_vertex_builds(session, flows: list[Flow]): for flow_id in flow_ids: if not flow_id: continue - await delete_vertex_builds_by_flow_id(session, flow_id) - await delete_transactions_by_flow_id(session, flow_id) + try: + await delete_vertex_builds_by_flow_id(session, flow_id) + except Exception as e: # noqa: BLE001 + logger.debug(f"Error deleting vertex builds for flow {flow_id}: {e}") + try: + await delete_transactions_by_flow_id(session, flow_id) + except Exception as e: # noqa: BLE001 + logger.debug(f"Error deleting transactions for flow {flow_id}: {e}") @pytest.fixture @@ -433,12 +440,21 @@ async def active_user(client): # noqa: ARG001 yield user # Clean up # Now cleanup transactions, vertex_build - async with db_manager.with_session() as session: - user = await session.get(User, user.id, options=[selectinload(User.flows)]) - await _delete_transactions_and_vertex_builds(session, user.flows) - await session.delete(user) + try: + async with db_manager.with_session() as session: + user = await session.get(User, user.id, options=[selectinload(User.flows)]) + await _delete_transactions_and_vertex_builds(session, user.flows) + await session.commit() + except Exception as e: # noqa: BLE001 + logger.exception(f"Error deleting transactions and vertex builds for user: {e}") - await session.commit() + try: + async with db_manager.with_session() as session: + user = await session.get(User, user.id) + await session.delete(user) + await session.commit() + except Exception as e: # noqa: BLE001 + logger.exception(f"Error deleting user: {e}") @pytest.fixture diff --git a/src/backend/tests/unit/base/tools/test_component_toolkit.py b/src/backend/tests/unit/base/tools/test_component_toolkit.py index 07a67e797..4db6f9f6a 100644 --- a/src/backend/tests/unit/base/tools/test_component_toolkit.py +++ b/src/backend/tests/unit/base/tools/test_component_toolkit.py @@ -35,17 +35,21 @@ def test_component_tool(): @pytest.mark.api_key_required -def test_component_tool_with_api_key(): +@pytest.mark.usefixtures("client") +async def test_component_tool_with_api_key(): chat_output = ChatOutput() openai_llm = OpenAIModelComponent() openai_llm.set(api_key=os.environ["OPENAI_API_KEY"]) tool_calling_agent = ToolCallingAgentComponent() + tools = await chat_output.to_toolkit() tool_calling_agent.set( - llm=openai_llm.build_model, tools=[chat_output], input_value="Which tools are available? Please tell its name." + llm=openai_llm.build_model, + tools=tools, + input_value="Which tools are available? Please tell its name.", ) g = Graph(start=tool_calling_agent, end=tool_calling_agent) assert g is not None - results = list(g.start()) + results = [result async for result in g.async_start()] assert len(results) == 4 assert "message_response" in tool_calling_agent._outputs_map["response"].value.get_text() diff --git a/src/backend/tests/unit/build_utils.py b/src/backend/tests/unit/build_utils.py new file mode 100644 index 000000000..abad2d769 --- /dev/null +++ b/src/backend/tests/unit/build_utils.py @@ -0,0 +1,75 @@ +import json +from typing import Any +from uuid import UUID + +from httpx import AsyncClient, codes + + +async def create_flow(client: AsyncClient, flow_data: str, headers: dict[str, str]) -> UUID: + """Create a flow and return its ID.""" + response = await client.post("api/v1/flows/", json=json.loads(flow_data), headers=headers) + assert response.status_code == codes.CREATED + return UUID(response.json()["id"]) + + +async def build_flow( + client: AsyncClient, flow_id: UUID, headers: dict[str, str], json: dict[str, Any] | None = None +) -> dict[str, Any]: + """Start a flow build and return the job_id.""" + if json is None: + json = {} + response = await client.post(f"api/v1/build/{flow_id}/flow", json=json, headers=headers) + assert response.status_code == codes.OK + return response.json() + + +async def get_build_events(client: AsyncClient, job_id: str, headers: dict[str, str]): + """Get events for a build job.""" + return await client.get(f"api/v1/build/{job_id}/events", headers=headers) + + +async def consume_and_assert_stream(response, job_id): + """Consume the event stream and assert the expected event structure.""" + count = 0 + lines = [] + async for line in response.aiter_lines(): + # Skip empty lines (ndjson uses double newlines) + if not line: + continue + + lines.append(line) + parsed = json.loads(line) + if "job_id" in parsed: + assert parsed["job_id"] == job_id + continue + + if count == 0: + # First event should be vertices_sorted + assert parsed["event"] == "vertices_sorted", ( + "Invalid first event. Expected 'vertices_sorted'. Full event stream:\n" + "\n".join(lines) + ) + ids = parsed["data"]["ids"] + ids.sort() + assert ids == ["ChatInput-CIGht"], "Invalid ids in first event. Full event stream:\n" + "\n".join(lines) + + to_run = parsed["data"]["to_run"] + to_run.sort() + assert to_run == ["ChatInput-CIGht", "ChatOutput-QA7ej", "Memory-amN4Z", "Prompt-iWbCC"], ( + "Invalid to_run list in first event. Full event stream:\n" + "\n".join(lines) + ) + elif count > 0 and count < 5: + # Next events should be end_vertex events + assert parsed["event"] == "end_vertex", ( + f"Invalid event at position {count}. Expected 'end_vertex'. Full event stream:\n" + "\n".join(lines) + ) + assert parsed["data"]["build_data"] is not None, ( + f"Missing build_data at position {count}. Full event stream:\n" + "\n".join(lines) + ) + elif count == 5: + # Final event should be end + assert parsed["event"] == "end", "Invalid final event. Expected 'end'. Full event stream:\n" + "\n".join( + lines + ) + else: + raise ValueError(f"Unexpected event at position {count}. Full event stream:\n" + "\n".join(lines)) + count += 1 diff --git a/src/backend/tests/unit/components/agents/test_agent_component.py b/src/backend/tests/unit/components/agents/test_agent_component.py index 45efa5a71..75895d56e 100644 --- a/src/backend/tests/unit/components/agents/test_agent_component.py +++ b/src/backend/tests/unit/components/agents/test_agent_component.py @@ -78,9 +78,8 @@ class TestAgentComponent(ComponentTestBaseWithoutClient): assert all(provider in updated_config["agent_llm"]["options"] for provider in MODEL_PROVIDERS_DICT) assert "Anthropic" in updated_config["agent_llm"]["options"] assert updated_config["agent_llm"]["input_types"] == [] - assert any("sonnet" in option.lower() for option in updated_config["model_name"]["options"]), ( - f"Options: {updated_config['model_name']['options']}" - ) + options = updated_config["model_name"]["options"] + assert any("sonnet" in option.lower() for option in options), f"Options: {options}" # Test updating build config for Custom updated_config = await component.update_build_config(build_config, "Custom", "agent_llm") @@ -113,6 +112,7 @@ async def test_agent_component_with_calculator(): model_name="gpt-4o", llm_type="OpenAI", temperature=temperature, + _session_id=str(uuid4()), ) response = await agent.message_response() diff --git a/src/backend/tests/unit/components/logic/test_loop.py b/src/backend/tests/unit/components/logic/test_loop.py index 208ee54f5..e3373dbab 100644 --- a/src/backend/tests/unit/components/logic/test_loop.py +++ b/src/backend/tests/unit/components/logic/test_loop.py @@ -1,14 +1,15 @@ from uuid import UUID +import orjson import pytest from httpx import AsyncClient from langflow.components.logic.loop import LoopComponent from langflow.memory import aget_messages from langflow.schema.data import Data from langflow.services.database.models.flow import FlowCreate -from orjson import orjson from tests.base import ComponentTestBaseWithClient +from tests.unit.build_utils import build_flow, get_build_events TEXT = ( "lorem ipsum dolor sit amet lorem ipsum dolor sit amet lorem ipsum dolor sit amet. " @@ -62,15 +63,25 @@ class TestLoopComponentWithAPI(ComponentTestBaseWithClient): assert len(messages[1].text) > 0 async def test_build_flow_loop(self, client, json_loop_test, logged_in_headers): - # TODO: Add a test for the loop where the loop component gets updated even the component in json + """Test building a flow with a loop component.""" + # Create the flow flow_id = await self._create_flow(client, json_loop_test, logged_in_headers) - async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r: - async for line in r.aiter_lines(): - # httpx split by \n, but ndjson sends two \n for each line - if line: - # Process the line if needed - pass + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers) + job_id = build_response["job_id"] + assert job_id is not None + + # Get the events stream + events_response = await get_build_events(client, job_id, logged_in_headers) + assert events_response.status_code == 200 + + # Process the events stream + async for line in events_response.aiter_lines(): + if not line: # Skip empty lines + continue + # Process events if needed + # We could add specific assertions here for loop-related events await self.check_messages(flow_id) diff --git a/src/backend/tests/unit/test_chat_endpoint.py b/src/backend/tests/unit/test_chat_endpoint.py index 3a443358b..5a629245b 100644 --- a/src/backend/tests/unit/test_chat_endpoint.py +++ b/src/backend/tests/unit/test_chat_endpoint.py @@ -1,42 +1,63 @@ -import json +import asyncio +import uuid from uuid import UUID import pytest +from httpx import codes from langflow.memory import aget_messages -from langflow.services.database.models.flow import FlowCreate, FlowUpdate -from orjson import orjson +from langflow.services.database.models.flow import FlowUpdate + +from tests.unit.build_utils import build_flow, consume_and_assert_stream, create_flow, get_build_events @pytest.mark.benchmark async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers): - flow_id = await _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) + """Test the build flow endpoint with the new two-step process.""" + # First create the flow + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) - async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r: - await consume_and_assert_stream(r) + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers) + job_id = build_response["job_id"] + assert job_id is not None - await check_messages(flow_id) + # Get the events stream + events_response = await get_build_events(client, job_id, logged_in_headers) + assert events_response.status_code == codes.OK + + # Consume and verify the events + await consume_and_assert_stream(events_response, job_id) @pytest.mark.benchmark async def test_build_flow_from_request_data(client, json_memory_chatbot_no_llm, logged_in_headers): - flow_id = await _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) - response = await client.get("api/v1/flows/" + str(flow_id), headers=logged_in_headers) + """Test building a flow from request data.""" + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) + response = await client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers) flow_data = response.json() - async with client.stream( - "POST", f"api/v1/build/{flow_id}/flow", json={"data": flow_data["data"]}, headers=logged_in_headers - ) as r: - await consume_and_assert_stream(r) + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers, json={"data": flow_data["data"]}) + job_id = build_response["job_id"] + # Get the events stream + events_response = await get_build_events(client, job_id, logged_in_headers) + assert events_response.status_code == codes.OK + + # Consume and verify the events + await consume_and_assert_stream(events_response, job_id) await check_messages(flow_id) async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, logged_in_headers): - flow_id = await _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) + """Test building a flow with a frozen path.""" + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) - response = await client.get("api/v1/flows/" + str(flow_id), headers=logged_in_headers) + response = await client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers) flow_data = response.json() flow_data["data"]["nodes"][0]["data"]["node"]["frozen"] = True + + # Update the flow with frozen path response = await client.patch( f"api/v1/flows/{flow_id}", json=FlowUpdate(name="Flow", description="description", data=flow_data["data"]).model_dump(), @@ -44,151 +65,131 @@ async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, l ) response.raise_for_status() - async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r: - await consume_and_assert_stream(r) + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers) + job_id = build_response["job_id"] + # Get the events stream + events_response = await get_build_events(client, job_id, logged_in_headers) + assert events_response.status_code == codes.OK + + # Consume and verify the events + await consume_and_assert_stream(events_response, job_id) await check_messages(flow_id) async def check_messages(flow_id): - messages = await aget_messages(flow_id=UUID(flow_id), order="ASC") + if isinstance(flow_id, str): + flow_id = UUID(flow_id) + messages = await aget_messages(flow_id=flow_id, order="ASC") + flow_id_str = str(flow_id) assert len(messages) == 2 - assert messages[0].session_id == flow_id + assert messages[0].session_id == flow_id_str assert messages[0].sender == "User" assert messages[0].sender_name == "User" assert messages[0].text == "" - assert messages[1].session_id == flow_id + assert messages[1].session_id == flow_id_str assert messages[1].sender == "Machine" assert messages[1].sender_name == "AI" -async def consume_and_assert_stream(r): - count = 0 - async for line in r.aiter_lines(): - # httpx split by \n, but ndjson sends two \n for each line - if not line: - continue - parsed = json.loads(line) - if count == 0: - assert parsed["event"] == "vertices_sorted" - ids = parsed["data"]["ids"] - ids.sort() - assert ids == ["ChatInput-CIGht"] - - to_run = parsed["data"]["to_run"] - to_run.sort() - assert to_run == ["ChatInput-CIGht", "ChatOutput-QA7ej", "Memory-amN4Z", "Prompt-iWbCC"] - elif count > 0 and count < 5: - assert parsed["event"] == "end_vertex" - assert parsed["data"]["build_data"] is not None - elif count == 5: - assert parsed["event"] == "end" - else: - msg = f"Unexpected line: {line}" - raise ValueError(msg) - count += 1 +@pytest.mark.benchmark +async def test_build_flow_invalid_job_id(client, logged_in_headers): + """Test getting events for an invalid job ID.""" + invalid_job_id = str(uuid.uuid4()) + response = await get_build_events(client, invalid_job_id, logged_in_headers) + assert response.status_code == codes.NOT_FOUND + assert "No queue found for job_id" in response.json()["detail"] -async def _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers): - vector_store = orjson.loads(json_memory_chatbot_no_llm) - data = vector_store["data"] - vector_store = FlowCreate(name="Flow", description="description", data=data, endpoint_name="f") - response = await client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers) - response.raise_for_status() - return response.json()["id"] +@pytest.mark.benchmark +async def test_build_flow_invalid_flow_id(client, logged_in_headers): + """Test starting a build with an invalid flow ID.""" + invalid_flow_id = uuid.uuid4() + response = await client.post(f"api/v1/build/{invalid_flow_id}/flow", json={}, headers=logged_in_headers) + assert response.status_code == codes.NOT_FOUND -# TODO: Fix this test -# async def test_multiple_runs_with_no_payload_generate_max_vertex_builds( -# client, json_memory_chatbot_no_llm, logged_in_headers -# ): -# """Test that multiple builds of a flow generate the correct number of vertex builds.""" -# # Create the initial flow -# flow_id = await _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) +@pytest.mark.benchmark +async def test_build_flow_start_only(client, json_memory_chatbot_no_llm, logged_in_headers): + """Test only the build flow start endpoint.""" + # First create the flow + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) -# # Get the flow data to count nodes before making requests -# response = await client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers) -# flow_data = response.json() -# num_nodes = len(flow_data["data"]["nodes"]) -# max_vertex_builds = get_settings_service().settings.max_vertex_builds_per_vertex + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers) -# logger.debug(f"Starting test with {num_nodes} nodes, max_vertex_builds={max_vertex_builds}") + # Assert response structure + assert "job_id" in build_response + assert isinstance(build_response["job_id"], str) + # Verify it's a valid UUID + assert uuid.UUID(build_response["job_id"]) -# # Make multiple build requests - ensure we exceed max_vertex_builds significantly -# num_requests = max_vertex_builds * 3 # Triple the max to ensure rotation -# for i in range(num_requests): -# # Generate a random session ID for each request -# session_id = session_id_generator() -# payload = {"inputs": {"session": session_id, "type": "chat", "input_value": f"Test message {i + 1}"}} -# async with client.stream("POST", f"api/v1/build/{flow_id}/flow", -# json=payload, headers=logged_in_headers) as r: -# await consume_and_assert_stream(r) +@pytest.mark.benchmark +async def test_build_flow_start_with_inputs(client, json_memory_chatbot_no_llm, logged_in_headers): + """Test the build flow start endpoint with input data.""" + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) -# # Add a small delay between requests to ensure proper ordering -# await asyncio.sleep(0.1) + # Start build with some input data + test_inputs = {"inputs": {"session": "test_session", "input_value": "test message"}} -# # Track builds after each request -# async with session_scope() as session: -# builds = await get_vertex_builds_by_flow_id(db=session, flow_id=flow_id) -# by_vertex = {} -# for build in builds: -# build_dict = build.model_dump() -# vertex_id = build_dict.get("id") -# by_vertex.setdefault(vertex_id, []).append(build_dict) + build_response = await build_flow(client, flow_id, logged_in_headers, json=test_inputs) -# # Log state of each vertex with more details -# for vertex_id, vertex_builds in by_vertex.items(): -# vertex_builds.sort(key=lambda x: x.get("timestamp")) -# logger.debug( -# f"Request {i + 1} (session={session_id}) - Vertex {vertex_id}: {len(vertex_builds)} builds " -# f"(max allowed: {max_vertex_builds}), " -# f"build_ids: {[b.get('build_id') for b in vertex_builds]}" -# ) + assert "job_id" in build_response + assert isinstance(build_response["job_id"], str) + assert uuid.UUID(build_response["job_id"]) -# # Wait a bit before final verification to ensure all DB operations complete -# await asyncio.sleep(0.5) -# # Final verification with detailed logging -# async with session_scope() as session: -# vertex_builds = await get_vertex_builds_by_flow_id(db=session, flow_id=flow_id) -# assert len(vertex_builds) > 0, "No vertex builds found" +@pytest.mark.benchmark +async def test_build_flow_polling(client, json_memory_chatbot_no_llm, logged_in_headers): + """Test the build flow endpoint with polling (non-streaming).""" + # First create the flow + flow_id = await create_flow(client, json_memory_chatbot_no_llm, logged_in_headers) -# builds_by_vertex = {} -# for build in vertex_builds: -# build_dict = build.model_dump() -# vertex_id = build_dict.get("id") -# builds_by_vertex.setdefault(vertex_id, []).append(build_dict) + # Start the build and get job_id + build_response = await build_flow(client, flow_id, logged_in_headers) + job_id = build_response["job_id"] + assert job_id is not None -# # Log detailed final state -# logger.debug(f"\nFinal state after {num_requests} requests:") -# for vertex_id, builds in builds_by_vertex.items(): -# builds.sort(key=lambda x: x.get("timestamp")) -# logger.debug( -# f"Vertex {vertex_id}: {len(builds)} builds " -# f"(oldest: {builds[0].get('timestamp')}, " -# f"newest: {builds[-1].get('timestamp')}), " -# f"build_ids: {[b.get('build_id') for b in builds]}" -# ) + # Create a response object that mimics a streaming response but uses polling + class PollingResponse: + def __init__(self, client, job_id, headers): + self.client = client + self.job_id = job_id + self.headers = headers + self.status_code = codes.OK -# # Log individual build details for debugging -# for build in builds: -# logger.debug( -# f" - Build {build.get('build_id')}: timestamp={build.get('timestamp')}, " -# f"valid={build.get('valid')}" -# ) + async def aiter_lines(self): + try: + sleeps = 0 + max_sleeps = 100 + while True: + response = await self.client.get( + f"api/v1/build/{self.job_id}/events?stream=false", headers=self.headers + ) + assert response.status_code == codes.OK + data = response.json() -# # Verify each vertex has correct number of builds -# for vertex_id, vertex_builds_list in builds_by_vertex.items(): -# assert len(vertex_builds_list) == max_vertex_builds, ( -# f"Vertex {vertex_id} has {len(vertex_builds_list)} builds, expected {max_vertex_builds}" -# ) + if data["event"] is None: + # No event available, add delay to prevent tight polling + await asyncio.sleep(0.1) + sleeps += 1 + continue -# # Verify total number of builds -# total_builds = len(vertex_builds) -# expected_total = max_vertex_builds * num_nodes -# assert total_builds == expected_total, ( -# f"Total builds ({total_builds}) doesn't match expected " -# f"({max_vertex_builds} builds/vertex * {num_nodes} nodes = {expected_total})" -# ) -# assert all(vertex_build.get("valid") for vertex_build in vertex_builds) + yield data["event"] + + # If this was the end event, stop polling + if '"end"' in data["event"]: + break + if sleeps > max_sleeps: + msg = "Build event polling timed out." + raise TimeoutError(msg) + except asyncio.TimeoutError as e: + msg = "Build event polling timed out." + raise TimeoutError(msg) from e + + polling_response = PollingResponse(client, job_id, logged_in_headers) + + # Use the same consume_and_assert_stream function to verify the events + await consume_and_assert_stream(polling_response, job_id) diff --git a/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx index 8b16cccc0..434e9ce36 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/NodeStatus/index.tsx @@ -6,7 +6,8 @@ import ShadTooltip from "@/components/common/shadTooltipComponent"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { ICON_STROKE_WIDTH } from "@/constants/constants"; -import { BuildStatus } from "@/constants/enums"; +import { BuildStatus, EventDeliveryType } from "@/constants/enums"; +import { useGetConfig } from "@/controllers/API/queries/config/use-get-config"; import { track } from "@/customization/utils/analytics"; import { useDarkStore } from "@/stores/darkStore"; import useFlowStore from "@/stores/flowStore"; @@ -69,11 +70,16 @@ export default function NodeStatus({ const isBuilding = useFlowStore((state) => state.isBuilding); const setNode = useFlowStore((state) => state.setNode); const version = useDarkStore((state) => state.version); + const config = useGetConfig(); + const shouldStreamEvents = () => { + // Get from useGetConfig store + return config.data?.event_delivery === EventDeliveryType.STREAMING; + }; function handlePlayWShortcut() { if (buildStatus === BuildStatus.BUILDING || isBuilding || !selected) return; setValidationStatus(null); - buildFlow({ stopNodeId: nodeId }); + buildFlow({ stopNodeId: nodeId, stream: shouldStreamEvents() }); } const play = useShortcutsStore((state) => state.play); @@ -164,7 +170,7 @@ export default function NodeStatus({ return; } if (buildStatus === BuildStatus.BUILDING || isBuilding) return; - buildFlow({ stopNodeId: nodeId }); + buildFlow({ stopNodeId: nodeId, stream: shouldStreamEvents() }); track("Flow Build - Clicked", { stopNodeId: nodeId }); }; diff --git a/src/frontend/src/constants/constants.ts b/src/frontend/src/constants/constants.ts index f70d2bb63..bd0016792 100644 --- a/src/frontend/src/constants/constants.ts +++ b/src/frontend/src/constants/constants.ts @@ -715,7 +715,7 @@ export const STORE_TITLE = "Langflow Store"; export const NO_API_KEY = "You don't have an API key."; export const INSERT_API_KEY = "Insert your Langflow API key."; export const INVALID_API_KEY = "Your API key is not valid. "; -export const CREATE_API_KEY = `Don’t have an API key? Sign up at`; +export const CREATE_API_KEY = `Don't have an API key? Sign up at`; export const STATUS_BUILD = "Build to validate status."; export const STATUS_INACTIVE = "Execution blocked"; export const STATUS_BUILDING = "Building..."; @@ -1005,3 +1005,10 @@ export const ICON_STROKE_WIDTH = 1.5; export const DEFAULT_PLACEHOLDER = "Type something..."; export const DEFAULT_TOOLSET_PLACEHOLDER = "Used as a tool"; + +export const POLLING_MESSAGES = { + ENDPOINT_NOT_AVAILABLE: "Endpoint not available", + STREAMING_NOT_SUPPORTED: "Streaming not supported", +} as const; + +export const POLLING_INTERVAL = 100; // milliseconds between polling attempts diff --git a/src/frontend/src/constants/enums.ts b/src/frontend/src/constants/enums.ts index edcafabfa..915ccf4fb 100644 --- a/src/frontend/src/constants/enums.ts +++ b/src/frontend/src/constants/enums.ts @@ -38,3 +38,8 @@ export enum IOOutputTypes { STRING_LIST = "StringListOutput", DATA = "DataOutput", } + +export enum EventDeliveryType { + STREAMING = "streaming", + POLLING = "polling", +} diff --git a/src/frontend/src/controllers/API/queries/config/use-get-config.ts b/src/frontend/src/controllers/API/queries/config/use-get-config.ts index b3457b13a..4fcc71040 100644 --- a/src/frontend/src/controllers/API/queries/config/use-get-config.ts +++ b/src/frontend/src/controllers/API/queries/config/use-get-config.ts @@ -1,3 +1,4 @@ +import { EventDeliveryType } from "@/constants/enums"; import useFlowsManagerStore from "@/stores/flowsManagerStore"; import { useUtilityStore } from "@/stores/utilityStore"; import axios from "axios"; @@ -13,6 +14,7 @@ export interface ConfigResponse { health_check_max_retries: number; max_file_size_upload: number; feature_flags: Record; + event_delivery: EventDeliveryType; } export const useGetConfig: useQueryFunctionType = ( diff --git a/src/frontend/src/modals/IOModal/new-modal.tsx b/src/frontend/src/modals/IOModal/new-modal.tsx index 39170d886..71dd9478f 100644 --- a/src/frontend/src/modals/IOModal/new-modal.tsx +++ b/src/frontend/src/modals/IOModal/new-modal.tsx @@ -1,4 +1,5 @@ -import { Separator } from "@/components/ui/separator"; +import { EventDeliveryType } from "@/constants/enums"; +import { useGetConfig } from "@/controllers/API/queries/config/use-get-config"; import { useDeleteMessages, useGetMessagesQuery, @@ -16,7 +17,6 @@ import { IOModalPropsType } from "../../types/components"; import { cn } from "../../utils/utils"; import BaseModal from "../baseModal"; import { ChatViewWrapper } from "./components/chat-view-wrapper"; -import ChatView from "./components/chatView/chat-view"; import { SelectedViewField } from "./components/selected-view-field"; import { SidebarOpenView } from "./components/sidebar-open-view"; @@ -136,6 +136,11 @@ export default function IOModal({ const chatValue = useUtilityStore((state) => state.chatValueStore); const setChatValue = useUtilityStore((state) => state.setChatValueStore); + const config = useGetConfig(); + + function shouldStreamEvents() { + return config.data?.event_delivery === EventDeliveryType.STREAMING; + } const sendMessage = useCallback( async ({ @@ -154,6 +159,7 @@ export default function IOModal({ files: files, silent: true, session: sessionId, + stream: shouldStreamEvents(), }).catch((err) => { console.error(err); }); diff --git a/src/frontend/src/stores/flowStore.ts b/src/frontend/src/stores/flowStore.ts index 8564a5c3e..64ef4ba21 100644 --- a/src/frontend/src/stores/flowStore.ts +++ b/src/frontend/src/stores/flowStore.ts @@ -594,6 +594,7 @@ const useFlowStore = create((set, get) => ({ files, silent, session, + stream = true, }: { startNodeId?: string; stopNodeId?: string; @@ -601,6 +602,7 @@ const useFlowStore = create((set, get) => ({ files?: string[]; silent?: boolean; session?: string; + stream?: boolean; }) => { get().setIsBuilding(true); const currentFlow = useFlowsManagerStore.getState().currentFlow; @@ -825,6 +827,7 @@ const useFlowStore = create((set, get) => ({ nodes: get().nodes || undefined, edges: get().edges || undefined, logBuilds: get().onFlowPage, + stream, }); get().setIsBuilding(false); get().revertBuiltStatusFromBuilding(); diff --git a/src/frontend/src/types/zustand/flow/index.ts b/src/frontend/src/types/zustand/flow/index.ts index 119c99e08..bcfa59374 100644 --- a/src/frontend/src/types/zustand/flow/index.ts +++ b/src/frontend/src/types/zustand/flow/index.ts @@ -146,6 +146,7 @@ export type FlowStoreType = { files, silent, session, + stream, }: { startNodeId?: string; stopNodeId?: string; @@ -153,6 +154,7 @@ export type FlowStoreType = { files?: string[]; silent?: boolean; session?: string; + stream?: boolean; }) => Promise; getFlow: () => { nodes: Node[]; edges: EdgeType[]; viewport: Viewport }; updateVerticesBuild: ( diff --git a/src/frontend/src/utils/buildUtils.ts b/src/frontend/src/utils/buildUtils.ts index febb926c2..eed7e1540 100644 --- a/src/frontend/src/utils/buildUtils.ts +++ b/src/frontend/src/utils/buildUtils.ts @@ -1,4 +1,8 @@ -import { BASE_URL_API } from "@/constants/constants"; +import { + BASE_URL_API, + POLLING_INTERVAL, + POLLING_MESSAGES, +} from "@/constants/constants"; import { performStreamingRequest } from "@/controllers/API/api"; import { useMessagesStore } from "@/stores/messagesStore"; import { Edge, Node } from "@xyflow/react"; @@ -34,6 +38,7 @@ type BuildVerticesParams = { edges?: Edge[]; logBuilds?: boolean; session?: string; + stream?: boolean; }; function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI { @@ -124,10 +129,15 @@ export async function buildFlowVerticesWithFallback( params: BuildVerticesParams, ) { try { - return await buildFlowVertices(params); + // Use shouldUsePolling() to determine stream mode + return await buildFlowVertices({ ...params }); } catch (e: any) { - if (e.message === "Endpoint not available") { - return await buildVertices(params); + if ( + e.message === POLLING_MESSAGES.ENDPOINT_NOT_AVAILABLE || + e.message === POLLING_MESSAGES.STREAMING_NOT_SUPPORTED + ) { + // Fallback to polling + return await buildFlowVertices({ ...params, stream: false }); } throw e; } @@ -135,6 +145,63 @@ export async function buildFlowVerticesWithFallback( const MIN_VISUAL_BUILD_TIME_MS = 300; +async function pollBuildEvents( + url: string, + buildResults: Array, + verticesStartTimeMs: Map, + callbacks: { + onBuildStart?: (idList: VertexLayerElementType[]) => void; + onBuildUpdate?: (data: any, status: BuildStatus, buildId: string) => void; + onBuildComplete?: (allNodesValid: boolean) => void; + onBuildError?: ( + title: string, + list: string[], + idList?: VertexLayerElementType[], + ) => void; + onGetOrderSuccess?: () => void; + onValidateNodes?: (nodes: string[]) => void; + }, +): Promise { + let isDone = false; + while (!isDone) { + const response = await fetch(`${url}?stream=false`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + throw new Error("Error polling build events"); + } + + const data = await response.json(); + if (!data.event) { + // No event in this request, try again + await new Promise((resolve) => setTimeout(resolve, 100)); + continue; + } + + // Process the event + const event = JSON.parse(data.event); + await onEvent( + event.event, + event.data, + buildResults, + verticesStartTimeMs, + callbacks, + ); + + // Check if this was the end event or if we got a null value + if (event.event === "end" || data.event === null) { + isDone = true; + } + + // Add a small delay between polls to avoid overwhelming the server + await new Promise((resolve) => setTimeout(resolve, POLLING_INTERVAL)); + } +} + export async function buildFlowVertices({ flowId, input_value, @@ -152,18 +219,26 @@ export async function buildFlowVertices({ edges, logBuilds, session, + stream = true, }: BuildVerticesParams) { const inputs = {}; - let url = `${BASE_URL_API}build/${flowId}/flow?`; + let buildUrl = `${BASE_URL_API}build/${flowId}/flow`; + const queryParams = new URLSearchParams(); + if (startNodeId) { - url = `${url}&start_component_id=${startNodeId}`; + queryParams.append("start_component_id", startNodeId); } if (stopNodeId) { - url = `${url}&stop_component_id=${stopNodeId}`; + queryParams.append("stop_component_id", stopNodeId); } if (logBuilds !== undefined) { - url = `${url}&log_builds=${logBuilds}`; + queryParams.append("log_builds", logBuilds.toString()); } + + if (queryParams.toString()) { + buildUrl = `${buildUrl}?${queryParams.toString()}`; + } + const postData = {}; if (files) { postData["files"] = files; @@ -184,182 +259,272 @@ export async function buildFlowVertices({ postData["inputs"] = inputs; } - const buildResults: Array = []; + try { + // First, start the build and get the job ID + const buildResponse = await fetch(buildUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(postData), + }); - const verticesStartTimeMs: Map = new Map(); - - const onEvent = async (type, data): Promise => { - const onStartVertices = (ids: Array) => { - useFlowStore.getState().updateBuildStatus(ids, BuildStatus.TO_BUILD); - if (onBuildStart) - onBuildStart(ids.map((id) => ({ id: id, reference: id }))); - ids.forEach((id) => verticesStartTimeMs.set(id, Date.now())); - }; - console.log("type", type); - console.log("data", data); - switch (type) { - case "vertices_sorted": { - const verticesToRun = data.to_run; - const verticesIds = data.ids; - - onStartVertices(verticesIds); - - let verticesLayers: Array> = - verticesIds.map((id: string) => { - return [{ id: id, reference: id }]; - }); - - useFlowStore.getState().updateVerticesBuild({ - verticesLayers, - verticesIds, - verticesToRun, - }); - if (onValidateNodes) { - try { - onValidateNodes(data.to_run); - if (onGetOrderSuccess) onGetOrderSuccess(); - useFlowStore.getState().setIsBuilding(true); - return true; - } catch (e) { - useFlowStore.getState().setIsBuilding(false); - return false; - } - } - return true; + if (!buildResponse.ok) { + if (buildResponse.status === 404) { + throw new Error("Flow not found"); } - case "end_vertex": { - const buildData = data.build_data; - const startTimeMs = verticesStartTimeMs.get(buildData.id); - if (startTimeMs) { - const delta = Date.now() - startTimeMs; - if (delta < MIN_VISUAL_BUILD_TIME_MS) { - // this is a visual trick to make the build process look more natural - await new Promise((resolve) => - setTimeout(resolve, MIN_VISUAL_BUILD_TIME_MS - delta), - ); - } - } + throw new Error("Error starting build process"); + } - if (onBuildUpdate) { - if (!buildData.valid) { - // lots is a dictionary with the key the output field name and the value the log object - // logs: { [key: string]: { message: any; type: string }[] }; - const errorMessages = Object.keys(buildData.data.outputs).flatMap( - (key) => { - const outputs = buildData.data.outputs[key]; - if (Array.isArray(outputs)) { - return outputs - .filter((log) => isErrorLogType(log.message)) - .map((log) => log.message.errorMessage); - } - if (!isErrorLogType(outputs.message)) { - return []; - } - return [outputs.message.errorMessage]; - }, - ); - onBuildError!("Error Building Component", errorMessages, [ + const { job_id } = await buildResponse.json(); + + // Then stream the events + const eventsUrl = `${BASE_URL_API}build/${job_id}/events`; + const buildResults: Array = []; + const verticesStartTimeMs: Map = new Map(); + + if (stream) { + return performStreamingRequest({ + method: "GET", + url: eventsUrl, + onData: async (event) => { + const type = event["event"]; + const data = event["data"]; + return await onEvent(type, data, buildResults, verticesStartTimeMs, { + onBuildStart, + onBuildUpdate, + onBuildComplete, + onBuildError, + onGetOrderSuccess, + onValidateNodes, + }); + }, + onError: (statusCode) => { + if (statusCode === 404) { + throw new Error("Build job not found"); + } + throw new Error("Error processing build events"); + }, + onNetworkError: (error: Error) => { + if (error.name === "AbortError") { + onBuildStopped && onBuildStopped(); + return; + } + onBuildError!("Error Building Component", [ + "Network error. Please check the connection to the server.", + ]); + }, + }); + } else { + const callbacks = { + onBuildStart, + onBuildUpdate, + onBuildComplete, + onBuildError, + onGetOrderSuccess, + onValidateNodes, + }; + return pollBuildEvents( + eventsUrl, + buildResults, + verticesStartTimeMs, + callbacks, + ); + } + } catch (error) { + console.error("Build process error:", error); + onBuildError!("Error Building Flow", [ + (error as Error).message || "An unexpected error occurred", + ]); + throw error; + } +} +/** + * Handles various build events and calls corresponding callbacks. + * + * @param {string} type - The event type. + * @param {any} data - The event data. + * @param {boolean[]} buildResults - Array tracking build results. + * @param {Map} verticesStartTimeMs - Map tracking start times for vertices. + * @param {Object} callbacks - Object containing callback functions. + * @param {(idList: VertexLayerElementType[]) => void} [callbacks.onBuildStart] - Callback when vertices start building. + * @param {(data: any, status: BuildStatus, buildId: string) => void} [callbacks.onBuildUpdate] - Callback for build updates. + * @param {(allNodesValid: boolean) => void} [callbacks.onBuildComplete] - Callback when build completes. + * @param {(title: string, list: string[], idList?: VertexLayerElementType[]) => void} [callbacks.onBuildError] - Callback on build errors. + * @param {() => void} [callbacks.onGetOrderSuccess] - Callback for successful ordering. + * @param {(nodes: string[]) => void} [callbacks.onValidateNodes] - Callback to validate nodes. + * @param {(lock: boolean) => void} [callbacks.setLockChat] - Callback to lock/unlock chat. + * @returns {Promise} Promise that resolves to true if the event was handled successfully. + */ +async function onEvent( + type: string, + data: any, + buildResults: boolean[], + verticesStartTimeMs: Map, + callbacks: { + onBuildStart?: (idList: VertexLayerElementType[]) => void; + onBuildUpdate?: (data: any, status: BuildStatus, buildId: string) => void; + onBuildComplete?: (allNodesValid: boolean) => void; + onBuildError?: ( + title: string, + list: string[], + idList?: VertexLayerElementType[], + ) => void; + onGetOrderSuccess?: () => void; + onValidateNodes?: (nodes: string[]) => void; + }, +): Promise { + const { + onBuildStart, + onBuildUpdate, + onBuildComplete, + onBuildError, + onGetOrderSuccess, + onValidateNodes, + } = callbacks; + + // Helper to update status and register start times for an array of vertex IDs. + const onStartVertices = (ids: Array) => { + useFlowStore.getState().updateBuildStatus(ids, BuildStatus.TO_BUILD); + if (onBuildStart) { + onBuildStart(ids.map((id) => ({ id: id, reference: id }))); + } + ids.forEach((id) => verticesStartTimeMs.set(id, Date.now())); + }; + + switch (type) { + case "vertices_sorted": { + const verticesToRun = data.to_run; + const verticesIds = data.ids; + + onStartVertices(verticesIds); + + const verticesLayers: Array> = + verticesIds.map((id: string) => [{ id: id, reference: id }]); + + useFlowStore.getState().updateVerticesBuild({ + verticesLayers, + verticesIds, + verticesToRun, + }); + if (onValidateNodes) { + try { + onValidateNodes(data.to_run); + if (onGetOrderSuccess) onGetOrderSuccess(); + useFlowStore.getState().setIsBuilding(true); + return true; + } catch (e) { + useFlowStore.getState().setIsBuilding(false); + return false; + } + } + return true; + } + case "end_vertex": { + const buildData = data.build_data; + const startTimeMs = verticesStartTimeMs.get(buildData.id); + if (startTimeMs) { + const delta = Date.now() - startTimeMs; + if (delta < MIN_VISUAL_BUILD_TIME_MS) { + // Ensure a minimum visual build time for a smoother UI experience. + await new Promise((resolve) => + setTimeout(resolve, MIN_VISUAL_BUILD_TIME_MS - delta), + ); + } + } + + if (onBuildUpdate) { + if (!buildData.valid) { + // Aggregate error messages from the build outputs. + const errorMessages = Object.keys(buildData.data.outputs).flatMap( + (key) => { + const outputs = buildData.data.outputs[key]; + if (Array.isArray(outputs)) { + return outputs + .filter((log) => isErrorLogType(log.message)) + .map((log) => log.message.errorMessage); + } + if (!isErrorLogType(outputs.message)) { + return []; + } + return [outputs.message.errorMessage]; + }, + ); + onBuildError && + onBuildError("Error Building Component", errorMessages, [ { id: buildData.id }, ]); - onBuildUpdate(buildData, BuildStatus.ERROR, ""); - buildResults.push(false); - return false; - } else { - onBuildUpdate(buildData, BuildStatus.BUILT, ""); - buildResults.push(true); - } + onBuildUpdate(buildData, BuildStatus.ERROR, ""); + buildResults.push(false); + return false; + } else { + onBuildUpdate(buildData, BuildStatus.BUILT, ""); + buildResults.push(true); } + } - await useFlowStore.getState().clearEdgesRunningByNodes(); + await useFlowStore.getState().clearEdgesRunningByNodes(); - if (buildData.next_vertices_ids) { - if (isStringArray(buildData.next_vertices_ids)) { - useFlowStore - .getState() - .setCurrentBuildingNodeId(buildData?.next_vertices_ids ?? []); - useFlowStore - .getState() - .updateEdgesRunningByNodes( - buildData?.next_vertices_ids ?? [], - true, - ); - } - onStartVertices(buildData.next_vertices_ids); + if (buildData.next_vertices_ids) { + if (isStringArray(buildData.next_vertices_ids)) { + useFlowStore + .getState() + .setCurrentBuildingNodeId(buildData.next_vertices_ids ?? []); + useFlowStore + .getState() + .updateEdgesRunningByNodes(buildData.next_vertices_ids ?? [], true); } - return true; + onStartVertices(buildData.next_vertices_ids); } - case "add_message": { - //adds a message to the messsage table - useMessagesStore.getState().addMessage(data); - return true; - } - case "token": { - // flushSync and timeout is needed to avoid react batched updates - setTimeout(() => { - flushSync(() => { - useMessagesStore.getState().updateMessageText(data.id, data.chunk); - }); - }, 10); - return true; - } - case "remove_message": { - useMessagesStore.getState().removeMessage(data); - return true; - } - case "end": { - const allNodesValid = buildResults.every((result) => result); - onBuildComplete!(allNodesValid); - useFlowStore.getState().setIsBuilding(false); - return true; - } - case "error": { - if (data?.category === "error") { - useMessagesStore.getState().addMessage(data); - if (!data?.properties?.source?.id) { - onBuildError!("Error Building Flow", [data.text]); - } - } - buildResults.push(false); - return true; - } - case "build_start": - useFlowStore - .getState() - .updateBuildStatus([data.id], BuildStatus.BUILDING); - break; - case "build_end": - useFlowStore.getState().updateBuildStatus([data.id], BuildStatus.BUILT); - break; - default: - return true; + return true; } - return true; - }; - return performStreamingRequest({ - method: "POST", - url, - body: postData, - onData: async (event) => { - const type = event["event"]; - const data = event["data"]; - return await onEvent(type, data); - }, - onError: (statusCode) => { - if (statusCode === 404) { - throw new Error("Endpoint not available"); + case "add_message": { + // Add a message to the messages store. + useMessagesStore.getState().addMessage(data); + return true; + } + case "token": { + // Use flushSync with a timeout to avoid React batching issues. + setTimeout(() => { + flushSync(() => { + useMessagesStore.getState().updateMessageText(data.id, data.chunk); + }); + }, 10); + return true; + } + case "remove_message": { + useMessagesStore.getState().removeMessage(data); + return true; + } + case "end": { + const allNodesValid = buildResults.every((result) => result); + onBuildComplete && onBuildComplete(allNodesValid); + useFlowStore.getState().setIsBuilding(false); + return true; + } + case "error": { + if (data?.category === "error") { + useMessagesStore.getState().addMessage(data); + // Use a falsy check to correctly determine if the source ID is missing. + if (!data?.properties?.source?.id) { + onBuildError && onBuildError("Error Building Flow", [data.text]); + } } - throw new Error("Error Building Component"); - }, - onNetworkError: (error: Error) => { - if (error.name === "AbortError") { - onBuildStopped && onBuildStopped(); - return; - } - onBuildError!("Error Building Component", [ - "Network error. Please check the connection to the server.", - ]); - }, - }); + buildResults.push(false); + return true; + } + case "build_start": + useFlowStore + .getState() + .updateBuildStatus([data.id], BuildStatus.BUILDING); + break; + case "build_end": + useFlowStore.getState().updateBuildStatus([data.id], BuildStatus.BUILT); + break; + default: + return true; + } + return true; } export async function buildVertices({ diff --git a/src/frontend/tests/core/integrations/Basic Prompting.spec.ts b/src/frontend/tests/core/integrations/Basic Prompting.spec.ts index 04a4e93b5..18b2f5da4 100644 --- a/src/frontend/tests/core/integrations/Basic Prompting.spec.ts +++ b/src/frontend/tests/core/integrations/Basic Prompting.spec.ts @@ -3,8 +3,9 @@ import * as dotenv from "dotenv"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Basic Prompting (Hello, World)", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Blog Writer.spec.ts b/src/frontend/tests/core/integrations/Blog Writer.spec.ts index 08fcc67f9..58f46fa3b 100644 --- a/src/frontend/tests/core/integrations/Blog Writer.spec.ts +++ b/src/frontend/tests/core/integrations/Blog Writer.spec.ts @@ -3,8 +3,9 @@ import * as dotenv from "dotenv"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Blog Writer", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Custom Component Generator.spec.ts b/src/frontend/tests/core/integrations/Custom Component Generator.spec.ts index 279d30520..db405e929 100644 --- a/src/frontend/tests/core/integrations/Custom Component Generator.spec.ts +++ b/src/frontend/tests/core/integrations/Custom Component Generator.spec.ts @@ -5,8 +5,9 @@ import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; import { waitForOpenModalWithChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Custom Component Generator", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Document QA.spec.ts b/src/frontend/tests/core/integrations/Document QA.spec.ts index e1a26348c..12521bf34 100644 --- a/src/frontend/tests/core/integrations/Document QA.spec.ts +++ b/src/frontend/tests/core/integrations/Document QA.spec.ts @@ -3,8 +3,9 @@ import * as dotenv from "dotenv"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Document Q&A", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Image Sentiment Analysis.spec.ts b/src/frontend/tests/core/integrations/Image Sentiment Analysis.spec.ts index 893929df8..b96d9020f 100644 --- a/src/frontend/tests/core/integrations/Image Sentiment Analysis.spec.ts +++ b/src/frontend/tests/core/integrations/Image Sentiment Analysis.spec.ts @@ -2,18 +2,14 @@ import { expect, test } from "@playwright/test"; import * as dotenv from "dotenv"; import { readFileSync } from "fs"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { buildDataTransfer } from "../../utils/build-data-transfer"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; import { waitForOpenModalWithoutChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Image Sentiment Analysis", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Market Research.spec.ts b/src/frontend/tests/core/integrations/Market Research.spec.ts index a6c1a859f..84a9744e7 100644 --- a/src/frontend/tests/core/integrations/Market Research.spec.ts +++ b/src/frontend/tests/core/integrations/Market Research.spec.ts @@ -5,8 +5,9 @@ import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; import { waitForOpenModalWithChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Market Research", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Memory Chatbot.spec.ts b/src/frontend/tests/core/integrations/Memory Chatbot.spec.ts index 0d12fb09b..ab9bf2e55 100644 --- a/src/frontend/tests/core/integrations/Memory Chatbot.spec.ts +++ b/src/frontend/tests/core/integrations/Memory Chatbot.spec.ts @@ -3,8 +3,9 @@ import * as dotenv from "dotenv"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Memory Chatbot", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Prompt Chaining.spec.ts b/src/frontend/tests/core/integrations/Prompt Chaining.spec.ts index 262bcbb76..babcf171e 100644 --- a/src/frontend/tests/core/integrations/Prompt Chaining.spec.ts +++ b/src/frontend/tests/core/integrations/Prompt Chaining.spec.ts @@ -1,17 +1,13 @@ import { expect, test } from "@playwright/test"; import * as dotenv from "dotenv"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; import { waitForOpenModalWithChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Prompt Chaining", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/SEO Keyword Generator.spec.ts b/src/frontend/tests/core/integrations/SEO Keyword Generator.spec.ts index b9561d37a..adda1fcb5 100644 --- a/src/frontend/tests/core/integrations/SEO Keyword Generator.spec.ts +++ b/src/frontend/tests/core/integrations/SEO Keyword Generator.spec.ts @@ -1,17 +1,13 @@ import { expect, test } from "@playwright/test"; import * as dotenv from "dotenv"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; import { waitForOpenModalWithoutChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "SEO Keyword Generator", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/SaaS Pricing.spec.ts b/src/frontend/tests/core/integrations/SaaS Pricing.spec.ts index 6c047f77c..34265db09 100644 --- a/src/frontend/tests/core/integrations/SaaS Pricing.spec.ts +++ b/src/frontend/tests/core/integrations/SaaS Pricing.spec.ts @@ -1,17 +1,13 @@ import { expect, test } from "@playwright/test"; import * as dotenv from "dotenv"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; import { waitForOpenModalWithoutChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "SaaS Pricing", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Simple Agent.spec.ts b/src/frontend/tests/core/integrations/Simple Agent.spec.ts index 3bfa8c893..080448b04 100644 --- a/src/frontend/tests/core/integrations/Simple Agent.spec.ts +++ b/src/frontend/tests/core/integrations/Simple Agent.spec.ts @@ -3,8 +3,9 @@ import * as dotenv from "dotenv"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Simple Agent", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Travel Planning Agent.spec.ts b/src/frontend/tests/core/integrations/Travel Planning Agent.spec.ts index fb0458f51..e140ab1d2 100644 --- a/src/frontend/tests/core/integrations/Travel Planning Agent.spec.ts +++ b/src/frontend/tests/core/integrations/Travel Planning Agent.spec.ts @@ -1,15 +1,11 @@ import { expect, Page, test } from "@playwright/test"; import * as dotenv from "dotenv"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Travel Planning Agent", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Twitter Thread Generator.spec.ts b/src/frontend/tests/core/integrations/Twitter Thread Generator.spec.ts index ba0a92815..03edbbd4a 100644 --- a/src/frontend/tests/core/integrations/Twitter Thread Generator.spec.ts +++ b/src/frontend/tests/core/integrations/Twitter Thread Generator.spec.ts @@ -1,17 +1,13 @@ import { expect, test } from "@playwright/test"; import * as dotenv from "dotenv"; import path from "path"; -import { addNewApiKeys } from "../../utils/add-new-api-keys"; -import { adjustScreenView } from "../../utils/adjust-screen-view"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { getAllResponseMessage } from "../../utils/get-all-response-message"; import { initialGPTsetup } from "../../utils/initialGPTsetup"; -import { removeOldApiKeys } from "../../utils/remove-old-api-keys"; -import { selectGptModel } from "../../utils/select-gpt-model"; -import { updateOldComponents } from "../../utils/update-old-components"; import { waitForOpenModalWithoutChatInput } from "../../utils/wait-for-open-modal"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Twitter Thread Generator", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/core/integrations/Vector Store.spec.ts b/src/frontend/tests/core/integrations/Vector Store.spec.ts index 9bff2b9f1..66a6d52e3 100644 --- a/src/frontend/tests/core/integrations/Vector Store.spec.ts +++ b/src/frontend/tests/core/integrations/Vector Store.spec.ts @@ -1,9 +1,10 @@ -import { expect, Page, test } from "@playwright/test"; +import { expect, test } from "@playwright/test"; import path from "path"; import { awaitBootstrapTest } from "../../utils/await-bootstrap-test"; import { extractAndCleanCode } from "../../utils/extract-and-clean-code"; +import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes"; -test( +withEventDeliveryModes( "Vector Store RAG", { tag: ["@release", "@starter-projects"] }, async ({ page }) => { diff --git a/src/frontend/tests/utils/withEventDeliveryModes.ts b/src/frontend/tests/utils/withEventDeliveryModes.ts new file mode 100644 index 000000000..535452978 --- /dev/null +++ b/src/frontend/tests/utils/withEventDeliveryModes.ts @@ -0,0 +1,34 @@ +import { Page, test } from "@playwright/test"; + +type TestFunction = (args: { page: Page }) => Promise; +type TestConfig = Parameters[1]; + +/** + * Wraps a test function to run it with both streaming and polling event delivery modes. + * + * @param title The test title + * @param config The test configuration (tags, etc) + * @param testFn The test function to wrap + */ +export function withEventDeliveryModes( + title: string, + config: TestConfig, + testFn: TestFunction, +) { + const eventDeliveryModes = ["streaming", "polling"] as const; + + for (const eventDelivery of eventDeliveryModes) { + test(`${title} - ${eventDelivery}`, config, async ({ page }) => { + // Intercept the config request and modify the event_delivery setting + await page.route("**/api/v1/config", async (route) => { + const response = await route.fetch(); + const json = await response.json(); + json.event_delivery = eventDelivery; + await route.fulfill({ response, json }); + }); + + // Run the original test function + await testFn({ page }); + }); + } +}