🔧 fix(endpoints.py): add session_id parameter to process_flow endpoint to support session-based caching

🔧 fix(utils.py): add hash_to_key mapping to memoize_dict decorator to support retrieving cache result by session_id
🔧 fix(process.py): add session_id parameter to process_graph_cached function to support session-based caching

Implements #730
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-08 07:16:41 -03:00
commit fc5670f534
3 changed files with 40 additions and 4 deletions

View file

@ -71,6 +71,7 @@ async def process_flow(
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[str, Body(embed=True)] = None, # noqa: F821
session: Session = Depends(get_session),
):
"""
@ -90,7 +91,7 @@ async def process_flow(
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
response = process_graph_cached(graph_data, inputs, clear_cache)
response = process_graph_cached(graph_data, inputs, clear_cache, session_id)
return ProcessResponse(
result=response,
)

View file

@ -30,6 +30,7 @@ def create_cache_folder(func):
def memoize_dict(maxsize=128):
cache = OrderedDict()
hash_to_key = {} # Mapping from hash to cache key
def decorator(func):
@functools.wraps(func)
@ -39,16 +40,29 @@ def memoize_dict(maxsize=128):
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
hash_to_key[hashed] = key # Store the mapping
if len(cache) > maxsize:
cache.popitem(last=False)
oldest_key = next(iter(cache))
oldest_hash = oldest_key[1]
del cache[oldest_key]
del hash_to_key[oldest_hash]
else:
result = cache[key]
wrapper.session_id = hashed # Store hash in the wrapper
return result
def clear_cache():
cache.clear()
hash_to_key.clear()
def get_result_by_session_id(session_id):
key = hash_to_key.get(session_id)
return cache.get(key) if key is not None else None
wrapper.clear_cache = clear_cache # type: ignore
wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore
wrapper.hash = None
wrapper.cache = cache # type: ignore
return wrapper

View file

@ -85,8 +85,27 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
return list(inputs.values())[0] if len(inputs) == 1 else None
def get_build_result(data_graph, session_id):
# If session_id is provided, load the langchain_object from the session
# using build_sorted_vertices_with_caching.get_result_by_session_id
# if it returns something different than None, return it
# otherwise, build the graph and return the result
if session_id:
logger.debug(f"Loading LangChain object from session {session_id}")
result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id)
if result is not None:
logger.debug("Loaded LangChain object")
return result
logger.debug("Building langchain object")
return build_sorted_vertices_with_caching(data_graph)
def process_graph_cached(
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,
session_id=None,
):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
@ -96,7 +115,9 @@ def process_graph_cached(
if clear_cache:
build_sorted_vertices_with_caching.clear_cache()
logger.debug("Cleared cache")
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
langchain_object, artifacts = get_build_result(data_graph, session_id)
logger.debug("Loaded LangChain object")
if inputs is None:
inputs = {}