From 3a931a6fc0ce4259e4ed1caeeb0c68fe087fe040 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Mon, 3 Apr 2023 09:42:05 -0300 Subject: [PATCH] fix: caching now is more reliable according to frontend info --- src/backend/langflow/cache/utils.py | 34 ++++++++++++++--- src/backend/langflow/interface/run.py | 49 ++++++++++++++++--------- src/backend/langflow/interface/types.py | 2 +- 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py index 514d991e5..3c416f4d7 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/utils.py @@ -10,7 +10,7 @@ import dill # type: ignore PREFIX = "langflow_cache" -def clear_old_cache_files(max_cache_size: int = 10): +def clear_old_cache_files(max_cache_size: int = 3): cache_dir = Path(tempfile.gettempdir()) cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill")) @@ -24,23 +24,45 @@ def clear_old_cache_files(max_cache_size: int = 10): os.remove(cache_file) -def remove_position_info(node): - node.pop("position", None) +def filter_json(json_data): + filtered_data = json_data.copy() + + # Remove 'viewport' and 'chatHistory' keys + if "viewport" in filtered_data: + del filtered_data["viewport"] + if "chatHistory" in filtered_data: + del filtered_data["chatHistory"] + + # Filter nodes + if "nodes" in filtered_data: + for node in filtered_data["nodes"]: + if "position" in node: + del node["position"] + if "positionAbsolute" in node: + del node["positionAbsolute"] + if "selected" in node: + del node["selected"] + if "dragging" in node: + del node["dragging"] + + return filtered_data def compute_hash(graph_data): - for node in graph_data["nodes"]: - remove_position_info(node) + graph_data = filter_json(graph_data) cleaned_graph_json = json.dumps(graph_data, sort_keys=True) return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() -def save_cache(hash_val, chat_data): +def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool): cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill" with cache_path.open("wb") as cache_file: dill.dump(chat_data, cache_file) + if clean_old_cache_files: + clear_old_cache_files() + def load_cache(hash_val): cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill" diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 305a7261a..34259d7ad 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -7,29 +7,39 @@ from langflow.cache.utils import compute_hash, load_cache, save_cache from langflow.graph.graph import Graph from langflow.interface import loading from langflow.utils import payload -import logging - -logger = logging.getLogger(__name__) +from langflow.utils.logger import logger -def load_langchain_object(data_graph): +def load_langchain_object(data_graph, is_first_message=False): + """ + Load langchain object from cache if it exists, otherwise build it. + """ computed_hash = compute_hash(data_graph) - - # Load langchain_object from cache if it exists - langchain_object = load_cache(computed_hash) - if langchain_object is None: - nodes = data_graph["nodes"] - # Add input variables - nodes = payload.extract_input_variables(nodes) - # Nodes, edges and root node - edges = data_graph["edges"] - graph = Graph(nodes, edges) - - langchain_object = graph.build() + if is_first_message: + langchain_object = build_langchain_object(data_graph) + else: + logger.debug("Loading langchain object from cache") + langchain_object = load_cache(computed_hash) return computed_hash, langchain_object +def build_langchain_object(data_graph): + """ + Build langchain object from data_graph. + """ + + logger.debug("Building langchain object") + nodes = data_graph["nodes"] + # Add input variables + nodes = payload.extract_input_variables(nodes) + # Nodes, edges and root node + edges = data_graph["edges"] + graph = Graph(nodes, edges) + + return graph.build() + + def process_graph(data_graph: Dict[str, Any]): """ Process graph by extracting input variables and replacing ZeroShotPrompt @@ -38,7 +48,10 @@ def process_graph(data_graph: Dict[str, Any]): # Load langchain object logger.debug("Loading langchain object") message = data_graph.pop("message", "") - computed_hash, langchain_object = load_langchain_object(data_graph) + is_first_message = len(data_graph.get("chatHistory", [])) == 0 + computed_hash, langchain_object = load_langchain_object( + data_graph, is_first_message + ) logger.debug("Loaded langchain object") # Generate result and thought @@ -50,7 +63,7 @@ def process_graph(data_graph: Dict[str, Any]): # We have to save it here because if the # memory is updated we need to keep the new values logger.debug("Saving langchain object to cache") - save_cache(computed_hash, langchain_object) + save_cache(computed_hash, langchain_object, is_first_message) logger.debug("Saved langchain object to cache") return { "result": str(result), diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 307d430f0..bf3cea372 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -23,7 +23,7 @@ def get_type_list(): return all_types -def build_langchain_types_dict(): +def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union """Build a dictionary of all langchain types""" all_types = {}