From 1d5f156a223317e2f30c4e5a5a9e2ef26447d33d Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sat, 6 May 2023 10:25:44 -0300 Subject: [PATCH] refactor(cache): add cache attribute to memoized function wrapper refactor(test_cache.py): update import statements and function names test(cache): add tests for load_or_build_langchain_object, build_langchain_object_with_caching, build_graph, and cache size limit --- src/backend/langflow/cache/base.py | 1 + tests/test_cache.py | 65 +++++++++++++++++++----------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index 3d667b8b4..1f2039b27 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -48,6 +48,7 @@ def memoize_dict(maxsize=128): cache.clear() wrapper.clear_cache = clear_cache # type: ignore + wrapper.cache = cache # type: ignore return wrapper return decorator diff --git a/tests/test_cache.py b/tests/test_cache.py index 131e015f3..3d3e951fc 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,10 +1,11 @@ import json -import tempfile -from pathlib import Path import pytest -from langflow.cache.base import PREFIX, save_cache -from langflow.interface.run import load_langchain_object +from langflow.interface.run import ( + build_graph, + build_langchain_object_with_caching, + load_or_build_langchain_object, +) def get_graph(_type="basic"): @@ -40,26 +41,44 @@ def langchain_objects_are_equal(obj1, obj2): return str(obj1) == str(obj2) -def test_cache_creation(basic_data_graph): - # Compute hash for the input data_graph - # Call process_graph function to build and cache the langchain_object - is_first_message = True - computed_hash, langchain_object = load_langchain_object( - basic_data_graph, is_first_message=is_first_message - ) - save_cache(computed_hash, langchain_object, is_first_message) - # Check if the cache file exists - cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}/{computed_hash}.dill" - - assert cache_file.exists() +# Test load_or_build_langchain_object +def test_load_or_build_langchain_object_first_message_true(basic_data_graph): + build_langchain_object_with_caching.clear_cache() + graph = load_or_build_langchain_object(basic_data_graph, is_first_message=True) + assert graph is not None -def test_cache_reuse(basic_data_graph): - # Call process_graph function to build and cache the langchain_object - result1 = load_langchain_object(basic_data_graph) +def test_load_or_build_langchain_object_first_message_false(basic_data_graph): + graph = load_or_build_langchain_object(basic_data_graph, is_first_message=False) + assert graph is not None - # Call process_graph function again to use the cached langchain_object - result2 = load_langchain_object(basic_data_graph) - # Compare the results to ensure the same langchain_object was used - assert langchain_objects_are_equal(result1, result2) +# Test build_langchain_object_with_caching +def test_build_langchain_object_with_caching(basic_data_graph): + build_langchain_object_with_caching.clear_cache() + graph = build_langchain_object_with_caching(basic_data_graph) + assert graph is not None + + +# Test build_graph +def test_build_graph(basic_data_graph): + graph = build_graph(basic_data_graph) + assert graph is not None + assert len(graph.nodes) == len(basic_data_graph["nodes"]) + assert len(graph.edges) == len(basic_data_graph["edges"]) + + +# Test cache size limit +def test_cache_size_limit(basic_data_graph): + build_langchain_object_with_caching.clear_cache() + for i in range(11): + modified_data_graph = basic_data_graph.copy() + nodes = modified_data_graph["nodes"] + node_id = nodes[0]["id"] + # Now we replace all instances ode node_id with a new id in the json + json_string = json.dumps(modified_data_graph) + modified_json_string = json_string.replace(node_id, f"{node_id}_{i}") + modified_data_graph_new_id = json.loads(modified_json_string) + build_langchain_object_with_caching(modified_data_graph_new_id) + + assert len(build_langchain_object_with_caching.cache) == 10