From a98c1b54e5a317a059cd7b02185db87fdec63743 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 4 Apr 2023 19:36:08 -0300 Subject: [PATCH] fix: fixes for memory and better error message --- src/backend/langflow/interface/run.py | 76 +++++++++++++++++---------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 28afb4a27..060308216 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -55,7 +55,7 @@ def process_graph(data_graph: Dict[str, Any]): if langchain_object is None: # Raise user facing error raise ValueError( - "There was an error loading the flow. Please, check all the nodes and try again." + "There was an error loading the langchain_object. Please, check all the nodes and try again." ) # Generate result and thought @@ -72,47 +72,69 @@ def process_graph(data_graph: Dict[str, Any]): return {"result": str(result), "thought": thought.strip()} -def get_result_and_thought_using_graph(loaded_langchain, message: str): +def fix_memory_inputs_for_intermediate_steps(langchain_object): + """ + Fix memory inputs by replacing the memory key with the input key. + """ + langchain_object.return_intermediate_steps = True + langchain_object.memory.memory_key + input_key = [ + key + for key in langchain_object.input_keys + if key != langchain_object.memory.memory_key + ][0] + # get output_key + output_key = [ + key + for key in langchain_object.output_keys + if key != langchain_object.memory.memory_key + ][0] + # set input_key and output_key in memory + langchain_object.memory.input_key = input_key + langchain_object.memory.output_key = output_key + + +def get_result_and_thought_using_graph(langchain_object, message: str): """Get result and thought from extracted json""" try: - if hasattr(loaded_langchain, "verbose"): - loaded_langchain.verbose = True + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): chat_input = None memory_key = "" - if hasattr(loaded_langchain, "memory"): - mem_vars = loaded_langchain.memory.memory_variables - memory_key = mem_vars[0] if mem_vars else "" + if ( + hasattr(langchain_object, "memory") + and langchain_object.memory is not None + ): + memory_key = langchain_object.memory.memory_key - for key in loaded_langchain.input_keys: - if key != memory_key: + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: chat_input = {key: message} - if hasattr(loaded_langchain, "return_intermediate_steps"): + if hasattr(langchain_object, "return_intermediate_steps"): # https://github.com/hwchase17/langchain/issues/2068 - loaded_langchain.return_intermediate_steps = False - - # I'm not sure about this yet. - function_to_call = None - if hasattr(loaded_langchain, "memory"): - function_to_call = loaded_langchain.predict - elif hasattr(loaded_langchain, "run"): - function_to_call = loaded_langchain.run - else: - function_to_call = loaded_langchain + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + if langchain_object.return_intermediate_steps: + fix_memory_inputs_for_intermediate_steps(langchain_object) try: - output = function_to_call(chat_input) + output = langchain_object(chat_input) except ValueError as exc: - logger.debug("Error: %s", str(exc)) - output = loaded_langchain.run(chat_input) + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + if hasattr(langchain_object, "memory"): + langchain_object.memory.memory_key = memory_key + output = langchain_object(chat_input) intermediate_steps = ( output.get("intermediate_steps", []) if isinstance(output, dict) else [] ) result = ( - output.get(loaded_langchain.output_keys[0]) + output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output ) @@ -129,16 +151,16 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str): def get_result_and_thought(extracted_json: Dict[str, Any], message: str): """Get result and thought from extracted json""" try: - loaded_langchain = loading.load_langchain_type_from_config( + langchain_object = loading.load_langchain_type_from_config( config=extracted_json ) with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): - output = loaded_langchain(message) + output = langchain_object(message) intermediate_steps = ( output.get("intermediate_steps", []) if isinstance(output, dict) else [] ) result = ( - output.get(loaded_langchain.output_keys[0]) + output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output )