From 380a42e988cf5684511818d60e216fa77762cb0c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 28 Nov 2023 14:34:37 -0300 Subject: [PATCH] Refactor process_graph_cached_task function --- src/backend/langflow/worker.py | 51 +++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index b9d646184..8e07c0e55 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from asgiref.sync import async_to_sync from celery.exceptions import SoftTimeLimitExceeded # type: ignore from langflow.core.celery_app import celery_app -from langflow.processing.process import Result, generate_result, process_inputs +from langflow.processing.process import generate_result, process_inputs from langflow.services.deps import get_session_service from langflow.services.manager import initialize_session_service +from loguru import logger if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex @@ -36,19 +37,37 @@ def process_graph_cached_task( clear_cache=False, session_id=None, ) -> Dict[str, Any]: - initialize_session_service() - session_service = get_session_service() - if clear_cache: - session_service.clear_session(session_id) - if session_id is None: - session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) - # Load the graph using SessionService - graph, artifacts = async_to_sync(session_service.load_session)(session_id, data_graph) - built_object = graph.build() - processed_inputs = process_inputs(inputs, artifacts) - result = generate_result(built_object, processed_inputs) - # langchain_object is now updated with the new memory - # we need to update the cache with the updated langchain_object - session_service.update_session(session_id, (graph, artifacts)) + try: + initialize_session_service() + session_service = get_session_service() - return Result(result=result, session_id=session_id).model_dump() + if clear_cache: + session_service.clear_session(session_id) + + if session_id is None: + session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) + + # Use async_to_sync to handle the asynchronous part of the session service + session_data = async_to_sync(session_service.load_session)(session_id, data_graph) + logger.warning(f"session_data: {session_data}") + graph, artifacts = session_data if session_data else (None, None) + + if not graph: + raise ValueError("Graph not found in the session") + + # Use async_to_sync for the asynchronous build method + built_object = async_to_sync(graph.build)() + + logger.info(f"Built object: {built_object}") + + processed_inputs = process_inputs(inputs, artifacts or {}) + result = generate_result(built_object, processed_inputs) + + # Update the session with the new data + session_service.update_session(session_id, (graph, artifacts)) + + return {"result": result, "session_id": session_id} + except Exception as e: + logger.error(f"Error in process_graph_cached_task: {e}") + # Handle the exception as needed, maybe re-raise or return an error message + raise