refactor: Update CacheService to use generic types for locks

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 18:46:41 -03:00
commit 69146c682e
2 changed files with 21 additions and 18 deletions

View file

@ -1,12 +1,15 @@
import abc
import asyncio
import threading
from typing import Optional
from typing import Generic, Optional, TypeVar
from langflow.services.base import Service
LockType = TypeVar("LockType", bound=threading.Lock)
AsyncLockType = TypeVar("AsyncLockType", bound=asyncio.Lock)
class CacheService(Service):
class CacheService(Service, Generic[LockType]):
"""
Abstract base class for a cache.
"""
@ -14,7 +17,7 @@ class CacheService(Service):
name = "cache_service"
@abc.abstractmethod
def get(self, key, lock: Optional[threading.Lock] = None):
def get(self, key, lock: Optional[LockType] = None):
"""
Retrieve an item from the cache.
@ -26,7 +29,7 @@ class CacheService(Service):
"""
@abc.abstractmethod
def set(self, key, value, lock: Optional[threading.Lock] = None):
def set(self, key, value, lock: Optional[LockType] = None):
"""
Add an item to the cache.
@ -36,7 +39,7 @@ class CacheService(Service):
"""
@abc.abstractmethod
def upsert(self, key, value, lock: Optional[threading.Lock] = None):
def upsert(self, key, value, lock: Optional[LockType] = None):
"""
Add an item to the cache if it doesn't exist, or update it if it does.
@ -46,7 +49,7 @@ class CacheService(Service):
"""
@abc.abstractmethod
def delete(self, key, lock: Optional[threading.Lock] = None):
def delete(self, key, lock: Optional[LockType] = None):
"""
Remove an item from the cache.
@ -55,7 +58,7 @@ class CacheService(Service):
"""
@abc.abstractmethod
def clear(self, lock: Optional[threading.Lock] = None):
def clear(self, lock: Optional[LockType] = None):
"""
Clear all items from the cache.
"""
@ -101,7 +104,7 @@ class CacheService(Service):
"""
class AsyncBaseCacheService(Service):
class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
"""
Abstract base class for a async cache.
"""
@ -109,7 +112,7 @@ class AsyncBaseCacheService(Service):
name = "cache_service"
@abc.abstractmethod
async def get(self, key, lock: Optional[asyncio.Lock] = None):
async def get(self, key, lock: Optional[AsyncLockType] = None):
"""
Retrieve an item from the cache.
@ -121,7 +124,7 @@ class AsyncBaseCacheService(Service):
"""
@abc.abstractmethod
async def set(self, key, value, lock: Optional[asyncio.Lock] = None):
async def set(self, key, value, lock: Optional[AsyncLockType] = None):
"""
Add an item to the cache.
@ -131,7 +134,7 @@ class AsyncBaseCacheService(Service):
"""
@abc.abstractmethod
async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None):
async def upsert(self, key, value, lock: Optional[AsyncLockType] = None):
"""
Add an item to the cache if it doesn't exist, or update it if it does.
@ -141,7 +144,7 @@ class AsyncBaseCacheService(Service):
"""
@abc.abstractmethod
async def delete(self, key, lock: Optional[asyncio.Lock] = None):
async def delete(self, key, lock: Optional[AsyncLockType] = None):
"""
Remove an item from the cache.
@ -150,7 +153,7 @@ class AsyncBaseCacheService(Service):
"""
@abc.abstractmethod
async def clear(self, lock: Optional[asyncio.Lock] = None):
async def clear(self, lock: Optional[AsyncLockType] = None):
"""
Clear all items from the cache.
"""

View file

@ -3,18 +3,18 @@ import pickle
import threading
import time
from collections import OrderedDict
from typing import Optional
from typing import Generic, Optional
from loguru import logger
from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.cache.base import AsyncBaseCacheService, AsyncLockType, CacheService, LockType
from langflow.services.cache.utils import CacheMiss
CACHE_MISS = CacheMiss()
class ThreadingInMemoryCache(CacheService, Service):
class ThreadingInMemoryCache(CacheService, Generic[LockType]):
"""
A simple in-memory cache using an OrderedDict.
@ -182,7 +182,7 @@ class ThreadingInMemoryCache(CacheService, Service):
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
class RedisCache(CacheService):
class RedisCache(CacheService, Generic[LockType]):
"""
A Redis-based cache implementation.
@ -332,7 +332,7 @@ class RedisCache(CacheService):
return f"RedisCache(expiration_time={self.expiration_time})"
class AsyncInMemoryCache(AsyncBaseCacheService, Service):
class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
def __init__(self, max_size=None, expiration_time=3600):
self.cache = OrderedDict()