fix: function to try to avoid input keys erros with memory
This commit is contained in:
parent
64fb056ba9
commit
fe95790331
1 changed files with 35 additions and 22 deletions
|
|
@ -72,26 +72,41 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def fix_memory_inputs_for_intermediate_steps(langchain_object):
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Fix memory inputs by replacing the memory key with the input key.
|
||||
"""
|
||||
langchain_object.return_intermediate_steps = True
|
||||
langchain_object.memory.memory_key
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key != langchain_object.memory.memory_key
|
||||
][0]
|
||||
# get output_key
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key != langchain_object.memory.memory_key
|
||||
][0]
|
||||
# set input_key and output_key in memory
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
# Possible memory keys
|
||||
# "chat_history", "history"
|
||||
# if memory_key is "chat_history" and input_keys has "history"
|
||||
# we need to replace "chat_history" with "history"
|
||||
mem_key_dict = {
|
||||
"chat_history": "history",
|
||||
"history": "chat_history",
|
||||
}
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
possible_new_mem_key = mem_key_dict.get(memory_key)
|
||||
if possible_new_mem_key is not None:
|
||||
# get input_key
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
# get output_key
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
# set input_key and output_key in memory
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
for input_key in langchain_object.input_keys:
|
||||
if input_key == possible_new_mem_key:
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
|
|
@ -117,17 +132,15 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
|
|||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
if langchain_object.return_intermediate_steps:
|
||||
fix_memory_inputs_for_intermediate_steps(langchain_object)
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
try:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
if hasattr(langchain_object, "memory"):
|
||||
langchain_object.memory.memory_key = memory_key
|
||||
output = langchain_object(chat_input)
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue