fix: Fix issues with use of async (#4296)

* Fix issues with use of async

* Update src/backend/base/langflow/custom/custom_component/component.py

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
Christophe Bornet 2024-10-27 15:30:33 +01:00 committed by GitHub
commit d567b8518c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 51 additions and 51 deletions

View file

@ -40,6 +40,7 @@ from langflow.exceptions.component import ComponentBuildError
from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
from langflow.schema.schema import OutputValue
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
@ -493,7 +494,7 @@ async def build_vertex(
error_message = None
try:
cache = await chat_service.get_cache(flow_id_str)
if not cache:
if isinstance(cache, CacheMiss):
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: Graph = await build_graph_from_db(
@ -621,7 +622,7 @@ async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService
yield str(StreamData(event="error", data={"error": str(exc)}))
return
if not cache:
if isinstance(cache, CacheMiss):
# If there's no cache
msg = f"No cache found for {flow_id}."
logger.error(msg)

View file

@ -10,6 +10,8 @@ import requests
from astra_assistants import OpenAIWithDefaultKey, patch
from astra_assistants.tools.tool_interface import ToolInterface
from langflow.services.cache.utils import CacheMiss
client_lock = threading.Lock()
client = None
@ -17,7 +19,7 @@ client = None
def get_patched_openai_client(shared_component_cache):
os.environ["ASTRA_ASSISTANTS_QUIET"] = "true"
client = shared_component_cache.get("client")
if client is None:
if isinstance(client, CacheMiss):
client = patch(OpenAIWithDefaultKey())
shared_component_cache.set("client", client)
return client

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
import asyncio
import inspect
from copy import deepcopy
from textwrap import dedent
@ -506,11 +507,10 @@ class Component(CustomComponent):
async def _run(self):
# Resolve callable inputs
for key, _input in self._inputs.items():
if callable(_input.value):
result = _input.value()
if inspect.iscoroutine(result):
result = await result
self._inputs[key].value = result
if asyncio.iscoroutinefunction(_input.value):
self._inputs[key].value = await _input.value()
elif callable(_input.value):
self._inputs[key].value = await asyncio.to_thread(_input.value)
self.set_attributes({})
@ -718,10 +718,11 @@ class Component(CustomComponent):
_results[output.name] = output.value
result = output.value
else:
result = method()
# If the method is asynchronous, we need to await it
if inspect.iscoroutinefunction(method):
result = await result
result = await method()
else:
result = await asyncio.to_thread(method)
if (
self._vertex is not None
and isinstance(result, Message)

View file

@ -1356,7 +1356,7 @@ class Graph:
if get_cache is not None:
cached_result = await get_cache(key=vertex.id)
else:
cached_result = None
cached_result = CacheMiss()
if isinstance(cached_result, CacheMiss):
should_build = True
else:

View file

@ -814,10 +814,7 @@ class Vertex:
# Run steps
for step in self.steps:
if step not in self.steps_ran:
if inspect.iscoroutinefunction(step):
await step(user_id=user_id, event_manager=event_manager, **kwargs)
else:
step(user_id=user_id, event_manager=event_manager, **kwargs)
await step(user_id=user_id, event_manager=event_manager, **kwargs)
self.steps_ran.append(step)
self.finalize_build()

View file

@ -23,7 +23,7 @@ class CacheService(Service, Generic[LockType]):
lock: A lock to use for the operation.
Returns:
The value associated with the key, or None if the key is not found.
The value associated with the key, or CACHE_MISS if the key is not found.
"""
@abc.abstractmethod
@ -121,7 +121,7 @@ class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
lock: A lock to use for the operation.
Returns:
The value associated with the key, or None if the key is not found.
The value associated with the key, or CACHE_MISS if the key is not found.
"""
@abc.abstractmethod

View file

@ -26,18 +26,18 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
async def get(self, key, lock: asyncio.Lock | None = None):
if not lock:
async with self.lock:
return await self._get(key)
return await asyncio.to_thread(self._get, key)
else:
return await self._get(key)
return await asyncio.to_thread(self._get, key)
async def _get(self, key):
item = await asyncio.to_thread(self.cache.get, key, default=None)
def _get(self, key):
item = self.cache.get(key, default=None)
if item:
if time.time() - item["time"] < self.expiration_time:
await asyncio.to_thread(self.cache.touch, key) # Refresh the expiry time
self.cache.touch(key) # Refresh the expiry time
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
logger.info(f"Cache item for key '{key}' has expired and will be deleted.")
await self._delete(key) # Log before deleting the expired item
self.cache.delete(key) # Log before deleting the expired item
return CACHE_MISS
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
@ -81,7 +81,7 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
await self._upsert(key, value)
async def _upsert(self, key, value) -> None:
existing_value = await self.get(key)
existing_value = await asyncio.to_thread(self._get, key)
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value

View file

@ -56,7 +56,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
lock: A lock to use for the operation.
Returns:
The value associated with the key, or None if the key is not found or the item has expired.
The value associated with the key, or CACHE_MISS if the key is not found or the item has expired.
"""
with lock or self._lock:
return self._get_without_lock(key)
@ -70,7 +70,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
# Check if the value is pickled
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
self.delete(key)
return None
return CACHE_MISS
def set(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
"""Add an item to the cache.
@ -105,7 +105,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
"""
with lock or self._lock:
existing_value = self._get_without_lock(key)
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value
@ -233,9 +233,9 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
@override
async def get(self, key, lock=None):
if key is None:
return None
return CACHE_MISS
value = await self._client.get(str(key))
return pickle.loads(value) if value else None
return pickle.loads(value) if value else CACHE_MISS
@override
async def set(self, key, value, lock=None) -> None:

View file

@ -63,5 +63,5 @@ class ChatService(Service):
lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None.
"""
if isinstance(self.cache_service, AsyncBaseCacheService):
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
return await self.cache_service.delete(key, lock=lock or self.async_cache_locks[key])
return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key])

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import importlib
import inspect
from typing import TYPE_CHECKING
@ -95,9 +94,7 @@ class ServiceManager:
continue
logger.debug(f"Teardown service {service.name}")
try:
result = service.teardown()
if asyncio.iscoroutine(result):
await result
await service.teardown()
except Exception as exc: # noqa: BLE001
logger.exception(exc)
self.services = {}

View file

@ -1,7 +1,9 @@
from collections.abc import Coroutine
import asyncio
from typing import TYPE_CHECKING
from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService
from langflow.services.cache.utils import CacheMiss
from langflow.services.session.utils import compute_dict_hash, session_id_generator
if TYPE_CHECKING:
@ -12,21 +14,21 @@ class SessionService(Service):
name = "session_service"
def __init__(self, cache_service) -> None:
self.cache_service: CacheService = cache_service
self.cache_service: CacheService | AsyncBaseCacheService = cache_service
async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
# Check if the data is cached
is_cached = self.cache_service.contains(key)
if isinstance(is_cached, Coroutine):
if await is_cached:
return await self.cache_service.get(key)
elif is_cached:
return self.cache_service.get(key)
if isinstance(self.cache_service, AsyncBaseCacheService):
value = await self.cache_service.get(key)
else:
value = await asyncio.to_thread(self.cache_service.get, key)
if not isinstance(value, CacheMiss):
return value
if key is None:
key = self.generate_key(session_id=None, data_graph=data_graph)
if data_graph is None:
return (None, None)
return None, None
# If not cached, build the graph and cache it
from langflow.graph.graph.base import Graph
@ -48,13 +50,13 @@ class SessionService(Service):
return self.build_key(session_id, data_graph=data_graph)
async def update_session(self, session_id, value) -> None:
result = self.cache_service.set(session_id, value)
# if it is a coroutine, await it
if isinstance(result, Coroutine):
await result
if isinstance(self.cache_service, AsyncBaseCacheService):
await self.cache_service.set(session_id, value)
else:
await asyncio.to_thread(self.cache_service.set, session_id, value)
async def clear_session(self, session_id) -> None:
result = self.cache_service.delete(session_id)
# if it is a coroutine, await it
if isinstance(result, Coroutine):
await result
if isinstance(self.cache_service, AsyncBaseCacheService):
await self.cache_service.delete(session_id)
else:
await asyncio.to_thread(self.cache_service.delete, session_id)