parent
2880f585fb
commit
3131c0ce08
8 changed files with 59 additions and 83 deletions
13
src/backend/base/langflow/services/cache/base.py
vendored
13
src/backend/base/langflow/services/cache/base.py
vendored
|
|
@ -59,6 +59,17 @@ class CacheService(Service, Generic[LockType]):
|
|||
def clear(self, lock: LockType | None = None):
|
||||
"""Clear all items from the cache."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def contains(self, key) -> bool:
|
||||
"""Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to check.
|
||||
|
||||
Returns:
|
||||
True if the key is in the cache, False otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache.
|
||||
|
|
@ -147,7 +158,7 @@ class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
|
|||
"""Clear all items from the cache."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key) -> bool:
|
||||
async def contains(self, key) -> bool:
|
||||
"""Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -87,8 +87,8 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
value = existing_value
|
||||
await self.set(key, value)
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
return asyncio.run(asyncio.to_thread(self.cache.__contains__, key))
|
||||
async def contains(self, key) -> bool:
|
||||
return await asyncio.to_thread(self.cache.__contains__, key)
|
||||
|
||||
async def teardown(self) -> None:
|
||||
# Clean up the cache directory
|
||||
|
|
|
|||
|
|
@ -139,10 +139,14 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
with lock or self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
def contains(self, key) -> bool:
|
||||
"""Check if the key is in the cache."""
|
||||
return key in self._cache
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache."""
|
||||
return self.contains(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieve an item from the cache using the square bracket notation."""
|
||||
return self.get(key)
|
||||
|
|
@ -274,11 +278,11 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
"""Clear all items from the cache."""
|
||||
await self._client.flushdb()
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
async def contains(self, key) -> bool:
|
||||
"""Check if the key is in the cache."""
|
||||
if key is None:
|
||||
return False
|
||||
return bool(asyncio.run(self._client.exists(str(key))))
|
||||
return bool(await self._client.exists(str(key)))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the RedisCache instance."""
|
||||
|
|
@ -364,5 +368,5 @@ class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
value = existing_value
|
||||
await self.set(key, value)
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
async def contains(self, key) -> bool:
|
||||
return key in self.cache
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from threading import RLock
|
|||
from typing import Any
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.cache.base import AsyncBaseCacheService
|
||||
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
|
||||
from langflow.services.deps import get_cache_service
|
||||
|
||||
|
||||
|
|
@ -16,60 +16,7 @@ class ChatService(Service):
|
|||
def __init__(self) -> None:
|
||||
self.async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock)
|
||||
self.cache_service = get_cache_service()
|
||||
|
||||
def _get_lock(self, key: str):
|
||||
"""Retrieves the lock associated with the given key.
|
||||
|
||||
Args:
|
||||
key (str): The key to retrieve the lock for.
|
||||
|
||||
Returns:
|
||||
threading.Lock or asyncio.Lock: The lock associated with the given key.
|
||||
"""
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
return self.async_cache_locks[key]
|
||||
return self._sync_cache_locks[key]
|
||||
|
||||
async def _perform_cache_operation(
|
||||
self, operation: str, key: str, data: Any = None, lock: asyncio.Lock | None = None
|
||||
):
|
||||
"""Perform a cache operation based on the given operation type.
|
||||
|
||||
Args:
|
||||
operation (str): The type of cache operation to perform. Possible values are "upsert", "get", or "delete".
|
||||
key (str): The key associated with the cache operation.
|
||||
data (Any, optional): The data to be stored in the cache. Only applicable for "upsert" operation.
|
||||
Defaults to None.
|
||||
lock (Optional[asyncio.Lock], optional): The lock to be used for the cache operation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: The result of the cache operation. Only applicable for "get" operation.
|
||||
|
||||
Raises:
|
||||
None
|
||||
|
||||
"""
|
||||
lock = lock or self._get_lock(key)
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
if operation == "upsert":
|
||||
await self.cache_service.upsert(str(key), data, lock=lock)
|
||||
return None
|
||||
if operation == "get":
|
||||
return await self.cache_service.get(key, lock=lock)
|
||||
if operation == "delete":
|
||||
await self.cache_service.delete(key, lock=lock)
|
||||
return None
|
||||
return None
|
||||
if operation == "upsert":
|
||||
self.cache_service.upsert(str(key), data, lock=lock)
|
||||
return None
|
||||
if operation == "get":
|
||||
return self.cache_service.get(key, lock=lock)
|
||||
if operation == "delete":
|
||||
self.cache_service.delete(key, lock=lock)
|
||||
return None
|
||||
return None
|
||||
self.cache_service: CacheService | AsyncBaseCacheService = get_cache_service()
|
||||
|
||||
async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) -> bool:
|
||||
"""Set the cache for a client.
|
||||
|
|
@ -86,7 +33,12 @@ class ChatService(Service):
|
|||
"result": data,
|
||||
"type": type(data),
|
||||
}
|
||||
await self._perform_cache_operation("upsert", key, result_dict, lock)
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
await self.cache_service.upsert(str(key), result_dict, lock=lock or self.async_cache_locks[key])
|
||||
return await self.cache_service.contains(key)
|
||||
await asyncio.to_thread(
|
||||
self.cache_service.upsert, str(key), result_dict, lock=lock or self._sync_cache_locks[key]
|
||||
)
|
||||
return key in self.cache_service
|
||||
|
||||
async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any:
|
||||
|
|
@ -99,7 +51,9 @@ class ChatService(Service):
|
|||
Returns:
|
||||
Any: The cached data.
|
||||
"""
|
||||
return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key))
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
|
||||
return await asyncio.to_thread(self.cache_service.get, key, lock=lock or self._sync_cache_locks[key])
|
||||
|
||||
async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
|
||||
"""Clear the cache for a client.
|
||||
|
|
@ -108,4 +62,6 @@ class ChatService(Service):
|
|||
key (str): The cache key.
|
||||
lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None.
|
||||
"""
|
||||
await self._perform_cache_operation("delete", key, lock=lock or self._get_lock(key))
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
|
||||
return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key])
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.cache.service import CacheService
|
||||
from langflow.services.cache.service import AsyncBaseCacheService, CacheService
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.database.service import DatabaseService
|
||||
from langflow.services.plugins.service import PluginService
|
||||
|
|
@ -188,7 +188,7 @@ def session_scope() -> Generator[Session, None, None]:
|
|||
raise
|
||||
|
||||
|
||||
def get_cache_service() -> CacheService:
|
||||
def get_cache_service() -> CacheService | AsyncBaseCacheService:
|
||||
"""Retrieves the cache service from the service manager.
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -16,11 +16,12 @@ class SessionService(Service):
|
|||
|
||||
async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
|
||||
# Check if the data is cached
|
||||
if key in self.cache_service:
|
||||
result = self.cache_service.get(key)
|
||||
if isinstance(result, Coroutine):
|
||||
result = await result
|
||||
return result
|
||||
is_cached = self.cache_service.contains(key)
|
||||
if isinstance(is_cached, Coroutine):
|
||||
if await is_cached:
|
||||
return await self.cache_service.get(key)
|
||||
elif is_cached:
|
||||
return self.cache_service.get(key)
|
||||
|
||||
if key is None:
|
||||
key = self.generate_key(session_id=None, data_graph=data_graph)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import socketio
|
||||
from loguru import logger
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.services.socket.utils import build_vertex, get_vertices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.cache.service import CacheService
|
||||
|
||||
|
||||
class SocketIOService(Service):
|
||||
name = "socket_service"
|
||||
|
||||
def __init__(self, cache_service: "CacheService"):
|
||||
def __init__(self, cache_service: CacheService | AsyncBaseCacheService):
|
||||
self.cache_service = cache_service
|
||||
|
||||
def init(self, sio: socketio.AsyncServer) -> None:
|
||||
|
|
@ -63,11 +61,14 @@ class SocketIOService(Service):
|
|||
set_cache=self.set_cache,
|
||||
)
|
||||
|
||||
def get_cache(self, sid: str) -> Any:
|
||||
async def get_cache(self, sid: str) -> Any:
|
||||
"""Get the cache for a client."""
|
||||
return self.cache_service.get(sid)
|
||||
value = self.cache_service.get(sid)
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
return await value
|
||||
return value
|
||||
|
||||
def set_cache(self, sid: str, build_result: Any) -> bool:
|
||||
async def set_cache(self, sid: str, build_result: Any) -> bool:
|
||||
"""Set the cache for a client."""
|
||||
# client_id is the flow id but that already exists in the cache
|
||||
# so we need to change it to something else
|
||||
|
|
@ -76,5 +77,8 @@ class SocketIOService(Service):
|
|||
"result": build_result,
|
||||
"type": type(build_result),
|
||||
}
|
||||
self.cache_service.upsert(sid, result_dict)
|
||||
result = self.cache_service.upsert(sid, result_dict)
|
||||
if isinstance(self.cache_service, AsyncBaseCacheService):
|
||||
await result
|
||||
return await self.cache_service.contains(sid)
|
||||
return sid in self.cache_service
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ async def build_vertex(
|
|||
set_cache: Callable,
|
||||
) -> None:
|
||||
try:
|
||||
cache = get_cache(flow_id)
|
||||
cache = await get_cache(flow_id)
|
||||
graph = cache.get("result")
|
||||
|
||||
if not isinstance(graph, Graph):
|
||||
|
|
@ -86,7 +86,7 @@ async def build_vertex(
|
|||
valid = False
|
||||
result_dict = ResultDataResponse(results={})
|
||||
artifacts = {}
|
||||
set_cache(flow_id, graph)
|
||||
await set_cache(flow_id, graph)
|
||||
log_vertex_build(
|
||||
flow_id=flow_id,
|
||||
vertex_id=vertex_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue