🔧 fix(process.py): remove unused import and add missing import for BaseModel from pydantic
🔧 fix(manager.py): change return type annotation of launch_task method to handle both Coroutine and Any types 🔧 fix(worker.py): add missing import for Result class from process module and update return statement to return Result object as dictionary
This commit is contained in:
parent
b8d8eccbff
commit
05a6f4d067
3 changed files with 13 additions and 7 deletions
|
|
@ -1,4 +1,3 @@
|
|||
from dataclasses import dataclass
|
||||
import json
|
||||
from pathlib import Path
|
||||
from langchain.schema import AgentAction
|
||||
|
|
@ -14,6 +13,8 @@ from langchain.chains.base import Chain
|
|||
from langchain.vectorstores.base import VectorStore
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
|
|
@ -146,8 +147,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
|||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
class Result(BaseModel):
|
||||
result: Any
|
||||
session_id: str
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ async def process_graph_cached(
|
|||
# we need to update the cache with the updated langchain_object
|
||||
session_service.update_session(session_id, (langchain_object, artifacts))
|
||||
|
||||
return Result(result, session_id)
|
||||
return Result(result=result, session_id=session_id)
|
||||
|
||||
|
||||
def load_flow_from_json(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Coroutine, Union
|
||||
import logging
|
||||
|
||||
from langflow.services.base import Service
|
||||
|
|
@ -51,7 +51,8 @@ class TaskService(Service):
|
|||
async def launch_task(
|
||||
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
return await self.backend.launch_task(task_func, *args, **kwargs)
|
||||
task = self.backend.launch_task(task_func, *args, **kwargs)
|
||||
return await task if isinstance(task, Coroutine) else task
|
||||
|
||||
def get_task(self, task_id: Union[int, str]) -> Any:
|
||||
return self.backend.get_task(task_id)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||
from typing import TYPE_CHECKING
|
||||
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
|
||||
from langflow.processing.process import (
|
||||
Result,
|
||||
generate_result,
|
||||
process_inputs,
|
||||
)
|
||||
|
|
@ -44,6 +45,10 @@ def process_graph_cached_task(
|
|||
session_service = get_session_service()
|
||||
if clear_cache:
|
||||
session_service.clear_session(session_id)
|
||||
if session_id is None:
|
||||
session_id = session_service.generate_key(
|
||||
session_id=session_id, data_graph=data_graph
|
||||
)
|
||||
# Load the graph using SessionService
|
||||
langchain_object, artifacts = session_service.load_session(session_id, data_graph)
|
||||
processed_inputs = process_inputs(inputs, artifacts)
|
||||
|
|
@ -52,4 +57,4 @@ def process_graph_cached_task(
|
|||
# we need to update the cache with the updated langchain_object
|
||||
session_service.update_session(session_id, (langchain_object, artifacts))
|
||||
|
||||
return result, session_id
|
||||
return Result(result=result, session_id=session_id).dict()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue