style: 🐛 mypy correction

This commit is contained in:
Gabriel Almeida 2023-03-27 17:31:40 -03:00
commit 37c83a9de2

View file

@ -183,7 +183,7 @@ def get_tool_by_name(name: str):
return tools[name]
def get_tool_params(tool, **kwargs):
def get_tool_params(tool, **kwargs) -> Dict | None:
# Parse the function code into an abstract syntax tree
# Define if it is a function or a class
if inspect.isfunction(tool):
@ -192,9 +192,11 @@ def get_tool_params(tool, **kwargs):
# Get the parameters necessary to
# instantiate the class
return get_class_tool_params(tool, **kwargs)
else:
raise ValueError("Tool must be a function or class.")
def get_func_tool_params(func, **kwargs):
def get_func_tool_params(func, **kwargs) -> Dict | None:
tree = ast.parse(inspect.getsource(func))
# Iterate over the statements in the abstract syntax tree
@ -237,7 +239,7 @@ def get_func_tool_params(func, **kwargs):
return None
def get_class_tool_params(cls, **kwargs):
def get_class_tool_params(cls, **kwargs) -> Dict | None:
tree = ast.parse(inspect.getsource(cls))
tool_params = {}
@ -259,12 +261,14 @@ def get_class_tool_params(cls, **kwargs):
# If there is not default value, set it to an empty string
else:
try:
tool_params[arg.arg] = ast.literal_eval(arg.annotation)
annotation = ast.literal_eval(arg.annotation) # type: ignore
tool_params[arg.arg] = annotation
except ValueError:
tool_params[arg.arg] = ""
elif not cls == Tool and isinstance(stmt, ast.AnnAssign):
# Get the attribute name and the annotation
elif cls != Tool and isinstance(stmt, ast.AnnAssign):
# Get the attribute name and the annotation
tool_params[stmt.target.id] = ""
tool_params[stmt.target.id] = "" # type: ignore
return tool_params