diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 6490f02fb..0fefa8deb 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass import json from pathlib import Path from langchain.schema import AgentAction @@ -14,6 +13,8 @@ from langchain.chains.base import Chain from langchain.vectorstores.base import VectorStore from typing import Any, Dict, List, Optional, Tuple, Union +from pydantic import BaseModel + def fix_memory_inputs(langchain_object): """ @@ -146,8 +147,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): return result -@dataclass -class Result: +class Result(BaseModel): result: Any session_id: str @@ -173,7 +173,7 @@ async def process_graph_cached( # we need to update the cache with the updated langchain_object session_service.update_session(session_id, (langchain_object, artifacts)) - return Result(result, session_id) + return Result(result=result, session_id=session_id) def load_flow_from_json( diff --git a/src/backend/langflow/services/task/manager.py b/src/backend/langflow/services/task/manager.py index e0448ab66..4083978df 100644 --- a/src/backend/langflow/services/task/manager.py +++ b/src/backend/langflow/services/task/manager.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Coroutine, Union import logging from langflow.services.base import Service @@ -51,7 +51,8 @@ class TaskService(Service): async def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: - return await self.backend.launch_task(task_func, *args, **kwargs) + task = self.backend.launch_task(task_func, *args, **kwargs) + return await task if isinstance(task, Coroutine) else task def get_task(self, task_id: Union[int, str]) -> Any: return self.backend.get_task(task_id) diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index ad847a01c..018cfa5e7 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple from typing import TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded # type: ignore from langflow.processing.process import ( + Result, generate_result, process_inputs, ) @@ -44,6 +45,10 @@ def process_graph_cached_task( session_service = get_session_service() if clear_cache: session_service.clear_session(session_id) + if session_id is None: + session_id = session_service.generate_key( + session_id=session_id, data_graph=data_graph + ) # Load the graph using SessionService langchain_object, artifacts = session_service.load_session(session_id, data_graph) processed_inputs = process_inputs(inputs, artifacts) @@ -52,4 +57,4 @@ def process_graph_cached_task( # we need to update the cache with the updated langchain_object session_service.update_session(session_id, (langchain_object, artifacts)) - return result, session_id + return Result(result=result, session_id=session_id).dict()