refactor(cache): add cache attribute to memoized function wrapper

refactor(test_cache.py): update import statements and function names
test(cache): add tests for load_or_build_langchain_object, build_langchain_object_with_caching, build_graph, and cache size limit
This commit is contained in:
Gabriel Almeida 2023-05-06 10:25:44 -03:00 committed by Gabriel Luiz Freitas Almeida
commit 1d5f156a22
2 changed files with 43 additions and 23 deletions

View file

@ -48,6 +48,7 @@ def memoize_dict(maxsize=128):
cache.clear()
wrapper.clear_cache = clear_cache # type: ignore
wrapper.cache = cache # type: ignore
return wrapper
return decorator

View file

@ -1,10 +1,11 @@
import json
import tempfile
from pathlib import Path
import pytest
from langflow.cache.base import PREFIX, save_cache
from langflow.interface.run import load_langchain_object
from langflow.interface.run import (
build_graph,
build_langchain_object_with_caching,
load_or_build_langchain_object,
)
def get_graph(_type="basic"):
@ -40,26 +41,44 @@ def langchain_objects_are_equal(obj1, obj2):
return str(obj1) == str(obj2)
def test_cache_creation(basic_data_graph):
# Compute hash for the input data_graph
# Call process_graph function to build and cache the langchain_object
is_first_message = True
computed_hash, langchain_object = load_langchain_object(
basic_data_graph, is_first_message=is_first_message
)
save_cache(computed_hash, langchain_object, is_first_message)
# Check if the cache file exists
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}/{computed_hash}.dill"
assert cache_file.exists()
# Test load_or_build_langchain_object
def test_load_or_build_langchain_object_first_message_true(basic_data_graph):
build_langchain_object_with_caching.clear_cache()
graph = load_or_build_langchain_object(basic_data_graph, is_first_message=True)
assert graph is not None
def test_cache_reuse(basic_data_graph):
# Call process_graph function to build and cache the langchain_object
result1 = load_langchain_object(basic_data_graph)
def test_load_or_build_langchain_object_first_message_false(basic_data_graph):
graph = load_or_build_langchain_object(basic_data_graph, is_first_message=False)
assert graph is not None
# Call process_graph function again to use the cached langchain_object
result2 = load_langchain_object(basic_data_graph)
# Compare the results to ensure the same langchain_object was used
assert langchain_objects_are_equal(result1, result2)
# Test build_langchain_object_with_caching
def test_build_langchain_object_with_caching(basic_data_graph):
build_langchain_object_with_caching.clear_cache()
graph = build_langchain_object_with_caching(basic_data_graph)
assert graph is not None
# Test build_graph
def test_build_graph(basic_data_graph):
graph = build_graph(basic_data_graph)
assert graph is not None
assert len(graph.nodes) == len(basic_data_graph["nodes"])
assert len(graph.edges) == len(basic_data_graph["edges"])
# Test cache size limit
def test_cache_size_limit(basic_data_graph):
build_langchain_object_with_caching.clear_cache()
for i in range(11):
modified_data_graph = basic_data_graph.copy()
nodes = modified_data_graph["nodes"]
node_id = nodes[0]["id"]
# Now we replace all instances ode node_id with a new id in the json
json_string = json.dumps(modified_data_graph)
modified_json_string = json_string.replace(node_id, f"{node_id}_{i}")
modified_data_graph_new_id = json.loads(modified_json_string)
build_langchain_object_with_caching(modified_data_graph_new_id)
assert len(build_langchain_object_with_caching.cache) == 10