🔧 fix(process.py): remove unused import and add missing import for BaseModel from pydantic

🔧 fix(manager.py): change return type annotation of launch_task method to handle both Coroutine and Any types
🔧 fix(worker.py): add missing import for Result class from process module and update return statement to return Result object as dictionary
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-16 11:13:53 -03:00
commit 05a6f4d067
3 changed files with 13 additions and 7 deletions

View file

@ -1,4 +1,3 @@
from dataclasses import dataclass
import json
from pathlib import Path
from langchain.schema import AgentAction
@ -14,6 +13,8 @@ from langchain.chains.base import Chain
from langchain.vectorstores.base import VectorStore
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
"""
@ -146,8 +147,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
return result
@dataclass
class Result:
class Result(BaseModel):
result: Any
session_id: str
@ -173,7 +173,7 @@ async def process_graph_cached(
# we need to update the cache with the updated langchain_object
session_service.update_session(session_id, (langchain_object, artifacts))
return Result(result, session_id)
return Result(result=result, session_id=session_id)
def load_flow_from_json(

View file

@ -1,4 +1,4 @@
from typing import Any, Callable, Union
from typing import Any, Callable, Coroutine, Union
import logging
from langflow.services.base import Service
@ -51,7 +51,8 @@ class TaskService(Service):
async def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
return await self.backend.launch_task(task_func, *args, **kwargs)
task = self.backend.launch_task(task_func, *args, **kwargs)
return await task if isinstance(task, Coroutine) else task
def get_task(self, task_id: Union[int, str]) -> Any:
return self.backend.get_task(task_id)

View file

@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
from langflow.processing.process import (
Result,
generate_result,
process_inputs,
)
@ -44,6 +45,10 @@ def process_graph_cached_task(
session_service = get_session_service()
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
# Load the graph using SessionService
langchain_object, artifacts = session_service.load_session(session_id, data_graph)
processed_inputs = process_inputs(inputs, artifacts)
@ -52,4 +57,4 @@ def process_graph_cached_task(
# we need to update the cache with the updated langchain_object
session_service.update_session(session_id, (langchain_object, artifacts))
return result, session_id
return Result(result=result, session_id=session_id).dict()