Merge branch 'cache' into vecstores

This commit is contained in:
Ibis Prevedello 2023-04-10 09:28:41 -03:00
commit 3dc894a57c
4 changed files with 92 additions and 12 deletions

View file

@ -3,7 +3,7 @@ from typing import Any, Dict
from fastapi import APIRouter, HTTPException
from langflow.interface.run import process_graph
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
# build router
@ -19,7 +19,7 @@ def get_all():
@router.post("/predict")
def get_load(data: Dict[str, Any]):
try:
return process_graph(data)
return process_graph_cached(data)
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -1,11 +1,40 @@
import contextlib
import hashlib
import json
import os
import tempfile
from pathlib import Path
import dill # type: ignore
import functools
from collections import OrderedDict
import hashlib
def memoize_dict(maxsize=128):
cache = OrderedDict()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hashed = compute_dict_hash(args[0])
key = (func.__name__, hashed, frozenset(kwargs.items()))
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
if len(cache) > maxsize:
cache.popitem(last=False)
else:
result = cache[key]
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache
return wrapper
return decorator
PREFIX = "langflow_cache"
@ -24,6 +53,13 @@ def clear_old_cache_files(max_cache_size: int = 3):
os.remove(cache_file)
def compute_dict_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def filter_json(json_data):
filtered_data = json_data.copy()
@ -48,13 +84,6 @@ def filter_json(json_data):
return filtered_data
def compute_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
with cache_path.open("wb") as cache_file:

View file

@ -2,7 +2,7 @@ import contextlib
import io
from typing import Any, Dict
from langflow.cache.utils import compute_hash, load_cache
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils.logger import logger
@ -12,7 +12,7 @@ def load_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
computed_hash = compute_hash(data_graph)
computed_hash = compute_dict_hash(data_graph)
if is_first_message:
langchain_object = build_langchain_object(data_graph)
else:
@ -22,6 +22,32 @@ def load_langchain_object(data_graph, is_first_message=False):
return computed_hash, langchain_object
def load_or_build_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
if is_first_message:
build_langchain_object_with_caching.clear_cache()
return build_langchain_object_with_caching(data_graph)
@memoize_dict(maxsize=1)
def build_langchain_object_with_caching(data_graph):
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
nodes = data_graph["nodes"]
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
return graph.build()
def build_langchain_object(data_graph):
"""
Build langchain object from data_graph.
@ -72,6 +98,30 @@ def process_graph(data_graph: Dict[str, Any]):
return {"result": str(result), "thought": thought.strip()}
def process_graph_cached(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
message = data_graph.pop("message", "")
is_first_message = len(data_graph.get("chatHistory", [])) == 0
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought_using_graph(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
def get_memory_key(langchain_object):
"""
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.

View file

@ -18,6 +18,7 @@ class VectorstoreCreator(LangChainTypeCreator):
try:
signature = build_template_from_class(name, vectorstores_type_to_cls_dict)
# TODO: Use FrontendendNode class to build the signature
signature["template"] = {
"documents": {
"type": "TextSplitter",