🔧 refactor(process.py): remove unused imports and variables, refactor function names for clarity

 feat(process.py): introduce SessionManager to handle loading and updating langchain_object sessions
🐛 fix(process.py): update cache with the updated langchain_object after processing graph
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-18 08:57:38 -03:00
commit 7f217bd518

View file

@ -2,10 +2,11 @@ from pathlib import Path
from langchain.schema import AgentAction
import json
from langflow.interface.run import (
build_sorted_vertices_with_caching,
build_sorted_vertices,
get_memory_key,
update_memory_keys,
)
from langflow.services.utils import get_session_manager
from langflow.utils.logger import logger
from langflow.graph import Graph
from langchain.chains.base import Chain
@ -92,18 +93,18 @@ def get_build_result(data_graph, session_id):
# otherwise, build the graph and return the result
if session_id:
logger.debug(f"Loading LangChain object from session {session_id}")
result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id)
result = build_sorted_vertices(data_graph=data_graph, session_id=session_id)
if result is not None:
logger.debug("Loaded LangChain object")
return result
logger.debug("Building langchain object")
return build_sorted_vertices_with_caching(data_graph)
return build_sorted_vertices(data_graph, session_id)
def clear_caches_if_needed(clear_cache: bool):
if clear_cache:
build_sorted_vertices_with_caching.clear_cache()
build_sorted_vertices.clear_cache()
logger.debug("Cleared cache")
@ -111,7 +112,6 @@ def load_langchain_object(
data_graph: Dict[str, Any], session_id: str
) -> Tuple[Union[Chain, VectorStore], Dict[str, Any], str]:
langchain_object, artifacts = get_build_result(data_graph, session_id)
session_id = build_sorted_vertices_with_caching.session_id
logger.debug("Loaded LangChain object")
if langchain_object is None:
@ -139,6 +139,7 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
raise ValueError("Inputs must be provided for a Chain")
logger.debug("Generating result and thought")
result = get_result_and_thought(langchain_object, inputs)
logger.debug("Generated result and thought")
elif isinstance(langchain_object, VectorStore):
result = langchain_object.search(**inputs)
@ -150,6 +151,28 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
return result
# def process_graph_cached(
# data_graph: Dict[str, Any],
# inputs: Optional[dict] = None,
# clear_cache=False,
# session_id=None,
# ) -> Tuple[Any, str]:
# clear_caches_if_needed(clear_cache)
# # If session_id is provided, load the langchain_object from the session
# # else build the graph and return the result and the new session_id
# langchain_object, artifacts, session_id = load_langchain_object(
# data_graph, session_id
# )
# processed_inputs = process_inputs(inputs, artifacts)
# result = generate_result(langchain_object, processed_inputs)
# if result:
# # we need to update the cache with the updated langchain_object
# build_sorted_vertices_with_caching.update_cache(
# session_id, (langchain_object, artifacts)
# )
# return result, session_id
def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
@ -157,13 +180,16 @@ def process_graph_cached(
session_id=None,
) -> Tuple[Any, str]:
clear_caches_if_needed(clear_cache)
# If session_id is provided, load the langchain_object from the session
# else build the graph and return the result and the new session_id
langchain_object, artifacts, session_id = load_langchain_object(
data_graph, session_id
)
session_manager = get_session_manager()
# Load the graph using SessionManager
langchain_object, artifacts = session_manager.load_session(session_id, data_graph)
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(langchain_object, processed_inputs)
# langchain_object is now updated with the new memory
# we need to update the cache with the updated langchain_object
session_manager.update_session(
session_id, data_graph, (langchain_object, artifacts)
)
return result, session_id