Add load_from_db_fields attribute to Vertex class
This commit is contained in:
parent
84f4c32076
commit
937a50498a
1 changed files with 35 additions and 26 deletions
|
|
@ -36,6 +36,7 @@ class Vertex:
|
|||
self.is_task = is_task
|
||||
self.params = params or {}
|
||||
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
|
||||
self.load_from_db_fields: List[str] = []
|
||||
self.parent_is_top_level = False
|
||||
|
||||
@property
|
||||
|
|
@ -53,6 +54,7 @@ class Vertex:
|
|||
"_built": False,
|
||||
"parent_node_id": self.parent_node_id,
|
||||
"parent_is_top_level": self.parent_is_top_level,
|
||||
"load_from_db_fields": self.load_from_db_fields,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
@ -72,6 +74,7 @@ class Vertex:
|
|||
self.task_id: Optional[str] = None
|
||||
self.parent_node_id = state["parent_node_id"]
|
||||
self.parent_is_top_level = state["parent_is_top_level"]
|
||||
self.load_from_db_fields = state["load_from_db_fields"]
|
||||
|
||||
def set_top_level(self, top_level_vertices: List[str]) -> None:
|
||||
self.parent_is_top_level = self.parent_node_id in top_level_vertices
|
||||
|
|
@ -151,60 +154,65 @@ class Vertex:
|
|||
elif edge.target_id == self.id:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
|
||||
for key, value in template_dict.items():
|
||||
if key in params:
|
||||
load_from_db_fields = []
|
||||
for field_name, field in template_dict.items():
|
||||
if field_name in params:
|
||||
continue
|
||||
# Skip _type and any value that has show == False and is not code
|
||||
# If we don't want to show code but we want to use it
|
||||
if key == "_type" or (not value.get("show") and key != "code"):
|
||||
if field_name == "_type" or (not field.get("show") and field_name != "code"):
|
||||
continue
|
||||
# If the type is not transformable to a python base class
|
||||
# then we need to get the edge that connects to this node
|
||||
if value.get("type") == "file":
|
||||
if field.get("type") == "file":
|
||||
# Load the type in value.get('fileTypes') using
|
||||
# what is inside value.get('content')
|
||||
# value.get('value') is the file name
|
||||
if file_path := value.get("file_path"):
|
||||
params[key] = file_path
|
||||
if file_path := field.get("file_path"):
|
||||
params[field_name] = file_path
|
||||
else:
|
||||
raise ValueError(f"File path not found for {self.vertex_type}")
|
||||
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
|
||||
val = value.get("value")
|
||||
if value.get("type") == "code":
|
||||
elif field.get("type") in DIRECT_TYPES and params.get(field_name) is None:
|
||||
val = field.get("value")
|
||||
if field.get("type") == "code":
|
||||
try:
|
||||
params[key] = ast.literal_eval(val) if val else None
|
||||
params[field_name] = ast.literal_eval(val) if val else None
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error parsing code: {exc}")
|
||||
params[key] = val
|
||||
elif value.get("type") in ["dict", "NestedDict"]:
|
||||
params[field_name] = val
|
||||
elif field.get("type") in ["dict", "NestedDict"]:
|
||||
# When dict comes from the frontend it comes as a
|
||||
# list of dicts, so we need to convert it to a dict
|
||||
# before passing it to the build method
|
||||
if isinstance(val, list):
|
||||
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
|
||||
params[field_name] = {k: v for item in field.get("value", []) for k, v in item.items()}
|
||||
elif isinstance(val, dict):
|
||||
params[key] = val
|
||||
elif value.get("type") == "int" and val is not None:
|
||||
params[field_name] = val
|
||||
elif field.get("type") == "int" and val is not None:
|
||||
try:
|
||||
params[key] = int(val)
|
||||
params[field_name] = int(val)
|
||||
except ValueError:
|
||||
params[key] = val
|
||||
elif value.get("type") == "float" and val is not None:
|
||||
params[field_name] = val
|
||||
elif field.get("type") == "float" and val is not None:
|
||||
try:
|
||||
params[key] = float(val)
|
||||
params[field_name] = float(val)
|
||||
except ValueError:
|
||||
params[key] = val
|
||||
elif val is not None and val != "":
|
||||
params[key] = val
|
||||
params[field_name] = val
|
||||
|
||||
if not value.get("required") and params.get(key) is None:
|
||||
if value.get("default"):
|
||||
params[key] = value.get("default")
|
||||
elif val is not None and val != "":
|
||||
params[field_name] = val
|
||||
if field.get("load_from_db"):
|
||||
load_from_db_fields.append(field_name)
|
||||
|
||||
if not field.get("required") and params.get(field_name) is None:
|
||||
if field.get("default"):
|
||||
params[field_name] = field.get("default")
|
||||
else:
|
||||
params.pop(key, None)
|
||||
params.pop(field_name, None)
|
||||
# Add _type to params
|
||||
self._raw_params = params
|
||||
self.params = params
|
||||
self.load_from_db_fields = load_from_db_fields
|
||||
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
|
|
@ -321,6 +329,7 @@ class Vertex:
|
|||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
load_from_db_fields=self.load_from_db_fields,
|
||||
params=self.params,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue