From fa1ebea378d043a1f62a48fb0bb218328fd09a3c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 15 Sep 2023 18:16:25 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(process.py):=20remove=20unne?= =?UTF-8?q?cessary=20session=5Fid=20parameter=20in=20build=5Fsorted=5Fvert?= =?UTF-8?q?ices=20function=20calls=20=F0=9F=90=9B=20fix(endpoints.py):=20r?= =?UTF-8?q?emove=20unused=20import=20and=20type=20hinting=20to=20improve?= =?UTF-8?q?=20code=20readability=20and=20maintainability=20=F0=9F=90=9B=20?= =?UTF-8?q?fix(endpoints.py):=20fix=20incorrect=20import=20statement=20for?= =?UTF-8?q?=20TaskService=20=F0=9F=90=9B=20fix(endpoints.py):=20fix=20inco?= =?UTF-8?q?rrect=20return=20value=20for=20get=5Ftask=5Fstatus=20function?= =?UTF-8?q?=20=E2=9C=A8=20feat(endpoints.py):=20add=20test=20case=20for=20?= =?UTF-8?q?async=20task=20processing=20to=20validate=20task=20completion?= =?UTF-8?q?=20and=20result?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/endpoints.py | 18 ++++++++---- src/backend/langflow/processing/process.py | 4 +-- tests/test_endpoints.py | 32 ++++++++++++++++++++-- tests/utils.py | 19 +++++++++++++ 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index d7fed1a70..ec107092b 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional, Union from langflow.services.auth.utils import api_key_security, get_current_active_user @@ -33,8 +33,8 @@ from langflow.services.utils import get_session from langflow.worker import process_graph_cached_task from sqlmodel import Session -if TYPE_CHECKING: - from langflow.services.task.manager import TaskService + +from langflow.services.task.manager import TaskService # build router router = APIRouter(tags=["Base"]) @@ -183,11 +183,17 @@ async def process_flow( async def get_task_status(task_id: str): task_service = get_task_service() task = task_service.get_task(task_id) + result = None + if task.ready(): + result = task.result + if isinstance(result, dict) and "result" in result: + result = result["result"] + elif hasattr(result, "result"): + result = result.result + if task is None: raise HTTPException(status_code=404, detail="Task not found") - return TaskStatusResponse( - status=task.status, result=task.result if task.ready() else None - ) + return TaskStatusResponse(status=task.status, result=result) @router.post( diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 877d32df2..6490f02fb 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -94,13 +94,13 @@ def get_build_result(data_graph, session_id): # otherwise, build the graph and return the result if session_id: logger.debug(f"Loading LangChain object from session {session_id}") - result = build_sorted_vertices(data_graph=data_graph, session_id=session_id) + result = build_sorted_vertices(data_graph=data_graph) if result is not None: logger.debug("Loaded LangChain object") return result logger.debug("Building langchain object") - return build_sorted_vertices(data_graph, session_id) + return build_sorted_vertices(data_graph) def load_langchain_object( diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 3344633ab..da67cbb1a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -8,7 +8,7 @@ from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS from langflow.template.frontend_node.chains import TimeTravelGuideChainNode -from tests.utils import run_post +from tests.utils import poll_task_status, run_post PROMPT_REQUEST = { @@ -409,11 +409,11 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): # session_id should be returned assert "session_id" in response.json() assert response.json()["session_id"] is not None + session_id1 = response.json()["session_id"] # New request with a different session_id # asking "What is my name?" should return "Gabriel" post_data = { "inputs": {"text": "What is my name?"}, - "session_id": "other session id", } response = client.post( f"api/v1/process/{added_flow.get('id')}", @@ -422,6 +422,7 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): ) assert response.status_code == 200, response.json() assert "Gabriel" not in response.json()["result"]["text"] + assert session_id1 != response.json()["session_id"] def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key): @@ -448,3 +449,30 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a response_json = run_post(client, flow_id, headers, post_data) assert name in response_json["result"]["text"] + + +# Test function without loop +def test_async_task_processing(client, added_flow, created_api_key): + headers = {"x-api-key": created_api_key.api_key} + post_data = {"inputs": {"text": "Hi, My name is Gabriel"}} + + # Run the /api/v1/process/{flow_id} endpoint with sync=False + response = client.post( + f"api/v1/process/{added_flow.get('id')}", + headers=headers, + json={**post_data, "sync": False}, + ) + assert response.status_code == 200, response.json() + + # Extract the task ID from the response + task_id = response.json().get("id") + assert task_id is not None + + # Polling the task status using the helper function + task_status_json = poll_task_status(client, headers, task_id) + assert task_status_json is not None, "Task did not complete in time" + + # Validate that the task completed successfully and the result is as expected + assert "result" in task_status_json, task_status_json + assert "text" in task_status_json["result"], task_status_json["result"] + assert "Gabriel" in task_status_json["result"]["text"], task_status_json["result"] diff --git a/tests/utils.py b/tests/utils.py index ce913c4b9..a4bae9863 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,6 @@ +import time + + def run_post(client, flow_id, headers, post_data): response = client.post( f"api/v1/process/{flow_id}", @@ -6,3 +9,19 @@ def run_post(client, flow_id, headers, post_data): ) assert response.status_code == 200, response.json() return response.json() + + +# Helper function to poll task status +def poll_task_status(client, headers, task_id, max_attempts=20, sleep_time=1): + for _ in range(max_attempts): + task_status_response = client.get( + f"api/v1/task/{task_id}/status", + headers=headers, + ) + if ( + task_status_response.status_code == 200 + and task_status_response.json()["status"] == "SUCCESS" + ): + return task_status_response.json() + time.sleep(sleep_time) + return None # Return None if task did not complete in time