diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index d7fed1a70..ec107092b 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional, Union from langflow.services.auth.utils import api_key_security, get_current_active_user @@ -33,8 +33,8 @@ from langflow.services.utils import get_session from langflow.worker import process_graph_cached_task from sqlmodel import Session -if TYPE_CHECKING: - from langflow.services.task.manager import TaskService + +from langflow.services.task.manager import TaskService # build router router = APIRouter(tags=["Base"]) @@ -183,11 +183,17 @@ async def process_flow( async def get_task_status(task_id: str): task_service = get_task_service() task = task_service.get_task(task_id) + result = None + if task.ready(): + result = task.result + if isinstance(result, dict) and "result" in result: + result = result["result"] + elif hasattr(result, "result"): + result = result.result + if task is None: raise HTTPException(status_code=404, detail="Task not found") - return TaskStatusResponse( - status=task.status, result=task.result if task.ready() else None - ) + return TaskStatusResponse(status=task.status, result=result) @router.post( diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 877d32df2..6490f02fb 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -94,13 +94,13 @@ def get_build_result(data_graph, session_id): # otherwise, build the graph and return the result if session_id: logger.debug(f"Loading LangChain object from session {session_id}") - result = build_sorted_vertices(data_graph=data_graph, session_id=session_id) + result = build_sorted_vertices(data_graph=data_graph) if result is not None: logger.debug("Loaded LangChain object") return result logger.debug("Building langchain object") - return build_sorted_vertices(data_graph, session_id) + return build_sorted_vertices(data_graph) def load_langchain_object( diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 3344633ab..da67cbb1a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -8,7 +8,7 @@ from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS from langflow.template.frontend_node.chains import TimeTravelGuideChainNode -from tests.utils import run_post +from tests.utils import poll_task_status, run_post PROMPT_REQUEST = { @@ -409,11 +409,11 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): # session_id should be returned assert "session_id" in response.json() assert response.json()["session_id"] is not None + session_id1 = response.json()["session_id"] # New request with a different session_id # asking "What is my name?" should return "Gabriel" post_data = { "inputs": {"text": "What is my name?"}, - "session_id": "other session id", } response = client.post( f"api/v1/process/{added_flow.get('id')}", @@ -422,6 +422,7 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): ) assert response.status_code == 200, response.json() assert "Gabriel" not in response.json()["result"]["text"] + assert session_id1 != response.json()["session_id"] def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key): @@ -448,3 +449,30 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a response_json = run_post(client, flow_id, headers, post_data) assert name in response_json["result"]["text"] + + +# Test function without loop +def test_async_task_processing(client, added_flow, created_api_key): + headers = {"x-api-key": created_api_key.api_key} + post_data = {"inputs": {"text": "Hi, My name is Gabriel"}} + + # Run the /api/v1/process/{flow_id} endpoint with sync=False + response = client.post( + f"api/v1/process/{added_flow.get('id')}", + headers=headers, + json={**post_data, "sync": False}, + ) + assert response.status_code == 200, response.json() + + # Extract the task ID from the response + task_id = response.json().get("id") + assert task_id is not None + + # Polling the task status using the helper function + task_status_json = poll_task_status(client, headers, task_id) + assert task_status_json is not None, "Task did not complete in time" + + # Validate that the task completed successfully and the result is as expected + assert "result" in task_status_json, task_status_json + assert "text" in task_status_json["result"], task_status_json["result"] + assert "Gabriel" in task_status_json["result"]["text"], task_status_json["result"] diff --git a/tests/utils.py b/tests/utils.py index ce913c4b9..a4bae9863 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,6 @@ +import time + + def run_post(client, flow_id, headers, post_data): response = client.post( f"api/v1/process/{flow_id}", @@ -6,3 +9,19 @@ def run_post(client, flow_id, headers, post_data): ) assert response.status_code == 200, response.json() return response.json() + + +# Helper function to poll task status +def poll_task_status(client, headers, task_id, max_attempts=20, sleep_time=1): + for _ in range(max_attempts): + task_status_response = client.get( + f"api/v1/task/{task_id}/status", + headers=headers, + ) + if ( + task_status_response.status_code == 200 + and task_status_response.json()["status"] == "SUCCESS" + ): + return task_status_response.json() + time.sleep(sleep_time) + return None # Return None if task did not complete in time