From ff86d02a92c7fc3c96c283e53acea6840da7f267 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 28 Nov 2023 19:38:22 -0300 Subject: [PATCH] Refactor process_graph_cached_task function --- src/backend/langflow/worker.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index be4a7e318..a66c35b78 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -3,10 +3,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from asgiref.sync import async_to_sync from celery.exceptions import SoftTimeLimitExceeded # type: ignore from langflow.core.celery_app import celery_app -from langflow.processing.process import generate_result, process_inputs +from langflow.processing.process import Result, generate_result, process_inputs from langflow.services.deps import get_session_service from langflow.services.manager import initialize_session_service from loguru import logger +from rich import print if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex @@ -30,9 +31,8 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex": raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e -@celery_app.task(bind=True, acks_late=True) -async def process_graph_cached_task( - self, +@celery_app.task(acks_late=True) +def process_graph_cached_task( data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False, @@ -49,7 +49,7 @@ async def process_graph_cached_task( session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) # Use async_to_sync to handle the asynchronous part of the session service - session_data = await session_service.load_session(session_id, data_graph) + session_data = async_to_sync(session_service.load_session, force_new_loop=True)(session_id, data_graph) logger.warning(f"session_data: {session_data}") graph, artifacts = session_data if session_data else (None, None) @@ -57,17 +57,18 @@ async def process_graph_cached_task( raise ValueError("Graph not found in the session") # Use async_to_sync for the asynchronous build method - built_object = await graph.build() + built_object = async_to_sync(graph.build, force_new_loop=True)() - logger.info(f"Built object: {built_object}") + logger.debug(f"Built object: {built_object}") processed_inputs = process_inputs(inputs, artifacts or {}) result = generate_result(built_object, processed_inputs) # Update the session with the new data session_service.update_session(session_id, (graph, artifacts)) - - return {"result": result, "session_id": session_id} + result_object = Result(result=result, session_id=session_id).model_dump() + print(f"Result object: {result_object}") + return result_object except Exception as e: logger.error(f"Error in process_graph_cached_task: {e}") # Handle the exception as needed, maybe re-raise or return an error message