fix: Fix issues with use of async (#4296)
* Fix issues with use of async * Update src/backend/base/langflow/custom/custom_component/component.py Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> --------- Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
73b5cc0098
commit
d567b8518c
11 changed files with 51 additions and 51 deletions
|
|
@ -40,6 +40,7 @@ from langflow.exceptions.component import ComponentBuildError
|
|||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.utils import log_vertex_build
|
||||
from langflow.schema.schema import OutputValue
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
|
||||
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
|
||||
|
|
@ -493,7 +494,7 @@ async def build_vertex(
|
|||
error_message = None
|
||||
try:
|
||||
cache = await chat_service.get_cache(flow_id_str)
|
||||
if not cache:
|
||||
if isinstance(cache, CacheMiss):
|
||||
# If there's no cache
|
||||
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
|
||||
graph: Graph = await build_graph_from_db(
|
||||
|
|
@ -621,7 +622,7 @@ async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService
|
|||
yield str(StreamData(event="error", data={"error": str(exc)}))
|
||||
return
|
||||
|
||||
if not cache:
|
||||
if isinstance(cache, CacheMiss):
|
||||
# If there's no cache
|
||||
msg = f"No cache found for {flow_id}."
|
||||
logger.error(msg)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import requests
|
|||
from astra_assistants import OpenAIWithDefaultKey, patch
|
||||
from astra_assistants.tools.tool_interface import ToolInterface
|
||||
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
|
||||
client_lock = threading.Lock()
|
||||
client = None
|
||||
|
||||
|
|
@ -17,7 +19,7 @@ client = None
|
|||
def get_patched_openai_client(shared_component_cache):
|
||||
os.environ["ASTRA_ASSISTANTS_QUIET"] = "true"
|
||||
client = shared_component_cache.get("client")
|
||||
if client is None:
|
||||
if isinstance(client, CacheMiss):
|
||||
client = patch(OpenAIWithDefaultKey())
|
||||
shared_component_cache.set("client", client)
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from textwrap import dedent
|
||||
|
|
@ -506,11 +507,10 @@ class Component(CustomComponent):
|
|||
async def _run(self):
|
||||
# Resolve callable inputs
|
||||
for key, _input in self._inputs.items():
|
||||
if callable(_input.value):
|
||||
result = _input.value()
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
self._inputs[key].value = result
|
||||
if asyncio.iscoroutinefunction(_input.value):
|
||||
self._inputs[key].value = await _input.value()
|
||||
elif callable(_input.value):
|
||||
self._inputs[key].value = await asyncio.to_thread(_input.value)
|
||||
|
||||
self.set_attributes({})
|
||||
|
||||
|
|
@ -718,10 +718,11 @@ class Component(CustomComponent):
|
|||
_results[output.name] = output.value
|
||||
result = output.value
|
||||
else:
|
||||
result = method()
|
||||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
result = await method()
|
||||
else:
|
||||
result = await asyncio.to_thread(method)
|
||||
if (
|
||||
self._vertex is not None
|
||||
and isinstance(result, Message)
|
||||
|
|
|
|||
|
|
@ -1356,7 +1356,7 @@ class Graph:
|
|||
if get_cache is not None:
|
||||
cached_result = await get_cache(key=vertex.id)
|
||||
else:
|
||||
cached_result = None
|
||||
cached_result = CacheMiss()
|
||||
if isinstance(cached_result, CacheMiss):
|
||||
should_build = True
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -814,10 +814,7 @@ class Vertex:
|
|||
# Run steps
|
||||
for step in self.steps:
|
||||
if step not in self.steps_ran:
|
||||
if inspect.iscoroutinefunction(step):
|
||||
await step(user_id=user_id, event_manager=event_manager, **kwargs)
|
||||
else:
|
||||
step(user_id=user_id, event_manager=event_manager, **kwargs)
|
||||
await step(user_id=user_id, event_manager=event_manager, **kwargs)
|
||||
self.steps_ran.append(step)
|
||||
|
||||
self.finalize_build()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class CacheService(Service, Generic[LockType]):
|
|||
lock: A lock to use for the operation.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
The value associated with the key, or CACHE_MISS if the key is not found.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
@ -121,7 +121,7 @@ class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
|
|||
lock: A lock to use for the operation.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
The value associated with the key, or CACHE_MISS if the key is not found.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
|||
14
src/backend/base/langflow/services/cache/disk.py
vendored
14
src/backend/base/langflow/services/cache/disk.py
vendored
|
|
@ -26,18 +26,18 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
async def get(self, key, lock: asyncio.Lock | None = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
return await self._get(key)
|
||||
return await asyncio.to_thread(self._get, key)
|
||||
else:
|
||||
return await self._get(key)
|
||||
return await asyncio.to_thread(self._get, key)
|
||||
|
||||
async def _get(self, key):
|
||||
item = await asyncio.to_thread(self.cache.get, key, default=None)
|
||||
def _get(self, key):
|
||||
item = self.cache.get(key, default=None)
|
||||
if item:
|
||||
if time.time() - item["time"] < self.expiration_time:
|
||||
await asyncio.to_thread(self.cache.touch, key) # Refresh the expiry time
|
||||
self.cache.touch(key) # Refresh the expiry time
|
||||
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
|
||||
logger.info(f"Cache item for key '{key}' has expired and will be deleted.")
|
||||
await self._delete(key) # Log before deleting the expired item
|
||||
self.cache.delete(key) # Log before deleting the expired item
|
||||
return CACHE_MISS
|
||||
|
||||
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
|
||||
|
|
@ -81,7 +81,7 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
await self._upsert(key, value)
|
||||
|
||||
async def _upsert(self, key, value) -> None:
|
||||
existing_value = await self.get(key)
|
||||
existing_value = await asyncio.to_thread(self._get, key)
|
||||
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
value = existing_value
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
lock: A lock to use for the operation.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found or the item has expired.
|
||||
The value associated with the key, or CACHE_MISS if the key is not found or the item has expired.
|
||||
"""
|
||||
with lock or self._lock:
|
||||
return self._get_without_lock(key)
|
||||
|
|
@ -70,7 +70,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
# Check if the value is pickled
|
||||
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
|
||||
self.delete(key)
|
||||
return None
|
||||
return CACHE_MISS
|
||||
|
||||
def set(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
|
||||
"""Add an item to the cache.
|
||||
|
|
@ -105,7 +105,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
"""
|
||||
with lock or self._lock:
|
||||
existing_value = self._get_without_lock(key)
|
||||
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
value = existing_value
|
||||
|
||||
|
|
@ -233,9 +233,9 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
@override
|
||||
async def get(self, key, lock=None):
|
||||
if key is None:
|
||||
return None
|
||||
return CACHE_MISS
|
||||
value = await self._client.get(str(key))
|
||||
return pickle.loads(value) if value else None
|
||||
return pickle.loads(value) if value else CACHE_MISS
|
||||
|
||||
@override
|
||||
async def set(self, key, value, lock=None) -> None:
|
||||
|
|
|
|||
|
|
@ -63,5 +63,5 @@ class ChatService(Service):
|
|||
lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None.
|
||||
"""
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
|
||||
return await self.cache_service.delete(key, lock=lock or self.async_cache_locks[key])
|
||||
return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key])
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -95,9 +94,7 @@ class ServiceManager:
|
|||
continue
|
||||
logger.debug(f"Teardown service {service.name}")
|
||||
try:
|
||||
result = service.teardown()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
await service.teardown()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(exc)
|
||||
self.services = {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from collections.abc import Coroutine
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.cache.base import AsyncBaseCacheService
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
from langflow.services.session.utils import compute_dict_hash, session_id_generator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -12,21 +14,21 @@ class SessionService(Service):
|
|||
name = "session_service"
|
||||
|
||||
def __init__(self, cache_service) -> None:
|
||||
self.cache_service: CacheService = cache_service
|
||||
self.cache_service: CacheService | AsyncBaseCacheService = cache_service
|
||||
|
||||
async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
|
||||
# Check if the data is cached
|
||||
is_cached = self.cache_service.contains(key)
|
||||
if isinstance(is_cached, Coroutine):
|
||||
if await is_cached:
|
||||
return await self.cache_service.get(key)
|
||||
elif is_cached:
|
||||
return self.cache_service.get(key)
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
value = await self.cache_service.get(key)
|
||||
else:
|
||||
value = await asyncio.to_thread(self.cache_service.get, key)
|
||||
if not isinstance(value, CacheMiss):
|
||||
return value
|
||||
|
||||
if key is None:
|
||||
key = self.generate_key(session_id=None, data_graph=data_graph)
|
||||
if data_graph is None:
|
||||
return (None, None)
|
||||
return None, None
|
||||
# If not cached, build the graph and cache it
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
||||
|
|
@ -48,13 +50,13 @@ class SessionService(Service):
|
|||
return self.build_key(session_id, data_graph=data_graph)
|
||||
|
||||
async def update_session(self, session_id, value) -> None:
|
||||
result = self.cache_service.set(session_id, value)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
await result
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
await self.cache_service.set(session_id, value)
|
||||
else:
|
||||
await asyncio.to_thread(self.cache_service.set, session_id, value)
|
||||
|
||||
async def clear_session(self, session_id) -> None:
|
||||
result = self.cache_service.delete(session_id)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
await result
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
await self.cache_service.delete(session_id)
|
||||
else:
|
||||
await asyncio.to_thread(self.cache_service.delete, session_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue