Add load_from_db_fields attribute to Vertex class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-06 21:16:27 -03:00
commit 937a50498a

View file

@ -36,6 +36,7 @@ class Vertex:
self.is_task = is_task
self.params = params or {}
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
self.load_from_db_fields: List[str] = []
self.parent_is_top_level = False
@property
@ -53,6 +54,7 @@ class Vertex:
"_built": False,
"parent_node_id": self.parent_node_id,
"parent_is_top_level": self.parent_is_top_level,
"load_from_db_fields": self.load_from_db_fields,
}
def __setstate__(self, state):
@ -72,6 +74,7 @@ class Vertex:
self.task_id: Optional[str] = None
self.parent_node_id = state["parent_node_id"]
self.parent_is_top_level = state["parent_is_top_level"]
self.load_from_db_fields = state["load_from_db_fields"]
def set_top_level(self, top_level_vertices: List[str]) -> None:
self.parent_is_top_level = self.parent_node_id in top_level_vertices
@ -151,60 +154,65 @@ class Vertex:
elif edge.target_id == self.id:
params[param_key] = self.graph.get_vertex(edge.source_id)
for key, value in template_dict.items():
if key in params:
load_from_db_fields = []
for field_name, field in template_dict.items():
if field_name in params:
continue
# Skip _type and any value that has show == False and is not code
# If we don't want to show code but we want to use it
if key == "_type" or (not value.get("show") and key != "code"):
if field_name == "_type" or (not field.get("show") and field_name != "code"):
continue
# If the type is not transformable to a python base class
# then we need to get the edge that connects to this node
if value.get("type") == "file":
if field.get("type") == "file":
# Load the type in value.get('fileTypes') using
# what is inside value.get('content')
# value.get('value') is the file name
if file_path := value.get("file_path"):
params[key] = file_path
if file_path := field.get("file_path"):
params[field_name] = file_path
else:
raise ValueError(f"File path not found for {self.vertex_type}")
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
val = value.get("value")
if value.get("type") == "code":
elif field.get("type") in DIRECT_TYPES and params.get(field_name) is None:
val = field.get("value")
if field.get("type") == "code":
try:
params[key] = ast.literal_eval(val) if val else None
params[field_name] = ast.literal_eval(val) if val else None
except Exception as exc:
logger.debug(f"Error parsing code: {exc}")
params[key] = val
elif value.get("type") in ["dict", "NestedDict"]:
params[field_name] = val
elif field.get("type") in ["dict", "NestedDict"]:
# When dict comes from the frontend it comes as a
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
params[field_name] = {k: v for item in field.get("value", []) for k, v in item.items()}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
params[field_name] = val
elif field.get("type") == "int" and val is not None:
try:
params[key] = int(val)
params[field_name] = int(val)
except ValueError:
params[key] = val
elif value.get("type") == "float" and val is not None:
params[field_name] = val
elif field.get("type") == "float" and val is not None:
try:
params[key] = float(val)
params[field_name] = float(val)
except ValueError:
params[key] = val
elif val is not None and val != "":
params[key] = val
params[field_name] = val
if not value.get("required") and params.get(key) is None:
if value.get("default"):
params[key] = value.get("default")
elif val is not None and val != "":
params[field_name] = val
if field.get("load_from_db"):
load_from_db_fields.append(field_name)
if not field.get("required") and params.get(field_name) is None:
if field.get("default"):
params[field_name] = field.get("default")
else:
params.pop(key, None)
params.pop(field_name, None)
# Add _type to params
self._raw_params = params
self.params = params
self.load_from_db_fields = load_from_db_fields
async def _build(self, user_id=None):
"""
@ -321,6 +329,7 @@ class Vertex:
result = await loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
load_from_db_fields=self.load_from_db_fields,
params=self.params,
user_id=user_id,
)