Refactor process_graph_cached_task function
This commit is contained in:
parent
a255afbabe
commit
ff86d02a92
1 changed files with 10 additions and 9 deletions
|
|
@ -3,10 +3,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|||
from asgiref.sync import async_to_sync
|
||||
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
|
||||
from langflow.core.celery_app import celery_app
|
||||
from langflow.processing.process import generate_result, process_inputs
|
||||
from langflow.processing.process import Result, generate_result, process_inputs
|
||||
from langflow.services.deps import get_session_service
|
||||
from langflow.services.manager import initialize_session_service
|
||||
from loguru import logger
|
||||
from rich import print
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
|
@ -30,9 +31,8 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
|
|||
raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e
|
||||
|
||||
|
||||
@celery_app.task(bind=True, acks_late=True)
|
||||
async def process_graph_cached_task(
|
||||
self,
|
||||
@celery_app.task(acks_late=True)
|
||||
def process_graph_cached_task(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
|
|
@ -49,7 +49,7 @@ async def process_graph_cached_task(
|
|||
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
|
||||
|
||||
# Use async_to_sync to handle the asynchronous part of the session service
|
||||
session_data = await session_service.load_session(session_id, data_graph)
|
||||
session_data = async_to_sync(session_service.load_session, force_new_loop=True)(session_id, data_graph)
|
||||
logger.warning(f"session_data: {session_data}")
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
|
||||
|
|
@ -57,17 +57,18 @@ async def process_graph_cached_task(
|
|||
raise ValueError("Graph not found in the session")
|
||||
|
||||
# Use async_to_sync for the asynchronous build method
|
||||
built_object = await graph.build()
|
||||
built_object = async_to_sync(graph.build, force_new_loop=True)()
|
||||
|
||||
logger.info(f"Built object: {built_object}")
|
||||
logger.debug(f"Built object: {built_object}")
|
||||
|
||||
processed_inputs = process_inputs(inputs, artifacts or {})
|
||||
result = generate_result(built_object, processed_inputs)
|
||||
|
||||
# Update the session with the new data
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
|
||||
return {"result": result, "session_id": session_id}
|
||||
result_object = Result(result=result, session_id=session_id).model_dump()
|
||||
print(f"Result object: {result_object}")
|
||||
return result_object
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_graph_cached_task: {e}")
|
||||
# Handle the exception as needed, maybe re-raise or return an error message
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue