🔨 refactor(base.py): refactor BaseCache class to use abstract methods

This commit refactors the BaseCache class to use abstract methods instead of concrete methods. This makes the class more flexible and allows for different implementations of the cache. The abstract methods include get, set, delete, clear, __contains__, __getitem__, __setitem__, and __delitem__.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-11 18:07:17 -03:00
commit a27e5d58dd

View file

@ -1,154 +1,95 @@
import base64
import contextlib
import functools
import hashlib
import json
import os
import tempfile
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict
import dill # type: ignore
CACHE: Dict[str, Any] = {}
import abc
def create_cache_folder(func):
def wrapper(*args, **kwargs):
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
# Create the destination folder if it doesn't exist
os.makedirs(cache_path, exist_ok=True)
return func(*args, **kwargs)
return wrapper
def memoize_dict(maxsize=128):
cache = OrderedDict()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hashed = compute_dict_hash(args[0])
key = (func.__name__, hashed, frozenset(kwargs.items()))
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
if len(cache) > maxsize:
cache.popitem(last=False)
else:
result = cache[key]
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache # type: ignore
wrapper.cache = cache # type: ignore
return wrapper
return decorator
PREFIX = "langflow_cache"
@create_cache_folder
def clear_old_cache_files(max_cache_size: int = 3):
cache_dir = Path(tempfile.gettempdir()) / PREFIX
cache_files = list(cache_dir.glob("*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):
os.remove(cache_file)
def compute_dict_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def filter_json(json_data):
filtered_data = json_data.copy()
# Remove 'viewport' and 'chatHistory' keys
if "viewport" in filtered_data:
del filtered_data["viewport"]
if "chatHistory" in filtered_data:
del filtered_data["chatHistory"]
# Filter nodes
if "nodes" in filtered_data:
for node in filtered_data["nodes"]:
if "position" in node:
del node["position"]
if "positionAbsolute" in node:
del node["positionAbsolute"]
if "selected" in node:
del node["selected"]
if "dragging" in node:
del node["dragging"]
return filtered_data
@create_cache_folder
def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str:
class BaseCache(abc.ABC):
"""
Save a binary file to the specified folder.
Args:
content: The content of the file as a bytes object.
file_name: The name of the file, including its extension.
Returns:
The path to the saved file.
Abstract base class for a cache.
"""
if not any(file_name.endswith(suffix) for suffix in accepted_types):
raise ValueError(f"File {file_name} is not accepted")
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
if not content:
raise ValueError("Please, reload the file in the loader.")
data = content.split(",")[1]
decoded_bytes = base64.b64decode(data)
@abc.abstractmethod
def get(self, key):
"""
Retrieve an item from the cache.
# Create the full file path
file_path = os.path.join(cache_path, file_name)
Args:
key: The key of the item to retrieve.
# Save the binary content to the file
with open(file_path, "wb") as file:
file.write(decoded_bytes)
Returns:
The value associated with the key, or None if the key is not found.
"""
pass
return file_path
@abc.abstractmethod
def set(self, key, value):
"""
Add an item to the cache.
Args:
key: The key of the item.
value: The value to cache.
"""
pass
@create_cache_folder
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
@abc.abstractmethod
def delete(self, key):
"""
Remove an item from the cache.
if clean_old_cache_files:
clear_old_cache_files()
Args:
key: The key of the item to remove.
"""
pass
@abc.abstractmethod
def clear(self):
"""
Clear all items from the cache.
"""
pass
@create_cache_folder
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
if cache_path.exists():
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None
@abc.abstractmethod
def __contains__(self, key):
"""
Check if the key is in the cache.
Args:
key: The key of the item to check.
Returns:
True if the key is in the cache, False otherwise.
"""
pass
@abc.abstractmethod
def __getitem__(self, key):
"""
Retrieve an item from the cache using the square bracket notation.
Args:
key: The key of the item to retrieve.
Returns:
The value associated with the key, or None if the key is not found.
"""
pass
@abc.abstractmethod
def __setitem__(self, key, value):
"""
Add an item to the cache using the square bracket notation.
Args:
key: The key of the item.
value: The value to cache.
"""
pass
@abc.abstractmethod
def __delitem__(self, key):
"""
Remove an item from the cache using the square bracket notation.
Args:
key: The key of the item to remove.
"""
pass