fix: caching now is more reliable according to frontend info
This commit is contained in:
parent
5abc0f8486
commit
3a931a6fc0
3 changed files with 60 additions and 25 deletions
34
src/backend/langflow/cache/utils.py
vendored
34
src/backend/langflow/cache/utils.py
vendored
|
|
@ -10,7 +10,7 @@ import dill # type: ignore
|
|||
PREFIX = "langflow_cache"
|
||||
|
||||
|
||||
def clear_old_cache_files(max_cache_size: int = 10):
|
||||
def clear_old_cache_files(max_cache_size: int = 3):
|
||||
cache_dir = Path(tempfile.gettempdir())
|
||||
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
|
||||
|
||||
|
|
@ -24,23 +24,45 @@ def clear_old_cache_files(max_cache_size: int = 10):
|
|||
os.remove(cache_file)
|
||||
|
||||
|
||||
def remove_position_info(node):
|
||||
node.pop("position", None)
|
||||
def filter_json(json_data):
|
||||
filtered_data = json_data.copy()
|
||||
|
||||
# Remove 'viewport' and 'chatHistory' keys
|
||||
if "viewport" in filtered_data:
|
||||
del filtered_data["viewport"]
|
||||
if "chatHistory" in filtered_data:
|
||||
del filtered_data["chatHistory"]
|
||||
|
||||
# Filter nodes
|
||||
if "nodes" in filtered_data:
|
||||
for node in filtered_data["nodes"]:
|
||||
if "position" in node:
|
||||
del node["position"]
|
||||
if "positionAbsolute" in node:
|
||||
del node["positionAbsolute"]
|
||||
if "selected" in node:
|
||||
del node["selected"]
|
||||
if "dragging" in node:
|
||||
del node["dragging"]
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
def compute_hash(graph_data):
|
||||
for node in graph_data["nodes"]:
|
||||
remove_position_info(node)
|
||||
graph_data = filter_json(graph_data)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def save_cache(hash_val, chat_data):
|
||||
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
dill.dump(chat_data, cache_file)
|
||||
|
||||
if clean_old_cache_files:
|
||||
clear_old_cache_files()
|
||||
|
||||
|
||||
def load_cache(hash_val):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
|
|
|
|||
|
|
@ -7,29 +7,39 @@ from langflow.cache.utils import compute_hash, load_cache, save_cache
|
|||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils import payload
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
def load_langchain_object(data_graph):
|
||||
def load_langchain_object(data_graph, is_first_message=False):
|
||||
"""
|
||||
Load langchain object from cache if it exists, otherwise build it.
|
||||
"""
|
||||
computed_hash = compute_hash(data_graph)
|
||||
|
||||
# Load langchain_object from cache if it exists
|
||||
langchain_object = load_cache(computed_hash)
|
||||
if langchain_object is None:
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
langchain_object = graph.build()
|
||||
if is_first_message:
|
||||
langchain_object = build_langchain_object(data_graph)
|
||||
else:
|
||||
logger.debug("Loading langchain object from cache")
|
||||
langchain_object = load_cache(computed_hash)
|
||||
|
||||
return computed_hash, langchain_object
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
return graph.build()
|
||||
|
||||
|
||||
def process_graph(data_graph: Dict[str, Any]):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
|
|
@ -38,7 +48,10 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
# Load langchain object
|
||||
logger.debug("Loading langchain object")
|
||||
message = data_graph.pop("message", "")
|
||||
computed_hash, langchain_object = load_langchain_object(data_graph)
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
computed_hash, langchain_object = load_langchain_object(
|
||||
data_graph, is_first_message
|
||||
)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
# Generate result and thought
|
||||
|
|
@ -50,7 +63,7 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
# We have to save it here because if the
|
||||
# memory is updated we need to keep the new values
|
||||
logger.debug("Saving langchain object to cache")
|
||||
save_cache(computed_hash, langchain_object)
|
||||
save_cache(computed_hash, langchain_object, is_first_message)
|
||||
logger.debug("Saved langchain object to cache")
|
||||
return {
|
||||
"result": str(result),
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def get_type_list():
|
|||
return all_types
|
||||
|
||||
|
||||
def build_langchain_types_dict():
|
||||
def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
||||
"""Build a dictionary of all langchain types"""
|
||||
|
||||
all_types = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue