refactor(cache): move cache functionality to a separate class

feat(cache): add support for multiple clients and context manager to set client_id
feat(cache): add observer pattern to notify on cache changes
feat(cache): add async observer pattern to notify on cache changes in async functions
feat(cache): add methods to add pandas DataFrame or Series and PIL Image to cache
feat(cache): add method to get an object from cache by key
feat(cache): add method to get the last added item in cache
This commit is contained in:
Gabriel Almeida 2023-04-20 11:09:11 -03:00
commit 3da30cc5bf
3 changed files with 127 additions and 31 deletions

View file

@ -1 +1 @@
from langflow.cache.base import add_pandas, add_image, get # noqa
from langflow.cache.manager import cache_manager # noqa

View file

@ -152,33 +152,3 @@ def load_cache(hash_val):
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None
def add_pandas(name: str, obj: Any):
if isinstance(obj, (pd.DataFrame, pd.Series)):
CACHE[name] = {"obj": obj, "type": "pandas"}
else:
raise ValueError("Object is not a pandas DataFrame or Series")
def add_image(name: str, obj: Any):
if isinstance(obj, Image.Image):
CACHE[name] = {"obj": obj, "type": "image"}
else:
raise ValueError("Object is not a PIL Image")
def get(name: str):
return CACHE.get(name, {}).get("obj", None)
# get last added item
def get_last():
obj_dict = list(CACHE.values())[-1]
if obj_dict["type"] == "pandas":
# return a csv string
return obj_dict["obj"].to_csv()
elif obj_dict["type"] == "image":
# return a base64 encoded string
return base64.b64encode(obj_dict["obj"].tobytes()).decode("utf-8")
return obj_dict["obj"]

126
src/backend/langflow/cache/manager.py vendored Normal file
View file

@ -0,0 +1,126 @@
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, List
from PIL import Image
import pandas as pd
class Subject:
"""Base class for implementing the observer pattern."""
def __init__(self):
self.observers: List[Callable[[], None]] = []
def attach(self, observer: Callable[[], None]):
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], None]):
"""Detach an observer from the subject."""
self.observers.remove(observer)
def notify(self):
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
continue
observer()
class AsyncSubject:
"""Base class for implementing the async observer pattern."""
def __init__(self):
self.observers: List[Callable[[], Awaitable]] = []
def attach(self, observer: Callable[[], Awaitable]):
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], Awaitable]):
"""Detach an observer from the subject."""
self.observers.remove(observer)
async def notify(self):
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
continue
await observer()
class CacheManager(Subject):
"""Manages cache for different clients and notifies observers on changes."""
def __init__(self):
super().__init__()
self.CACHE = {}
self.current_client_id = None
@contextmanager
def set_client_id(self, client_id: str):
"""
Context manager to set the current client_id and associated cache.
Args:
client_id (str): The client identifier.
"""
previous_client_id = self.current_client_id
self.current_client_id = client_id
self.current_cache = self.CACHE.setdefault(client_id, {})
try:
yield
finally:
self.current_client_id = previous_client_id
self.current_cache = self.CACHE.get(self.current_client_id, {})
def add_pandas(self, name: str, obj: Any):
"""
Add a pandas DataFrame or Series to the current client's cache.
Args:
name (str): The cache key.
obj (Any): The pandas DataFrame or Series object.
"""
if isinstance(obj, (pd.DataFrame, pd.Series)):
self.current_cache[name] = {"obj": obj, "type": "pandas"}
self.notify()
else:
raise ValueError("Object is not a pandas DataFrame or Series")
def add_image(self, name: str, obj: Any):
"""
Add a PIL Image to the current client's cache.
Args:
name (str): The cache key.
obj (Any): The PIL Image object.
"""
if isinstance(obj, Image.Image):
self.current_cache[name] = {"obj": obj, "type": "image"}
self.notify()
else:
raise ValueError("Object is not a PIL Image")
def get(self, name: str):
"""
Get an object from the current client's cache.
Args:
name (str): The cache key.
Returns:
The cached object associated with the given cache key.
"""
return self.current_cache[name]
def get_last(self):
"""
Get the last added item in the current client's cache.
Returns:
The last added item in the cache.
"""
return list(self.current_cache.values())[-1]
cache_manager = CacheManager()