🔧 fix(endpoints.py): add session_id parameter to process_flow endpoint to support session-based caching
🔧 fix(utils.py): add hash_to_key mapping to memoize_dict decorator to support retrieving cache result by session_id 🔧 fix(process.py): add session_id parameter to process_graph_cached function to support session-based caching Implements #730
This commit is contained in:
parent
aff3d53021
commit
fc5670f534
3 changed files with 40 additions and 4 deletions
|
|
@ -71,6 +71,7 @@ async def process_flow(
|
|||
inputs: Optional[dict] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[str, Body(embed=True)] = None, # noqa: F821
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -90,7 +91,7 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
response = process_graph_cached(graph_data, inputs, clear_cache)
|
||||
response = process_graph_cached(graph_data, inputs, clear_cache, session_id)
|
||||
return ProcessResponse(
|
||||
result=response,
|
||||
)
|
||||
|
|
|
|||
16
src/backend/langflow/cache/utils.py
vendored
16
src/backend/langflow/cache/utils.py
vendored
|
|
@ -30,6 +30,7 @@ def create_cache_folder(func):
|
|||
|
||||
def memoize_dict(maxsize=128):
|
||||
cache = OrderedDict()
|
||||
hash_to_key = {} # Mapping from hash to cache key
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
|
|
@ -39,16 +40,29 @@ def memoize_dict(maxsize=128):
|
|||
if key not in cache:
|
||||
result = func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
hash_to_key[hashed] = key # Store the mapping
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
oldest_key = next(iter(cache))
|
||||
oldest_hash = oldest_key[1]
|
||||
del cache[oldest_key]
|
||||
del hash_to_key[oldest_hash]
|
||||
else:
|
||||
result = cache[key]
|
||||
|
||||
wrapper.session_id = hashed # Store hash in the wrapper
|
||||
return result
|
||||
|
||||
def clear_cache():
|
||||
cache.clear()
|
||||
hash_to_key.clear()
|
||||
|
||||
def get_result_by_session_id(session_id):
|
||||
key = hash_to_key.get(session_id)
|
||||
return cache.get(key) if key is not None else None
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore
|
||||
wrapper.hash = None
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
|
|
|
|||
|
|
@ -85,8 +85,27 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
|
|||
return list(inputs.values())[0] if len(inputs) == 1 else None
|
||||
|
||||
|
||||
def get_build_result(data_graph, session_id):
|
||||
# If session_id is provided, load the langchain_object from the session
|
||||
# using build_sorted_vertices_with_caching.get_result_by_session_id
|
||||
# if it returns something different than None, return it
|
||||
# otherwise, build the graph and return the result
|
||||
if session_id:
|
||||
logger.debug(f"Loading LangChain object from session {session_id}")
|
||||
result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id)
|
||||
if result is not None:
|
||||
logger.debug("Loaded LangChain object")
|
||||
return result
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
return build_sorted_vertices_with_caching(data_graph)
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
|
|
@ -96,7 +115,9 @@ def process_graph_cached(
|
|||
if clear_cache:
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
logger.debug("Cleared cache")
|
||||
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
|
||||
|
||||
langchain_object, artifacts = get_build_result(data_graph, session_id)
|
||||
|
||||
logger.debug("Loaded LangChain object")
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue