diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 24af55588..456d6e483 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -71,6 +71,7 @@ async def process_flow( inputs: Optional[dict] = None, tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 + session_id: Annotated[str, Body(embed=True)] = None, # noqa: F821 session: Session = Depends(get_session), ): """ @@ -90,7 +91,7 @@ async def process_flow( graph_data = process_tweaks(graph_data, tweaks) except Exception as exc: logger.error(f"Error processing tweaks: {exc}") - response = process_graph_cached(graph_data, inputs, clear_cache) + response = process_graph_cached(graph_data, inputs, clear_cache, session_id) return ProcessResponse( result=response, ) diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py index 3deabe9f4..2333eb5f4 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/utils.py @@ -30,6 +30,7 @@ def create_cache_folder(func): def memoize_dict(maxsize=128): cache = OrderedDict() + hash_to_key = {} # Mapping from hash to cache key def decorator(func): @functools.wraps(func) @@ -39,16 +40,29 @@ def memoize_dict(maxsize=128): if key not in cache: result = func(*args, **kwargs) cache[key] = result + hash_to_key[hashed] = key # Store the mapping if len(cache) > maxsize: - cache.popitem(last=False) + oldest_key = next(iter(cache)) + oldest_hash = oldest_key[1] + del cache[oldest_key] + del hash_to_key[oldest_hash] else: result = cache[key] + + wrapper.session_id = hashed # Store hash in the wrapper return result def clear_cache(): cache.clear() + hash_to_key.clear() + + def get_result_by_session_id(session_id): + key = hash_to_key.get(session_id) + return cache.get(key) if key is not None else None wrapper.clear_cache = clear_cache # type: ignore + wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore + wrapper.hash = None wrapper.cache = cache # type: ignore return wrapper diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 8cefb1f44..6d72a736b 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -85,8 +85,27 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]: return list(inputs.values())[0] if len(inputs) == 1 else None +def get_build_result(data_graph, session_id): + # If session_id is provided, load the langchain_object from the session + # using build_sorted_vertices_with_caching.get_result_by_session_id + # if it returns something different than None, return it + # otherwise, build the graph and return the result + if session_id: + logger.debug(f"Loading LangChain object from session {session_id}") + result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id) + if result is not None: + logger.debug("Loaded LangChain object") + return result + + logger.debug("Building langchain object") + return build_sorted_vertices_with_caching(data_graph) + + def process_graph_cached( - data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, ): """ Process graph by extracting input variables and replacing ZeroShotPrompt @@ -96,7 +115,9 @@ def process_graph_cached( if clear_cache: build_sorted_vertices_with_caching.clear_cache() logger.debug("Cleared cache") - langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph) + + langchain_object, artifacts = get_build_result(data_graph, session_id) + logger.debug("Loaded LangChain object") if inputs is None: inputs = {}