fix: function to try to avoid input keys erros with memory

This commit is contained in:
Gabriel Almeida 2023-04-04 21:46:05 -03:00
commit fe95790331

View file

@ -72,26 +72,41 @@ def process_graph(data_graph: Dict[str, Any]):
return {"result": str(result), "thought": thought.strip()}
def fix_memory_inputs_for_intermediate_steps(langchain_object):
def fix_memory_inputs(langchain_object):
"""
Fix memory inputs by replacing the memory key with the input key.
"""
langchain_object.return_intermediate_steps = True
langchain_object.memory.memory_key
input_key = [
key
for key in langchain_object.input_keys
if key != langchain_object.memory.memory_key
][0]
# get output_key
output_key = [
key
for key in langchain_object.output_keys
if key != langchain_object.memory.memory_key
][0]
# set input_key and output_key in memory
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
# Possible memory keys
# "chat_history", "history"
# if memory_key is "chat_history" and input_keys has "history"
# we need to replace "chat_history" with "history"
mem_key_dict = {
"chat_history": "history",
"history": "chat_history",
}
memory_key = langchain_object.memory.memory_key
possible_new_mem_key = mem_key_dict.get(memory_key)
if possible_new_mem_key is not None:
# get input_key
input_key = [
key
for key in langchain_object.input_keys
if key not in [memory_key, possible_new_mem_key]
][0]
# get output_key
output_key = [
key
for key in langchain_object.output_keys
if key not in [memory_key, possible_new_mem_key]
][0]
# set input_key and output_key in memory
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
for input_key in langchain_object.input_keys:
if input_key == possible_new_mem_key:
langchain_object.memory.memory_key = possible_new_mem_key
def get_result_and_thought_using_graph(langchain_object, message: str):
@ -117,17 +132,15 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
if langchain_object.return_intermediate_steps:
fix_memory_inputs_for_intermediate_steps(langchain_object)
fix_memory_inputs(langchain_object)
try:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
if hasattr(langchain_object, "memory"):
langchain_object.memory.memory_key = memory_key
output = langchain_object(chat_input)
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []