diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 060308216..773b21b7f 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -72,26 +72,41 @@ def process_graph(data_graph: Dict[str, Any]): return {"result": str(result), "thought": thought.strip()} -def fix_memory_inputs_for_intermediate_steps(langchain_object): +def fix_memory_inputs(langchain_object): """ Fix memory inputs by replacing the memory key with the input key. """ - langchain_object.return_intermediate_steps = True - langchain_object.memory.memory_key - input_key = [ - key - for key in langchain_object.input_keys - if key != langchain_object.memory.memory_key - ][0] - # get output_key - output_key = [ - key - for key in langchain_object.output_keys - if key != langchain_object.memory.memory_key - ][0] - # set input_key and output_key in memory - langchain_object.memory.input_key = input_key - langchain_object.memory.output_key = output_key + # Possible memory keys + # "chat_history", "history" + # if memory_key is "chat_history" and input_keys has "history" + # we need to replace "chat_history" with "history" + mem_key_dict = { + "chat_history": "history", + "history": "chat_history", + } + memory_key = langchain_object.memory.memory_key + possible_new_mem_key = mem_key_dict.get(memory_key) + if possible_new_mem_key is not None: + # get input_key + input_key = [ + key + for key in langchain_object.input_keys + if key not in [memory_key, possible_new_mem_key] + ][0] + + # get output_key + output_key = [ + key + for key in langchain_object.output_keys + if key not in [memory_key, possible_new_mem_key] + ][0] + + # set input_key and output_key in memory + langchain_object.memory.input_key = input_key + langchain_object.memory.output_key = output_key + for input_key in langchain_object.input_keys: + if input_key == possible_new_mem_key: + langchain_object.memory.memory_key = possible_new_mem_key def get_result_and_thought_using_graph(langchain_object, message: str): @@ -117,17 +132,15 @@ def get_result_and_thought_using_graph(langchain_object, message: str): # Deactivating until we have a frontend solution # to display intermediate steps langchain_object.return_intermediate_steps = False - if langchain_object.return_intermediate_steps: - fix_memory_inputs_for_intermediate_steps(langchain_object) + + fix_memory_inputs(langchain_object) try: output = langchain_object(chat_input) except ValueError as exc: # make the error message more informative logger.debug(f"Error: {str(exc)}") - if hasattr(langchain_object, "memory"): - langchain_object.memory.memory_key = memory_key - output = langchain_object(chat_input) + output = langchain_object.run(chat_input) intermediate_steps = ( output.get("intermediate_steps", []) if isinstance(output, dict) else []