diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py index f7aac380b..583d5ac6d 100644 --- a/src/backend/langflow/cache/__init__.py +++ b/src/backend/langflow/cache/__init__.py @@ -1 +1 @@ -from langflow.cache.base import add_pandas, add_image, get # noqa +from langflow.cache.manager import cache_manager # noqa diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index 9dd5c1780..ba250da6b 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -152,33 +152,3 @@ def load_cache(hash_val): with cache_path.open("rb") as cache_file: return dill.load(cache_file) return None - - -def add_pandas(name: str, obj: Any): - if isinstance(obj, (pd.DataFrame, pd.Series)): - CACHE[name] = {"obj": obj, "type": "pandas"} - else: - raise ValueError("Object is not a pandas DataFrame or Series") - - -def add_image(name: str, obj: Any): - if isinstance(obj, Image.Image): - CACHE[name] = {"obj": obj, "type": "image"} - else: - raise ValueError("Object is not a PIL Image") - - -def get(name: str): - return CACHE.get(name, {}).get("obj", None) - - -# get last added item -def get_last(): - obj_dict = list(CACHE.values())[-1] - if obj_dict["type"] == "pandas": - # return a csv string - return obj_dict["obj"].to_csv() - elif obj_dict["type"] == "image": - # return a base64 encoded string - return base64.b64encode(obj_dict["obj"].tobytes()).decode("utf-8") - return obj_dict["obj"] diff --git a/src/backend/langflow/cache/manager.py b/src/backend/langflow/cache/manager.py new file mode 100644 index 000000000..ba34a3a8d --- /dev/null +++ b/src/backend/langflow/cache/manager.py @@ -0,0 +1,126 @@ +from contextlib import contextmanager +from typing import Any, Awaitable, Callable, List +from PIL import Image +import pandas as pd + + +class Subject: + """Base class for implementing the observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], None]] = [] + + def attach(self, observer: Callable[[], None]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], None]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + observer() + + +class AsyncSubject: + """Base class for implementing the async observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], Awaitable]] = [] + + def attach(self, observer: Callable[[], Awaitable]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], Awaitable]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + async def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + await observer() + + +class CacheManager(Subject): + """Manages cache for different clients and notifies observers on changes.""" + + def __init__(self): + super().__init__() + self.CACHE = {} + self.current_client_id = None + + @contextmanager + def set_client_id(self, client_id: str): + """ + Context manager to set the current client_id and associated cache. + + Args: + client_id (str): The client identifier. + """ + previous_client_id = self.current_client_id + self.current_client_id = client_id + self.current_cache = self.CACHE.setdefault(client_id, {}) + try: + yield + finally: + self.current_client_id = previous_client_id + self.current_cache = self.CACHE.get(self.current_client_id, {}) + + def add_pandas(self, name: str, obj: Any): + """ + Add a pandas DataFrame or Series to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The pandas DataFrame or Series object. + """ + if isinstance(obj, (pd.DataFrame, pd.Series)): + self.current_cache[name] = {"obj": obj, "type": "pandas"} + self.notify() + else: + raise ValueError("Object is not a pandas DataFrame or Series") + + def add_image(self, name: str, obj: Any): + """ + Add a PIL Image to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The PIL Image object. + """ + if isinstance(obj, Image.Image): + self.current_cache[name] = {"obj": obj, "type": "image"} + self.notify() + else: + raise ValueError("Object is not a PIL Image") + + def get(self, name: str): + """ + Get an object from the current client's cache. + + Args: + name (str): The cache key. + + Returns: + The cached object associated with the given cache key. + """ + return self.current_cache[name] + + def get_last(self): + """ + Get the last added item in the current client's cache. + + Returns: + The last added item in the cache. + """ + return list(self.current_cache.values())[-1] + + +cache_manager = CacheManager()