🐛 fix(process.py): refactor fix_memory_inputs function to improve readability and reduce nesting

The fix_memory_inputs function was refactored to reduce nesting and improve readability. The function now checks if the langchain_object has a memory attribute and if it is not None before proceeding. The try-except block was also refactored to reduce nesting.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-19 12:02:15 -03:00
commit 11185affdd

View file

@ -20,22 +20,23 @@ def fix_memory_inputs(langchain_object):
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
get_memory_key function and updates the memory keys using the update_memory_keys function. get_memory_key function and updates the memory keys using the update_memory_keys function.
""" """
if hasattr(langchain_object, "memory") and langchain_object.memory is not None: if not hasattr(langchain_object, "memory") or langchain_object.memory is None:
try: return
if langchain_object.memory.memory_key in langchain_object.input_variables: try:
return if langchain_object.memory.memory_key in langchain_object.input_variables:
except AttributeError: return
input_variables = ( except AttributeError:
langchain_object.prompt.input_variables input_variables = (
if hasattr(langchain_object, "prompt") langchain_object.prompt.input_variables
else langchain_object.input_keys if hasattr(langchain_object, "prompt")
) else langchain_object.input_keys
if langchain_object.memory.memory_key in input_variables: )
return if langchain_object.memory.memory_key in input_variables:
return
possible_new_mem_key = get_memory_key(langchain_object) possible_new_mem_key = get_memory_key(langchain_object)
if possible_new_mem_key is not None: if possible_new_mem_key is not None:
update_memory_keys(langchain_object, possible_new_mem_key) update_memory_keys(langchain_object, possible_new_mem_key)
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: