From 199ea3b1b84d0914fabcba38a5965c71b16fc2e5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 16 Aug 2023 21:29:17 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(process.py):=20refactor=20pr?= =?UTF-8?q?ocess=5Fgraph=5Fcached=20function=20for=20better=20readability?= =?UTF-8?q?=20and=20maintainability=20=E2=9C=A8=20feat(process.py):=20add?= =?UTF-8?q?=20clear=5Fcaches=5Fif=5Fneeded=20function=20to=20clear=20cache?= =?UTF-8?q?=20if=20clear=5Fcache=20flag=20is=20set=20=E2=9C=A8=20feat(proc?= =?UTF-8?q?ess.py):=20add=20load=5Flangchain=5Fobject=20function=20to=20lo?= =?UTF-8?q?ad=20langchain=5Fobject=20and=20artifacts=20from=20data=5Fgraph?= =?UTF-8?q?=20=E2=9C=A8=20feat(process.py):=20add=20process=5Finputs=20fun?= =?UTF-8?q?ction=20to=20process=20inputs=20and=20add=20artifacts=20to=20in?= =?UTF-8?q?puts=20=E2=9C=A8=20feat(process.py):=20add=20generate=5Fresult?= =?UTF-8?q?=20function=20to=20generate=20result=20and=20thought=20based=20?= =?UTF-8?q?on=20langchain=5Fobject=20and=20inputs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/processing/process.py | 65 +++++++++++++--------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 5549a2ae1..dc632e76b 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -101,45 +101,39 @@ def get_build_result(data_graph, session_id): return build_sorted_vertices_with_caching(data_graph) -def process_graph_cached( - data_graph: Dict[str, Any], - inputs: Optional[dict] = None, - clear_cache=False, - session_id=None, -): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - # Load langchain object +def clear_caches_if_needed(clear_cache: bool): if clear_cache: build_sorted_vertices_with_caching.clear_cache() logger.debug("Cleared cache") + +def load_langchain_object( + data_graph: Dict[str, Any], session_id: str +) -> Tuple[Union[Chain, VectorStore], Dict[str, Any]]: langchain_object, artifacts = get_build_result(data_graph, session_id) session_id = build_sorted_vertices_with_caching.hash - logger.debug("Loaded LangChain object") - if inputs is None: - inputs = {} - - # Add artifacts to inputs - # artifacts can be documents loaded when building - # the flow - for ( - key, - value, - ) in artifacts.items(): - if key not in inputs or not inputs[key]: - inputs[key] = value if langchain_object is None: - # Raise user facing error raise ValueError( "There was an error loading the langchain_object. Please, check all the nodes and try again." ) - # Generate result and thought + return langchain_object, artifacts, session_id + + +def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict: + if inputs is None: + inputs = {} + + for key, value in artifacts.items(): + if key not in inputs or not inputs[key]: + inputs[key] = value + + return inputs + + +def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): if isinstance(langchain_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -152,6 +146,25 @@ def process_graph_cached( raise ValueError( f"Unknown langchain_object type: {type(langchain_object).__name__}" ) + + return result + + +def process_graph_cached( + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, +) -> Tuple[Any, str]: + clear_caches_if_needed(clear_cache) + # If session_id is provided, load the langchain_object from the session + # else build the graph and return the result and the new session_id + langchain_object, artifacts, session_id = load_langchain_object( + data_graph, session_id + ) + processed_inputs = process_inputs(inputs, artifacts) + result = generate_result(langchain_object, processed_inputs) + return result, session_id