fix: caching now is more reliable according to frontend info

This commit is contained in:
Gabriel Almeida 2023-04-03 09:42:05 -03:00
commit 3a931a6fc0
3 changed files with 60 additions and 25 deletions

View file

@ -10,7 +10,7 @@ import dill # type: ignore
PREFIX = "langflow_cache"
def clear_old_cache_files(max_cache_size: int = 10):
def clear_old_cache_files(max_cache_size: int = 3):
cache_dir = Path(tempfile.gettempdir())
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
@ -24,23 +24,45 @@ def clear_old_cache_files(max_cache_size: int = 10):
os.remove(cache_file)
def remove_position_info(node):
node.pop("position", None)
def filter_json(json_data):
filtered_data = json_data.copy()
# Remove 'viewport' and 'chatHistory' keys
if "viewport" in filtered_data:
del filtered_data["viewport"]
if "chatHistory" in filtered_data:
del filtered_data["chatHistory"]
# Filter nodes
if "nodes" in filtered_data:
for node in filtered_data["nodes"]:
if "position" in node:
del node["position"]
if "positionAbsolute" in node:
del node["positionAbsolute"]
if "selected" in node:
del node["selected"]
if "dragging" in node:
del node["dragging"]
return filtered_data
def compute_hash(graph_data):
for node in graph_data["nodes"]:
remove_position_info(node)
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def save_cache(hash_val, chat_data):
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
if clean_old_cache_files:
clear_old_cache_files()
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"

View file

@ -7,29 +7,39 @@ from langflow.cache.utils import compute_hash, load_cache, save_cache
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils import payload
import logging
logger = logging.getLogger(__name__)
from langflow.utils.logger import logger
def load_langchain_object(data_graph):
def load_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
computed_hash = compute_hash(data_graph)
# Load langchain_object from cache if it exists
langchain_object = load_cache(computed_hash)
if langchain_object is None:
nodes = data_graph["nodes"]
# Add input variables
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
langchain_object = graph.build()
if is_first_message:
langchain_object = build_langchain_object(data_graph)
else:
logger.debug("Loading langchain object from cache")
langchain_object = load_cache(computed_hash)
return computed_hash, langchain_object
def build_langchain_object(data_graph):
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
nodes = data_graph["nodes"]
# Add input variables
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
return graph.build()
def process_graph(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
@ -38,7 +48,10 @@ def process_graph(data_graph: Dict[str, Any]):
# Load langchain object
logger.debug("Loading langchain object")
message = data_graph.pop("message", "")
computed_hash, langchain_object = load_langchain_object(data_graph)
is_first_message = len(data_graph.get("chatHistory", [])) == 0
computed_hash, langchain_object = load_langchain_object(
data_graph, is_first_message
)
logger.debug("Loaded langchain object")
# Generate result and thought
@ -50,7 +63,7 @@ def process_graph(data_graph: Dict[str, Any]):
# We have to save it here because if the
# memory is updated we need to keep the new values
logger.debug("Saving langchain object to cache")
save_cache(computed_hash, langchain_object)
save_cache(computed_hash, langchain_object, is_first_message)
logger.debug("Saved langchain object to cache")
return {
"result": str(result),

View file

@ -23,7 +23,7 @@ def get_type_list():
return all_types
def build_langchain_types_dict():
def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
"""Build a dictionary of all langchain types"""
all_types = {}