🐛 fix(loading.py): replace spaces with underscores in tool names to improve consistency and avoid issues

🐛 fix(base.py): catch and log exceptions when fixing memory inputs to prevent crashes
🐛 fix(process.py): check if memory_key attribute exists before accessing it to prevent AttributeError
🐛 fix(memories.py): hide memory_key field for ConversationEntityMemory to improve user experience
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-10 22:42:05 -03:00
commit a90c9fbd34
4 changed files with 19 additions and 3 deletions

View file

@ -68,7 +68,11 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
elif base_type == "prompts":
return instantiate_prompt(node_type, class_object, params)
elif base_type == "tools":
return instantiate_tool(node_type, class_object, params)
tool = instantiate_tool(node_type, class_object, params)
if hasattr(tool, "name") and isinstance(tool, BaseTool):
# tool name shouldn't contain spaces
tool.name = tool.name.replace(" ", "_")
return tool
elif base_type == "toolkits":
return instantiate_toolkit(node_type, class_object, params)
elif base_type == "embeddings":
@ -133,6 +137,9 @@ def instantiate_llm(node_type, class_object, params: Dict):
def instantiate_memory(node_type, class_object, params):
# process input_key and output_key to remove them if
# they are empty strings
if node_type == "ConversationEntityMemory":
params.pop("memory_key", None)
for key in ["input_key", "output_key"]:
if key in params and (params[key] == "" or not params[key]):
params.pop(key)

View file

@ -19,8 +19,11 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = True
try:
fix_memory_inputs(langchain_object)
except Exception as exc:
logger.error(exc)
fix_memory_inputs(langchain_object)
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(inputs, callbacks=async_callbacks)

View file

@ -22,7 +22,10 @@ def fix_memory_inputs(langchain_object):
if not hasattr(langchain_object, "memory") or langchain_object.memory is None:
return
try:
if langchain_object.memory.memory_key in langchain_object.input_variables:
if (
hasattr(langchain_object.memory, "memory_key")
and langchain_object.memory.memory_key in langchain_object.input_variables
):
return
except AttributeError:
input_variables = (

View file

@ -90,6 +90,9 @@ class MemoryFrontendNode(FrontendNode):
field.show = True
if field.name == "entity_store":
field.show = False
if name == "ConversationEntityMemory" and field.name == "memory_key":
field.show = False
field.required = False
class PostgresChatMessageHistoryFrontendNode(MemoryFrontendNode):