From fc5670f53447f85963c184434c23d568497cb387 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 8 Aug 2023 07:16:41 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(endpoints.py):=20add=20sessi?= =?UTF-8?q?on=5Fid=20parameter=20to=20process=5Fflow=20endpoint=20to=20sup?= =?UTF-8?q?port=20session-based=20caching=20=F0=9F=94=A7=20fix(utils.py):?= =?UTF-8?q?=20add=20hash=5Fto=5Fkey=20mapping=20to=20memoize=5Fdict=20deco?= =?UTF-8?q?rator=20to=20support=20retrieving=20cache=20result=20by=20sessi?= =?UTF-8?q?on=5Fid=20=F0=9F=94=A7=20fix(process.py):=20add=20session=5Fid?= =?UTF-8?q?=20parameter=20to=20process=5Fgraph=5Fcached=20function=20to=20?= =?UTF-8?q?support=20session-based=20caching?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements #730 --- src/backend/langflow/api/v1/endpoints.py | 3 ++- src/backend/langflow/cache/utils.py | 16 +++++++++++++- src/backend/langflow/processing/process.py | 25 ++++++++++++++++++++-- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 24af55588..456d6e483 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -71,6 +71,7 @@ async def process_flow( inputs: Optional[dict] = None, tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 + session_id: Annotated[str, Body(embed=True)] = None, # noqa: F821 session: Session = Depends(get_session), ): """ @@ -90,7 +91,7 @@ async def process_flow( graph_data = process_tweaks(graph_data, tweaks) except Exception as exc: logger.error(f"Error processing tweaks: {exc}") - response = process_graph_cached(graph_data, inputs, clear_cache) + response = process_graph_cached(graph_data, inputs, clear_cache, session_id) return ProcessResponse( result=response, ) diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py index 3deabe9f4..2333eb5f4 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/utils.py @@ -30,6 +30,7 @@ def create_cache_folder(func): def memoize_dict(maxsize=128): cache = OrderedDict() + hash_to_key = {} # Mapping from hash to cache key def decorator(func): @functools.wraps(func) @@ -39,16 +40,29 @@ def memoize_dict(maxsize=128): if key not in cache: result = func(*args, **kwargs) cache[key] = result + hash_to_key[hashed] = key # Store the mapping if len(cache) > maxsize: - cache.popitem(last=False) + oldest_key = next(iter(cache)) + oldest_hash = oldest_key[1] + del cache[oldest_key] + del hash_to_key[oldest_hash] else: result = cache[key] + + wrapper.session_id = hashed # Store hash in the wrapper return result def clear_cache(): cache.clear() + hash_to_key.clear() + + def get_result_by_session_id(session_id): + key = hash_to_key.get(session_id) + return cache.get(key) if key is not None else None wrapper.clear_cache = clear_cache # type: ignore + wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore + wrapper.hash = None wrapper.cache = cache # type: ignore return wrapper diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 8cefb1f44..6d72a736b 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -85,8 +85,27 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]: return list(inputs.values())[0] if len(inputs) == 1 else None +def get_build_result(data_graph, session_id): + # If session_id is provided, load the langchain_object from the session + # using build_sorted_vertices_with_caching.get_result_by_session_id + # if it returns something different than None, return it + # otherwise, build the graph and return the result + if session_id: + logger.debug(f"Loading LangChain object from session {session_id}") + result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id) + if result is not None: + logger.debug("Loaded LangChain object") + return result + + logger.debug("Building langchain object") + return build_sorted_vertices_with_caching(data_graph) + + def process_graph_cached( - data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, ): """ Process graph by extracting input variables and replacing ZeroShotPrompt @@ -96,7 +115,9 @@ def process_graph_cached( if clear_cache: build_sorted_vertices_with_caching.clear_cache() logger.debug("Cleared cache") - langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph) + + langchain_object, artifacts = get_build_result(data_graph, session_id) + logger.debug("Loaded LangChain object") if inputs is None: inputs = {}