diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index 22f548156..b8290e691 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -3,7 +3,7 @@ from typing import Any, Dict from fastapi import APIRouter, HTTPException -from langflow.interface.run import process_graph +from langflow.interface.run import process_graph_cached from langflow.interface.types import build_langchain_types_dict # build router @@ -19,7 +19,7 @@ def get_all(): @router.post("/predict") def get_load(data: Dict[str, Any]): try: - return process_graph(data) + return process_graph_cached(data) except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py index 3c416f4d7..7dcc6b935 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/utils.py @@ -1,11 +1,40 @@ import contextlib -import hashlib import json import os import tempfile from pathlib import Path import dill # type: ignore +import functools +from collections import OrderedDict +import hashlib + + +def memoize_dict(maxsize=128): + cache = OrderedDict() + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + hashed = compute_dict_hash(args[0]) + key = (func.__name__, hashed, frozenset(kwargs.items())) + if key not in cache: + result = func(*args, **kwargs) + cache[key] = result + if len(cache) > maxsize: + cache.popitem(last=False) + else: + result = cache[key] + return result + + def clear_cache(): + cache.clear() + + wrapper.clear_cache = clear_cache + return wrapper + + return decorator + PREFIX = "langflow_cache" @@ -24,6 +53,13 @@ def clear_old_cache_files(max_cache_size: int = 3): os.remove(cache_file) +def compute_dict_hash(graph_data): + graph_data = filter_json(graph_data) + + cleaned_graph_json = json.dumps(graph_data, sort_keys=True) + return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() + + def filter_json(json_data): filtered_data = json_data.copy() @@ -48,13 +84,6 @@ def filter_json(json_data): return filtered_data -def compute_hash(graph_data): - graph_data = filter_json(graph_data) - - cleaned_graph_json = json.dumps(graph_data, sort_keys=True) - return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() - - def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool): cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill" with cache_path.open("wb") as cache_file: diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 8f7765ef2..74dd1aa8f 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -2,7 +2,7 @@ import contextlib import io from typing import Any, Dict -from langflow.cache.utils import compute_hash, load_cache +from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict from langflow.graph.graph import Graph from langflow.interface import loading from langflow.utils.logger import logger @@ -12,7 +12,7 @@ def load_langchain_object(data_graph, is_first_message=False): """ Load langchain object from cache if it exists, otherwise build it. """ - computed_hash = compute_hash(data_graph) + computed_hash = compute_dict_hash(data_graph) if is_first_message: langchain_object = build_langchain_object(data_graph) else: @@ -22,6 +22,32 @@ def load_langchain_object(data_graph, is_first_message=False): return computed_hash, langchain_object +def load_or_build_langchain_object(data_graph, is_first_message=False): + """ + Load langchain object from cache if it exists, otherwise build it. + """ + if is_first_message: + build_langchain_object_with_caching.clear_cache() + return build_langchain_object_with_caching(data_graph) + + +@memoize_dict(maxsize=1) +def build_langchain_object_with_caching(data_graph): + """ + Build langchain object from data_graph. + """ + + logger.debug("Building langchain object") + nodes = data_graph["nodes"] + # Add input variables + # nodes = payload.extract_input_variables(nodes) + # Nodes, edges and root node + edges = data_graph["edges"] + graph = Graph(nodes, edges) + + return graph.build() + + def build_langchain_object(data_graph): """ Build langchain object from data_graph. @@ -72,6 +98,30 @@ def process_graph(data_graph: Dict[str, Any]): return {"result": str(result), "thought": thought.strip()} +def process_graph_cached(data_graph: Dict[str, Any]): + """ + Process graph by extracting input variables and replacing ZeroShotPrompt + with PromptTemplate,then run the graph and return the result and thought. + """ + # Load langchain object + message = data_graph.pop("message", "") + is_first_message = len(data_graph.get("chatHistory", [])) == 0 + langchain_object = load_or_build_langchain_object(data_graph, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + logger.debug("Generating result and thought") + result, thought = get_result_and_thought_using_graph(langchain_object, message) + logger.debug("Generated result and thought") + return {"result": str(result), "thought": thought.strip()} + + def get_memory_key(langchain_object): """ Given a LangChain object, this function retrieves the current memory key from the object's memory attribute. diff --git a/src/backend/langflow/interface/vectorStore/base.py b/src/backend/langflow/interface/vectorStore/base.py index 15dfd2886..943903b34 100644 --- a/src/backend/langflow/interface/vectorStore/base.py +++ b/src/backend/langflow/interface/vectorStore/base.py @@ -18,6 +18,7 @@ class VectorstoreCreator(LangChainTypeCreator): try: signature = build_template_from_class(name, vectorstores_type_to_cls_dict) + # TODO: Use FrontendendNode class to build the signature signature["template"] = { "documents": { "type": "TextSplitter",