From 05a6f4d0679566a8e19ef0d554e122812f132b13 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 16 Sep 2023 11:13:53 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(process.py):=20remove=20unus?= =?UTF-8?q?ed=20import=20and=20add=20missing=20import=20for=20BaseModel=20?= =?UTF-8?q?from=20pydantic=20=F0=9F=94=A7=20fix(manager.py):=20change=20re?= =?UTF-8?q?turn=20type=20annotation=20of=20launch=5Ftask=20method=20to=20h?= =?UTF-8?q?andle=20both=20Coroutine=20and=20Any=20types=20=F0=9F=94=A7=20f?= =?UTF-8?q?ix(worker.py):=20add=20missing=20import=20for=20Result=20class?= =?UTF-8?q?=20from=20process=20module=20and=20update=20return=20statement?= =?UTF-8?q?=20to=20return=20Result=20object=20as=20dictionary?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/processing/process.py | 8 ++++---- src/backend/langflow/services/task/manager.py | 5 +++-- src/backend/langflow/worker.py | 7 ++++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 6490f02fb..0fefa8deb 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass import json from pathlib import Path from langchain.schema import AgentAction @@ -14,6 +13,8 @@ from langchain.chains.base import Chain from langchain.vectorstores.base import VectorStore from typing import Any, Dict, List, Optional, Tuple, Union +from pydantic import BaseModel + def fix_memory_inputs(langchain_object): """ @@ -146,8 +147,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): return result -@dataclass -class Result: +class Result(BaseModel): result: Any session_id: str @@ -173,7 +173,7 @@ async def process_graph_cached( # we need to update the cache with the updated langchain_object session_service.update_session(session_id, (langchain_object, artifacts)) - return Result(result, session_id) + return Result(result=result, session_id=session_id) def load_flow_from_json( diff --git a/src/backend/langflow/services/task/manager.py b/src/backend/langflow/services/task/manager.py index e0448ab66..4083978df 100644 --- a/src/backend/langflow/services/task/manager.py +++ b/src/backend/langflow/services/task/manager.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Coroutine, Union import logging from langflow.services.base import Service @@ -51,7 +51,8 @@ class TaskService(Service): async def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: - return await self.backend.launch_task(task_func, *args, **kwargs) + task = self.backend.launch_task(task_func, *args, **kwargs) + return await task if isinstance(task, Coroutine) else task def get_task(self, task_id: Union[int, str]) -> Any: return self.backend.get_task(task_id) diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index ad847a01c..018cfa5e7 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple from typing import TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded # type: ignore from langflow.processing.process import ( + Result, generate_result, process_inputs, ) @@ -44,6 +45,10 @@ def process_graph_cached_task( session_service = get_session_service() if clear_cache: session_service.clear_session(session_id) + if session_id is None: + session_id = session_service.generate_key( + session_id=session_id, data_graph=data_graph + ) # Load the graph using SessionService langchain_object, artifacts = session_service.load_session(session_id, data_graph) processed_inputs = process_inputs(inputs, artifacts) @@ -52,4 +57,4 @@ def process_graph_cached_task( # we need to update the cache with the updated langchain_object session_service.update_session(session_id, (langchain_object, artifacts)) - return result, session_id + return Result(result=result, session_id=session_id).dict()