Refactor process_graph_cached_task function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-28 14:34:37 -03:00
commit 380a42e988

View file

@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
from asgiref.sync import async_to_sync
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
from langflow.core.celery_app import celery_app
from langflow.processing.process import Result, generate_result, process_inputs
from langflow.processing.process import generate_result, process_inputs
from langflow.services.deps import get_session_service
from langflow.services.manager import initialize_session_service
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
@ -36,19 +37,37 @@ def process_graph_cached_task(
clear_cache=False,
session_id=None,
) -> Dict[str, Any]:
initialize_session_service()
session_service = get_session_service()
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
graph, artifacts = async_to_sync(session_service.load_session)(session_id, data_graph)
built_object = graph.build()
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(built_object, processed_inputs)
# langchain_object is now updated with the new memory
# we need to update the cache with the updated langchain_object
session_service.update_session(session_id, (graph, artifacts))
try:
initialize_session_service()
session_service = get_session_service()
return Result(result=result, session_id=session_id).model_dump()
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Use async_to_sync to handle the asynchronous part of the session service
session_data = async_to_sync(session_service.load_session)(session_id, data_graph)
logger.warning(f"session_data: {session_data}")
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError("Graph not found in the session")
# Use async_to_sync for the asynchronous build method
built_object = async_to_sync(graph.build)()
logger.info(f"Built object: {built_object}")
processed_inputs = process_inputs(inputs, artifacts or {})
result = generate_result(built_object, processed_inputs)
# Update the session with the new data
session_service.update_session(session_id, (graph, artifacts))
return {"result": result, "session_id": session_id}
except Exception as e:
logger.error(f"Error in process_graph_cached_task: {e}")
# Handle the exception as needed, maybe re-raise or return an error message
raise