fix: fixes for memory and better error message

This commit is contained in:
Gabriel Almeida 2023-04-04 19:36:08 -03:00
commit a98c1b54e5

View file

@ -55,7 +55,7 @@ def process_graph(data_graph: Dict[str, Any]):
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the flow. Please, check all the nodes and try again."
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
@ -72,47 +72,69 @@ def process_graph(data_graph: Dict[str, Any]):
return {"result": str(result), "thought": thought.strip()}
def get_result_and_thought_using_graph(loaded_langchain, message: str):
def fix_memory_inputs_for_intermediate_steps(langchain_object):
"""
Fix memory inputs by replacing the memory key with the input key.
"""
langchain_object.return_intermediate_steps = True
langchain_object.memory.memory_key
input_key = [
key
for key in langchain_object.input_keys
if key != langchain_object.memory.memory_key
][0]
# get output_key
output_key = [
key
for key in langchain_object.output_keys
if key != langchain_object.memory.memory_key
][0]
# set input_key and output_key in memory
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
def get_result_and_thought_using_graph(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(loaded_langchain, "verbose"):
loaded_langchain.verbose = True
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
chat_input = None
memory_key = ""
if hasattr(loaded_langchain, "memory"):
mem_vars = loaded_langchain.memory.memory_variables
memory_key = mem_vars[0] if mem_vars else ""
if (
hasattr(langchain_object, "memory")
and langchain_object.memory is not None
):
memory_key = langchain_object.memory.memory_key
for key in loaded_langchain.input_keys:
if key != memory_key:
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
if hasattr(loaded_langchain, "return_intermediate_steps"):
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
loaded_langchain.return_intermediate_steps = False
# I'm not sure about this yet.
function_to_call = None
if hasattr(loaded_langchain, "memory"):
function_to_call = loaded_langchain.predict
elif hasattr(loaded_langchain, "run"):
function_to_call = loaded_langchain.run
else:
function_to_call = loaded_langchain
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
if langchain_object.return_intermediate_steps:
fix_memory_inputs_for_intermediate_steps(langchain_object)
try:
output = function_to_call(chat_input)
output = langchain_object(chat_input)
except ValueError as exc:
logger.debug("Error: %s", str(exc))
output = loaded_langchain.run(chat_input)
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
if hasattr(langchain_object, "memory"):
langchain_object.memory.memory_key = memory_key
output = langchain_object(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(loaded_langchain.output_keys[0])
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
@ -129,16 +151,16 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
"""Get result and thought from extracted json"""
try:
loaded_langchain = loading.load_langchain_type_from_config(
langchain_object = loading.load_langchain_type_from_config(
config=extracted_json
)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
output = loaded_langchain(message)
output = langchain_object(message)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(loaded_langchain.output_keys[0])
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)