Merge branch 'cache' into vecstores
This commit is contained in:
commit
3dc894a57c
4 changed files with 92 additions and 12 deletions
|
|
@ -3,7 +3,7 @@ from typing import Any, Dict
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.interface.run import process_graph
|
||||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
|
||||
# build router
|
||||
|
|
@ -19,7 +19,7 @@ def get_all():
|
|||
@router.post("/predict")
|
||||
def get_load(data: Dict[str, Any]):
|
||||
try:
|
||||
return process_graph(data)
|
||||
return process_graph_cached(data)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
|
|||
45
src/backend/langflow/cache/utils.py
vendored
45
src/backend/langflow/cache/utils.py
vendored
|
|
@ -1,11 +1,40 @@
|
|||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import dill # type: ignore
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
import hashlib
|
||||
|
||||
|
||||
def memoize_dict(maxsize=128):
|
||||
cache = OrderedDict()
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
hashed = compute_dict_hash(args[0])
|
||||
key = (func.__name__, hashed, frozenset(kwargs.items()))
|
||||
if key not in cache:
|
||||
result = func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
else:
|
||||
result = cache[key]
|
||||
return result
|
||||
|
||||
def clear_cache():
|
||||
cache.clear()
|
||||
|
||||
wrapper.clear_cache = clear_cache
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
PREFIX = "langflow_cache"
|
||||
|
||||
|
|
@ -24,6 +53,13 @@ def clear_old_cache_files(max_cache_size: int = 3):
|
|||
os.remove(cache_file)
|
||||
|
||||
|
||||
def compute_dict_hash(graph_data):
|
||||
graph_data = filter_json(graph_data)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def filter_json(json_data):
|
||||
filtered_data = json_data.copy()
|
||||
|
||||
|
|
@ -48,13 +84,6 @@ def filter_json(json_data):
|
|||
return filtered_data
|
||||
|
||||
|
||||
def compute_hash(graph_data):
|
||||
graph_data = filter_json(graph_data)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import contextlib
|
|||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
from langflow.cache.utils import compute_hash, load_cache
|
||||
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -12,7 +12,7 @@ def load_langchain_object(data_graph, is_first_message=False):
|
|||
"""
|
||||
Load langchain object from cache if it exists, otherwise build it.
|
||||
"""
|
||||
computed_hash = compute_hash(data_graph)
|
||||
computed_hash = compute_dict_hash(data_graph)
|
||||
if is_first_message:
|
||||
langchain_object = build_langchain_object(data_graph)
|
||||
else:
|
||||
|
|
@ -22,6 +22,32 @@ def load_langchain_object(data_graph, is_first_message=False):
|
|||
return computed_hash, langchain_object
|
||||
|
||||
|
||||
def load_or_build_langchain_object(data_graph, is_first_message=False):
|
||||
"""
|
||||
Load langchain object from cache if it exists, otherwise build it.
|
||||
"""
|
||||
if is_first_message:
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
return build_langchain_object_with_caching(data_graph)
|
||||
|
||||
|
||||
@memoize_dict(maxsize=1)
|
||||
def build_langchain_object_with_caching(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
return graph.build()
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
|
|
@ -72,6 +98,30 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any]):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
message = data_graph.pop("message", "")
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def get_memory_key(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class VectorstoreCreator(LangChainTypeCreator):
|
|||
try:
|
||||
signature = build_template_from_class(name, vectorstores_type_to_cls_dict)
|
||||
|
||||
# TODO: Use FrontendendNode class to build the signature
|
||||
signature["template"] = {
|
||||
"documents": {
|
||||
"type": "TextSplitter",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue