🐛 fix(process.py): remove unnecessary session_id parameter in build_sorted_vertices function calls
🐛 fix(endpoints.py): remove unused import and type hinting to improve code readability and maintainability 🐛 fix(endpoints.py): fix incorrect import statement for TaskService 🐛 fix(endpoints.py): fix incorrect return value for get_task_status function ✨ feat(endpoints.py): add test case for async task processing to validate task completion and result
This commit is contained in:
parent
8d1bff38fe
commit
fa1ebea378
4 changed files with 63 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
|
||||
|
||||
|
|
@ -33,8 +33,8 @@ from langflow.services.utils import get_session
|
|||
from langflow.worker import process_graph_cached_task
|
||||
from sqlmodel import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.task.manager import TaskService
|
||||
|
||||
from langflow.services.task.manager import TaskService
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
|
@ -183,11 +183,17 @@ async def process_flow(
|
|||
async def get_task_status(task_id: str):
|
||||
task_service = get_task_service()
|
||||
task = task_service.get_task(task_id)
|
||||
result = None
|
||||
if task.ready():
|
||||
result = task.result
|
||||
if isinstance(result, dict) and "result" in result:
|
||||
result = result["result"]
|
||||
elif hasattr(result, "result"):
|
||||
result = result.result
|
||||
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return TaskStatusResponse(
|
||||
status=task.status, result=task.result if task.ready() else None
|
||||
)
|
||||
return TaskStatusResponse(status=task.status, result=result)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
|||
|
|
@ -94,13 +94,13 @@ def get_build_result(data_graph, session_id):
|
|||
# otherwise, build the graph and return the result
|
||||
if session_id:
|
||||
logger.debug(f"Loading LangChain object from session {session_id}")
|
||||
result = build_sorted_vertices(data_graph=data_graph, session_id=session_id)
|
||||
result = build_sorted_vertices(data_graph=data_graph)
|
||||
if result is not None:
|
||||
logger.debug("Loaded LangChain object")
|
||||
return result
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
return build_sorted_vertices(data_graph, session_id)
|
||||
return build_sorted_vertices(data_graph)
|
||||
|
||||
|
||||
def load_langchain_object(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from fastapi.testclient import TestClient
|
|||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
from tests.utils import run_post
|
||||
from tests.utils import poll_task_status, run_post
|
||||
|
||||
|
||||
PROMPT_REQUEST = {
|
||||
|
|
@ -409,11 +409,11 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
|
|||
# session_id should be returned
|
||||
assert "session_id" in response.json()
|
||||
assert response.json()["session_id"] is not None
|
||||
session_id1 = response.json()["session_id"]
|
||||
# New request with a different session_id
|
||||
# asking "What is my name?" should return "Gabriel"
|
||||
post_data = {
|
||||
"inputs": {"text": "What is my name?"},
|
||||
"session_id": "other session id",
|
||||
}
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
|
|
@ -422,6 +422,7 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
|
|||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert "Gabriel" not in response.json()["result"]["text"]
|
||||
assert session_id1 != response.json()["session_id"]
|
||||
|
||||
|
||||
def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key):
|
||||
|
|
@ -448,3 +449,30 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a
|
|||
response_json = run_post(client, flow_id, headers, post_data)
|
||||
|
||||
assert name in response_json["result"]["text"]
|
||||
|
||||
|
||||
# Test function without loop
|
||||
def test_async_task_processing(client, added_flow, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
|
||||
|
||||
# Run the /api/v1/process/{flow_id} endpoint with sync=False
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json={**post_data, "sync": False},
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Extract the task ID from the response
|
||||
task_id = response.json().get("id")
|
||||
assert task_id is not None
|
||||
|
||||
# Polling the task status using the helper function
|
||||
task_status_json = poll_task_status(client, headers, task_id)
|
||||
assert task_status_json is not None, "Task did not complete in time"
|
||||
|
||||
# Validate that the task completed successfully and the result is as expected
|
||||
assert "result" in task_status_json, task_status_json
|
||||
assert "text" in task_status_json["result"], task_status_json["result"]
|
||||
assert "Gabriel" in task_status_json["result"]["text"], task_status_json["result"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import time
|
||||
|
||||
|
||||
def run_post(client, flow_id, headers, post_data):
|
||||
response = client.post(
|
||||
f"api/v1/process/{flow_id}",
|
||||
|
|
@ -6,3 +9,19 @@ def run_post(client, flow_id, headers, post_data):
|
|||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
return response.json()
|
||||
|
||||
|
||||
# Helper function to poll task status
|
||||
def poll_task_status(client, headers, task_id, max_attempts=20, sleep_time=1):
|
||||
for _ in range(max_attempts):
|
||||
task_status_response = client.get(
|
||||
f"api/v1/task/{task_id}/status",
|
||||
headers=headers,
|
||||
)
|
||||
if (
|
||||
task_status_response.status_code == 200
|
||||
and task_status_response.json()["status"] == "SUCCESS"
|
||||
):
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue