feat: implemented caching of agent

This commit is contained in:
Gabriel Almeida 2023-04-01 10:58:46 -03:00
commit 2235767d22
4 changed files with 141 additions and 10 deletions

View file

49
src/backend/langflow/cache/utils.py vendored Normal file
View file

@ -0,0 +1,49 @@
import contextlib
import hashlib
import json
import os
from pathlib import Path
import tempfile
import dill
PREFIX = "langflow_cache"
def clear_old_cache_files(max_cache_size: int = 10):
cache_dir = Path(tempfile.gettempdir())
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):
os.remove(cache_file)
def remove_position_info(node):
node.pop("position", None)
def compute_hash(graph_data):
for node in graph_data["nodes"]:
remove_position_info(node)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def save_cache(hash_val, chat_data):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
if cache_path.exists():
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None

View file

@ -2,30 +2,48 @@ import contextlib
import io
import re
from typing import Any, Dict
from langflow.cache.utils import compute_hash, load_cache, save_cache
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils import payload
def load_langchain_object(data_graph):
computed_hash = compute_hash(data_graph)
# Load langchain_object from cache if it exists
langchain_object = load_cache(computed_hash)
if langchain_object is None:
nodes = data_graph["nodes"]
# Add input variables
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
langchain_object = graph.build()
return computed_hash, langchain_object
def process_graph(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
nodes = data_graph["nodes"]
# Add input variables
# ? Is this necessary?
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
langchain_object = graph.build()
# Load langchain object
computed_hash, langchain_object = load_langchain_object(data_graph)
message = data_graph["message"]
# Process json
# Generate result and thought
result, thought = get_result_and_thought_using_graph(langchain_object, message)
# Save langchain_object to cache
# We have to save it here because if the
# memory is updated we need to keep the new values
save_cache(computed_hash, langchain_object)
return {
"result": result,
"thought": re.sub(

64
tests/test_cache.py Normal file
View file

@ -0,0 +1,64 @@
import json
import hashlib
from pathlib import Path
import dill
import tempfile
from langflow.cache.utils import compute_hash, load_cache, save_cache, PREFIX
from langflow.interface.run import load_langchain_object, process_graph
import pytest
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
return flow_graph["data"]
@pytest.fixture
def basic_data_graph():
return get_graph()
@pytest.fixture
def complex_data_graph():
return get_graph("complex")
@pytest.fixture
def openapi_data_graph():
return get_graph("openapi")
def langchain_objects_are_equal(obj1, obj2):
return str(obj1) == str(obj2)
def test_cache_creation(basic_data_graph):
# Compute hash for the input data_graph
computed_hash = compute_hash(basic_data_graph)
# Call process_graph function to build and cache the langchain_object
_ = load_langchain_object(basic_data_graph)
# Check if the cache file exists
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}_{computed_hash}.dill"
assert cache_file.exists()
def test_cache_reuse(basic_data_graph):
# Call process_graph function to build and cache the langchain_object
result1 = load_langchain_object(basic_data_graph)
# Call process_graph function again to use the cached langchain_object
result2 = load_langchain_object(basic_data_graph)
# Compare the results to ensure the same langchain_object was used
assert langchain_objects_are_equal(result1, result2)