diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 5549a2ae1..dc632e76b 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -101,45 +101,39 @@ def get_build_result(data_graph, session_id): return build_sorted_vertices_with_caching(data_graph) -def process_graph_cached( - data_graph: Dict[str, Any], - inputs: Optional[dict] = None, - clear_cache=False, - session_id=None, -): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - # Load langchain object +def clear_caches_if_needed(clear_cache: bool): if clear_cache: build_sorted_vertices_with_caching.clear_cache() logger.debug("Cleared cache") + +def load_langchain_object( + data_graph: Dict[str, Any], session_id: str +) -> Tuple[Union[Chain, VectorStore], Dict[str, Any]]: langchain_object, artifacts = get_build_result(data_graph, session_id) session_id = build_sorted_vertices_with_caching.hash - logger.debug("Loaded LangChain object") - if inputs is None: - inputs = {} - - # Add artifacts to inputs - # artifacts can be documents loaded when building - # the flow - for ( - key, - value, - ) in artifacts.items(): - if key not in inputs or not inputs[key]: - inputs[key] = value if langchain_object is None: - # Raise user facing error raise ValueError( "There was an error loading the langchain_object. Please, check all the nodes and try again." ) - # Generate result and thought + return langchain_object, artifacts, session_id + + +def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict: + if inputs is None: + inputs = {} + + for key, value in artifacts.items(): + if key not in inputs or not inputs[key]: + inputs[key] = value + + return inputs + + +def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): if isinstance(langchain_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -152,6 +146,25 @@ def process_graph_cached( raise ValueError( f"Unknown langchain_object type: {type(langchain_object).__name__}" ) + + return result + + +def process_graph_cached( + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, +) -> Tuple[Any, str]: + clear_caches_if_needed(clear_cache) + # If session_id is provided, load the langchain_object from the session + # else build the graph and return the result and the new session_id + langchain_object, artifacts, session_id = load_langchain_object( + data_graph, session_id + ) + processed_inputs = process_inputs(inputs, artifacts) + result = generate_result(langchain_object, processed_inputs) + return result, session_id