refactor(cache): add cache attribute to memoized function wrapper
refactor(test_cache.py): update import statements and function names test(cache): add tests for load_or_build_langchain_object, build_langchain_object_with_caching, build_graph, and cache size limit
This commit is contained in:
parent
fa9825849a
commit
1d5f156a22
2 changed files with 43 additions and 23 deletions
1
src/backend/langflow/cache/base.py
vendored
1
src/backend/langflow/cache/base.py
vendored
|
|
@ -48,6 +48,7 @@ def memoize_dict(maxsize=128):
|
|||
cache.clear()
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langflow.cache.base import PREFIX, save_cache
|
||||
from langflow.interface.run import load_langchain_object
|
||||
from langflow.interface.run import (
|
||||
build_graph,
|
||||
build_langchain_object_with_caching,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
|
|
@ -40,26 +41,44 @@ def langchain_objects_are_equal(obj1, obj2):
|
|||
return str(obj1) == str(obj2)
|
||||
|
||||
|
||||
def test_cache_creation(basic_data_graph):
|
||||
# Compute hash for the input data_graph
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
is_first_message = True
|
||||
computed_hash, langchain_object = load_langchain_object(
|
||||
basic_data_graph, is_first_message=is_first_message
|
||||
)
|
||||
save_cache(computed_hash, langchain_object, is_first_message)
|
||||
# Check if the cache file exists
|
||||
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}/{computed_hash}.dill"
|
||||
|
||||
assert cache_file.exists()
|
||||
# Test load_or_build_langchain_object
|
||||
def test_load_or_build_langchain_object_first_message_true(basic_data_graph):
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
graph = load_or_build_langchain_object(basic_data_graph, is_first_message=True)
|
||||
assert graph is not None
|
||||
|
||||
|
||||
def test_cache_reuse(basic_data_graph):
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
result1 = load_langchain_object(basic_data_graph)
|
||||
def test_load_or_build_langchain_object_first_message_false(basic_data_graph):
|
||||
graph = load_or_build_langchain_object(basic_data_graph, is_first_message=False)
|
||||
assert graph is not None
|
||||
|
||||
# Call process_graph function again to use the cached langchain_object
|
||||
result2 = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Compare the results to ensure the same langchain_object was used
|
||||
assert langchain_objects_are_equal(result1, result2)
|
||||
# Test build_langchain_object_with_caching
|
||||
def test_build_langchain_object_with_caching(basic_data_graph):
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
graph = build_langchain_object_with_caching(basic_data_graph)
|
||||
assert graph is not None
|
||||
|
||||
|
||||
# Test build_graph
|
||||
def test_build_graph(basic_data_graph):
|
||||
graph = build_graph(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.edges) == len(basic_data_graph["edges"])
|
||||
|
||||
|
||||
# Test cache size limit
|
||||
def test_cache_size_limit(basic_data_graph):
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
for i in range(11):
|
||||
modified_data_graph = basic_data_graph.copy()
|
||||
nodes = modified_data_graph["nodes"]
|
||||
node_id = nodes[0]["id"]
|
||||
# Now we replace all instances ode node_id with a new id in the json
|
||||
json_string = json.dumps(modified_data_graph)
|
||||
modified_json_string = json_string.replace(node_id, f"{node_id}_{i}")
|
||||
modified_data_graph_new_id = json.loads(modified_json_string)
|
||||
build_langchain_object_with_caching(modified_data_graph_new_id)
|
||||
|
||||
assert len(build_langchain_object_with_caching.cache) == 10
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue