feat: implemented caching of agent
This commit is contained in:
parent
5125fde38a
commit
2235767d22
4 changed files with 141 additions and 10 deletions
0
src/backend/langflow/cache/__init__.py
vendored
Normal file
0
src/backend/langflow/cache/__init__.py
vendored
Normal file
49
src/backend/langflow/cache/utils.py
vendored
Normal file
49
src/backend/langflow/cache/utils.py
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import dill
|
||||
|
||||
PREFIX = "langflow_cache"
|
||||
|
||||
|
||||
def clear_old_cache_files(max_cache_size: int = 10):
|
||||
cache_dir = Path(tempfile.gettempdir())
|
||||
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
|
||||
|
||||
if len(cache_files) > max_cache_size:
|
||||
cache_files_sorted_by_mtime = sorted(
|
||||
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
|
||||
)
|
||||
|
||||
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(cache_file)
|
||||
|
||||
|
||||
def remove_position_info(node):
|
||||
node.pop("position", None)
|
||||
|
||||
|
||||
def compute_hash(graph_data):
|
||||
for node in graph_data["nodes"]:
|
||||
remove_position_info(node)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def save_cache(hash_val, chat_data):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
dill.dump(chat_data, cache_file)
|
||||
|
||||
|
||||
def load_cache(hash_val):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
if cache_path.exists():
|
||||
with cache_path.open("rb") as cache_file:
|
||||
return dill.load(cache_file)
|
||||
return None
|
||||
|
|
@ -2,30 +2,48 @@ import contextlib
|
|||
import io
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
from langflow.cache.utils import compute_hash, load_cache, save_cache
|
||||
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils import payload
|
||||
|
||||
|
||||
def load_langchain_object(data_graph):
|
||||
computed_hash = compute_hash(data_graph)
|
||||
|
||||
# Load langchain_object from cache if it exists
|
||||
langchain_object = load_cache(computed_hash)
|
||||
if langchain_object is None:
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
langchain_object = graph.build()
|
||||
|
||||
return computed_hash, langchain_object
|
||||
|
||||
|
||||
def process_graph(data_graph: Dict[str, Any]):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
# ? Is this necessary?
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
langchain_object = graph.build()
|
||||
# Load langchain object
|
||||
computed_hash, langchain_object = load_langchain_object(data_graph)
|
||||
message = data_graph["message"]
|
||||
# Process json
|
||||
|
||||
# Generate result and thought
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
|
||||
# Save langchain_object to cache
|
||||
# We have to save it here because if the
|
||||
# memory is updated we need to keep the new values
|
||||
save_cache(computed_hash, langchain_object)
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"thought": re.sub(
|
||||
|
|
|
|||
64
tests/test_cache.py
Normal file
64
tests/test_cache.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import dill
|
||||
import tempfile
|
||||
from langflow.cache.utils import compute_hash, load_cache, save_cache, PREFIX
|
||||
from langflow.interface.run import load_langchain_object, process_graph
|
||||
import pytest
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
return flow_graph["data"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_data_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_data_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_data_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def langchain_objects_are_equal(obj1, obj2):
|
||||
return str(obj1) == str(obj2)
|
||||
|
||||
|
||||
def test_cache_creation(basic_data_graph):
|
||||
# Compute hash for the input data_graph
|
||||
computed_hash = compute_hash(basic_data_graph)
|
||||
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
_ = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Check if the cache file exists
|
||||
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}_{computed_hash}.dill"
|
||||
assert cache_file.exists()
|
||||
|
||||
|
||||
def test_cache_reuse(basic_data_graph):
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
result1 = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Call process_graph function again to use the cached langchain_object
|
||||
result2 = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Compare the results to ensure the same langchain_object was used
|
||||
assert langchain_objects_are_equal(result1, result2)
|
||||
Loading…
Add table
Add a link
Reference in a new issue