Fix dict setup

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 15:35:36 -03:00
commit e07eae6a8d

View file

@ -59,8 +59,13 @@ class Vertex:
self.updated_raw_params = False
self.id: str = data["id"]
self.is_state = False
self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS)
self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS)
self.is_input = any(
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
)
self.is_output = any(
output_component_name in self.id
for output_component_name in OUTPUT_COMPONENTS
)
self.has_session_id = None
self._custom_component = None
self.has_external_input = False
@ -101,11 +106,17 @@ class Vertex:
def set_state(self, state: str):
self.state = VertexStates[state]
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] < 2:
if (
self.state == VertexStates.INACTIVE
and self.graph.in_degree_map[self.id] < 2
):
# If the vertex is inactive and has only one in degree
# it means that it is not a merge point in the graph
self.graph.inactivated_vertices.add(self.id)
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices:
elif (
self.state == VertexStates.ACTIVE
and self.id in self.graph.inactivated_vertices
):
self.graph.inactivated_vertices.remove(self.id)
@property
@ -122,7 +133,9 @@ class Vertex:
# If the Vertex.type is a power component
# then we need to return the built object
# instead of the result dict
if self.is_interface_component and not isinstance(self._built_object, UnbuiltObject):
if self.is_interface_component and not isinstance(
self._built_object, UnbuiltObject
):
result = self._built_object
# if it is not a dict or a string and hasattr model_dump then
# return the model_dump
@ -134,7 +147,11 @@ class Vertex:
if isinstance(self._built_result, UnbuiltResult):
return {}
return self._built_result if isinstance(self._built_result, dict) else {"result": self._built_result}
return (
self._built_result
if isinstance(self._built_result, dict)
else {"result": self._built_result}
)
def set_artifacts(self) -> None:
pass
@ -204,19 +221,31 @@ class Vertex:
self.selected_output_type = self.data["node"].get("selected_output_type")
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
self.has_session_id = "session_id" in template_dicts
self.required_inputs = [
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
template_dicts[key]["type"]
for key, value in template_dicts.items()
if value["required"]
]
self.optional_inputs = [
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
template_dicts[key]["type"]
for key, value in template_dicts.items()
if not value["required"]
]
# Add the template_dicts[key]["input_types"] to the optional_inputs
self.optional_inputs.extend(
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
[
input_type
for value in template_dicts.values()
for input_type in value.get("input_types", [])
]
)
template_dict = self.data["node"]["template"]
@ -263,7 +292,11 @@ class Vertex:
self.updated_raw_params = False
return
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
params = {}
for edge in self.edges:
@ -284,7 +317,10 @@ class Vertex:
# we don't know the key of the dict but we need to set the value
# to the vertex that is the source of the edge
param_dict = template_dict[param_key]["value"]
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
params[param_key] = {
key: self.graph.get_vertex(edge.source_id)
for key in param_dict.keys()
}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
@ -320,7 +356,11 @@ class Vertex:
# list of dicts, so we need to convert it to a dict
# before passing it to the build method
if isinstance(val, list):
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
params[key] = {
k: v
for item in value.get("value", [])
for k, v in item.items()
}
elif isinstance(val, dict):
params[key] = val
elif value.get("type") == "int" and val is not None:
@ -445,7 +485,9 @@ class Vertex:
if isinstance(self._built_object, str):
self._built_result = self._built_object
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
result = await generate_result(
self._built_object, inputs, self.has_external_output, session_id
)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
@ -461,17 +503,22 @@ class Vertex:
elif isinstance(value, list) and self._is_list_of_nodes(value):
await self._build_list_of_nodes_and_update_params(key, value, user_id)
elif isinstance(value, dict):
await self._build_dict_of_nodes_and_update_params(key, value, user_id)
await self._build_dict_and_update_params(key, value, user_id)
elif key not in self.params or self.updated_raw_params:
self.params[key] = value
async def _build_dict_of_nodes_and_update_params(self, key, nodes: Dict[str, "Vertex"], user_id=None):
async def _build_dict_and_update_params(
self, key, nodes_dict: Dict[str, "Vertex"], user_id=None
):
"""
Iterates over a dictionary of nodes, builds each and updates the params dictionary.
"""
for sub_key, node in nodes.items():
built = await node.get_result(requester=self, user_id=user_id)
self.params[key][sub_key] = built
for sub_key, value in nodes_dict.items():
if not self._is_node(value):
self.params[key][sub_key] = value
else:
built = await value.get_result(requester=self, user_id=user_id)
self.params[key][sub_key] = built
def _is_node(self, value):
"""
@ -485,7 +532,9 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
async def get_result(
self, requester: Optional["Vertex"] = None, user_id=None, timeout=None
) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
@ -519,7 +568,9 @@ class Vertex:
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
async def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
@ -586,7 +637,9 @@ class Vertex:
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
raise ValueError(
f"Error building node {self.display_name}: {str(exc)}"
) from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -614,7 +667,9 @@ class Vertex:
logger.warning(message)
elif isinstance(self._built_object, (Iterator, AsyncIterator)):
if self.display_name in ["Text Output"]:
raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.")
raise ValueError(
f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead."
)
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
self._built = False
@ -676,16 +731,24 @@ class Vertex:
return self._built_object
# Get the requester edge
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
requester_edge = next(
(edge for edge in self.edges if edge.target_id == requester.id), None
)
# Return the result of the requester edge
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
return (
None
if requester_edge is None
else await requester_edge.get_result(source=self, target=requester)
)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:
self.edges.append(edge)
def __repr__(self) -> str:
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
return (
f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})"
)
def __eq__(self, __o: object) -> bool:
try:
@ -706,4 +769,8 @@ class Vertex:
def _built_object_repr(self):
# Add a message with an emoji, stars for sucess,
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"
return (
"Built sucessfully ✨"
if self._built_object is not None
else "Failed to build 😵‍💫"
)