fix: fixes for memory and better error message
This commit is contained in:
parent
f6fc9f2c3b
commit
a98c1b54e5
1 changed files with 49 additions and 27 deletions
|
|
@ -55,7 +55,7 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the flow. Please, check all the nodes and try again."
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
|
|
@ -72,47 +72,69 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(loaded_langchain, message: str):
|
||||
def fix_memory_inputs_for_intermediate_steps(langchain_object):
|
||||
"""
|
||||
Fix memory inputs by replacing the memory key with the input key.
|
||||
"""
|
||||
langchain_object.return_intermediate_steps = True
|
||||
langchain_object.memory.memory_key
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key != langchain_object.memory.memory_key
|
||||
][0]
|
||||
# get output_key
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key != langchain_object.memory.memory_key
|
||||
][0]
|
||||
# set input_key and output_key in memory
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(loaded_langchain, "verbose"):
|
||||
loaded_langchain.verbose = True
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(loaded_langchain, "memory"):
|
||||
mem_vars = loaded_langchain.memory.memory_variables
|
||||
memory_key = mem_vars[0] if mem_vars else ""
|
||||
if (
|
||||
hasattr(langchain_object, "memory")
|
||||
and langchain_object.memory is not None
|
||||
):
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
for key in loaded_langchain.input_keys:
|
||||
if key != memory_key:
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
|
||||
if hasattr(loaded_langchain, "return_intermediate_steps"):
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
loaded_langchain.return_intermediate_steps = False
|
||||
|
||||
# I'm not sure about this yet.
|
||||
function_to_call = None
|
||||
if hasattr(loaded_langchain, "memory"):
|
||||
function_to_call = loaded_langchain.predict
|
||||
elif hasattr(loaded_langchain, "run"):
|
||||
function_to_call = loaded_langchain.run
|
||||
else:
|
||||
function_to_call = loaded_langchain
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
if langchain_object.return_intermediate_steps:
|
||||
fix_memory_inputs_for_intermediate_steps(langchain_object)
|
||||
|
||||
try:
|
||||
output = function_to_call(chat_input)
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
logger.debug("Error: %s", str(exc))
|
||||
output = loaded_langchain.run(chat_input)
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
if hasattr(langchain_object, "memory"):
|
||||
langchain_object.memory.memory_key = memory_key
|
||||
output = langchain_object(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(loaded_langchain.output_keys[0])
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
|
@ -129,16 +151,16 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
|
|||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
loaded_langchain = loading.load_langchain_type_from_config(
|
||||
langchain_object = loading.load_langchain_type_from_config(
|
||||
config=extracted_json
|
||||
)
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
output = loaded_langchain(message)
|
||||
output = langchain_object(message)
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
result = (
|
||||
output.get(loaded_langchain.output_keys[0])
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue