From 2235767d22913b52eafd12a55de18141697fa120 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sat, 1 Apr 2023 10:58:46 -0300 Subject: [PATCH] feat: implemented caching of agent --- src/backend/langflow/cache/__init__.py | 0 src/backend/langflow/cache/utils.py | 49 ++++++++++++++++++++ src/backend/langflow/interface/run.py | 38 +++++++++++---- tests/test_cache.py | 64 ++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 10 deletions(-) create mode 100644 src/backend/langflow/cache/__init__.py create mode 100644 src/backend/langflow/cache/utils.py create mode 100644 tests/test_cache.py diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py new file mode 100644 index 000000000..2c4acd074 --- /dev/null +++ b/src/backend/langflow/cache/utils.py @@ -0,0 +1,49 @@ +import contextlib +import hashlib +import json +import os +from pathlib import Path +import tempfile +import dill + +PREFIX = "langflow_cache" + + +def clear_old_cache_files(max_cache_size: int = 10): + cache_dir = Path(tempfile.gettempdir()) + cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill")) + + if len(cache_files) > max_cache_size: + cache_files_sorted_by_mtime = sorted( + cache_files, key=lambda x: x.stat().st_mtime, reverse=True + ) + + for cache_file in cache_files_sorted_by_mtime[max_cache_size:]: + with contextlib.suppress(OSError): + os.remove(cache_file) + + +def remove_position_info(node): + node.pop("position", None) + + +def compute_hash(graph_data): + for node in graph_data["nodes"]: + remove_position_info(node) + + cleaned_graph_json = json.dumps(graph_data, sort_keys=True) + return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() + + +def save_cache(hash_val, chat_data): + cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill" + with cache_path.open("wb") as cache_file: + dill.dump(chat_data, cache_file) + + +def load_cache(hash_val): + cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill" + if cache_path.exists(): + with cache_path.open("rb") as cache_file: + return dill.load(cache_file) + return None diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 360bad364..998ee297d 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -2,30 +2,48 @@ import contextlib import io import re from typing import Any, Dict +from langflow.cache.utils import compute_hash, load_cache, save_cache from langflow.graph.graph import Graph from langflow.interface import loading from langflow.utils import payload +def load_langchain_object(data_graph): + computed_hash = compute_hash(data_graph) + + # Load langchain_object from cache if it exists + langchain_object = load_cache(computed_hash) + if langchain_object is None: + nodes = data_graph["nodes"] + # Add input variables + nodes = payload.extract_input_variables(nodes) + # Nodes, edges and root node + edges = data_graph["edges"] + graph = Graph(nodes, edges) + + langchain_object = graph.build() + + return computed_hash, langchain_object + + def process_graph(data_graph: Dict[str, Any]): """ Process graph by extracting input variables and replacing ZeroShotPrompt with PromptTemplate,then run the graph and return the result and thought. """ - nodes = data_graph["nodes"] - # Add input variables - # ? Is this necessary? - nodes = payload.extract_input_variables(nodes) - # Nodes, edges and root node - edges = data_graph["edges"] - graph = Graph(nodes, edges) - - langchain_object = graph.build() + # Load langchain object + computed_hash, langchain_object = load_langchain_object(data_graph) message = data_graph["message"] - # Process json + + # Generate result and thought result, thought = get_result_and_thought_using_graph(langchain_object, message) + # Save langchain_object to cache + # We have to save it here because if the + # memory is updated we need to keep the new values + save_cache(computed_hash, langchain_object) + return { "result": result, "thought": re.sub( diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 000000000..0d9102b49 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,64 @@ +import json +import hashlib +from pathlib import Path +import dill +import tempfile +from langflow.cache.utils import compute_hash, load_cache, save_cache, PREFIX +from langflow.interface.run import load_langchain_object, process_graph +import pytest + + +def get_graph(_type="basic"): + """Get a graph from a json file""" + if _type == "basic": + path = pytest.BASIC_EXAMPLE_PATH + elif _type == "complex": + path = pytest.COMPLEX_EXAMPLE_PATH + elif _type == "openapi": + path = pytest.OPENAPI_EXAMPLE_PATH + + with open(path, "r") as f: + flow_graph = json.load(f) + return flow_graph["data"] + + +@pytest.fixture +def basic_data_graph(): + return get_graph() + + +@pytest.fixture +def complex_data_graph(): + return get_graph("complex") + + +@pytest.fixture +def openapi_data_graph(): + return get_graph("openapi") + + +def langchain_objects_are_equal(obj1, obj2): + return str(obj1) == str(obj2) + + +def test_cache_creation(basic_data_graph): + # Compute hash for the input data_graph + computed_hash = compute_hash(basic_data_graph) + + # Call process_graph function to build and cache the langchain_object + _ = load_langchain_object(basic_data_graph) + + # Check if the cache file exists + cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}_{computed_hash}.dill" + assert cache_file.exists() + + +def test_cache_reuse(basic_data_graph): + # Call process_graph function to build and cache the langchain_object + result1 = load_langchain_object(basic_data_graph) + + # Call process_graph function again to use the cached langchain_object + result2 = load_langchain_object(basic_data_graph) + + # Compare the results to ensure the same langchain_object was used + assert langchain_objects_are_equal(result1, result2)