📝 (chat.py): Add an empty line before setting cache to improve code readability

🔧 (TextOperator.py): Add logging when stopping with a message
🔧 (TextOperator.py): Add logging when stopping with a message
🔧 (component.py): Refactor _set_outputs method to improve code readability
🔧 (component.py): Refactor build_results method to improve code readability
🔧 (component.py): Refactor custom_repr method to improve code readability
🔧 (custom_component.py): Refactor stop method to accept output_name parameter
🔧 (utils.py): Set output as selected after adding return types
🔧 (base.py): Reset inactivated vertices in the graph before marking them as active
🔧 (base.py): Refactor mark_branch method to only mark child vertices connected through a specific output
🔧 (base.py): Add get_edge method to retrieve edge between two vertices
🔧 (base.py): Refactor mark_branch method to consider output_name when marking child vertices
🔧 (base.py): Refactor build_parent_child_map method to improve code readability
🔧 (types.py): Add _built_object_repr method to handle custom representation of built object
This commit is contained in:
ogabrielluiz 2024-06-04 22:26:18 -03:00
commit d47d254a81
7 changed files with 47 additions and 16 deletions

View file

@ -209,6 +209,7 @@ async def build_vertex(
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
await chat_service.set_cache(flow_id_str, graph)
# graph.stop_vertex tells us if the user asked

View file

@ -48,9 +48,11 @@ class TextOperatorComponent(Component):
]
def true_response(self) -> Union[Text, Record]:
self.stop("False Result")
return self.true_output if self.true_output else self.input_text
def false_response(self) -> Union[Text, Record]:
self.stop("True Result")
return self.false_output if self.false_output else self.input_text
def result_response(self) -> Union[Text, Record]:
@ -79,8 +81,10 @@ class TextOperatorComponent(Component):
result = input_text.endswith(match_text)
if result:
self.status = self.true_response()
return self.true_response()
response = self.true_response()
self.status = response
return response
else:
self.status = self.false_response()
return self.false_response()
response = self.false_response()
self.status = response
return response

View file

@ -58,36 +58,41 @@ class Component(CustomComponent):
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
setattr(self, key, value)
def _set_outputs(self, outputs: List[dict]):
self.outputs = [Output(**output) for output in outputs]
async def build_results(self, vertex: "Vertex"):
build_results = {}
_results = {}
if hasattr(self, "outputs"):
self._set_outputs(vertex.outputs)
for output in self.outputs:
# Build the output if it's connected to some other vertex
# or if it's not connected to any vertex
self.output = output
if not vertex.outgoing_edges or output.name in vertex.edges_source_names:
method: Callable | Awaitable = getattr(self, output.method)
result = method()
# If the method is asynchronous, we need to await it
if inspect.iscoroutinefunction(method):
result = await result
build_results[output.name] = result
self.build_results = build_results
return build_results
_results[output.name] = result
self._results = _results
return _results
def custom_repr(self):
# ! Temporary REPR
# Since all are dict, yaml.dump them
if isinstance(self.build_results, dict):
_build_results = recursive_serialize_or_str(self.build_results)
if isinstance(self._results, dict):
_build_results = recursive_serialize_or_str(self._results)
try:
custom_repr = yaml.dump(_build_results)
except Exception as e:
logger.error(f"Error while dumping build_result: {e}")
custom_repr = str(self.build_results)
custom_repr = str(self._results)
if custom_repr is None and isinstance(self.build_results, (dict, Record, str)):
custom_repr = self.build_results
if custom_repr is None and isinstance(self._results, (dict, Record, str)):
custom_repr = self._results
if not isinstance(custom_repr, str):
custom_repr = str(custom_repr)
return custom_repr

View file

@ -84,11 +84,11 @@ class CustomComponent(BaseComponent):
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def stop(self):
def stop(self, output_name: str):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.graph.mark_branch(self.vertex.id, "INACTIVE")
self.graph.mark_branch(vertex_id=self.vertex.id, output_name=output_name, state="INACTIVE")
except Exception as e:
raise ValueError(f"Error stopping {self.display_name}: {e}")

View file

@ -337,6 +337,7 @@ def build_custom_component_template_from_inputs(
return_types = custom_component.get_method_return_type(output.method)
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)
output.set_selected()
# ! This should be removed when we have a better way to handle this
frontend_node.get_base_classes_from_outputs()
reorder_fields(frontend_node, custom_component._get_field_order())

View file

@ -457,6 +457,8 @@ class Graph:
"""
Resets the inactivated vertices in the graph.
"""
for vertex_id in self.inactivated_vertices.copy():
self.mark_vertex(vertex_id, "ACTIVE")
self.inactivated_vertices = []
self.inactivated_vertices = set()
@ -470,7 +472,7 @@ class Graph:
vertex = self.get_vertex(vertex_id)
vertex.set_state(state)
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None):
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None, output_name: Optional[str] = None):
"""Marks a branch of the graph."""
if visited is None:
visited = set()
@ -481,8 +483,21 @@ class Graph:
self.mark_vertex(vertex_id, state)
for child_id in self.parent_child_map[vertex_id]:
# Only child_id that have an edge with the vertex_id through the output_name
# should be marked
if output_name:
edge = self.get_edge(vertex_id, child_id)
if edge.source_handle.name != output_name:
continue
self.mark_branch(child_id, state)
def get_edge(self, source_id: str, target_id: str) -> Optional[ContractEdge]:
"""Returns the edge between two vertices."""
for edge in self.edges:
if edge.source_id == source_id and edge.target_id == target_id:
return edge
return None
def build_parent_child_map(self, vertices: List[Vertex]):
parent_child_map = defaultdict(list)
for vertex in vertices:
@ -1132,6 +1147,7 @@ class Graph:
)
layers: List[List[str]] = []
visited = set(queue)
current_layer = 0
while queue:
layers.append([]) # Start a new layer

View file

@ -309,6 +309,10 @@ class StateVertex(Vertex):
successors = self.graph.successor_map.get(self.id, [])
return successors + self.graph.activated_vertices
def _built_object_repr(self):
if self.artifacts and "repr" in self.artifacts:
return self.artifacts["repr"] or super()._built_object_repr()
def dict_to_codeblock(d: dict) -> str:
serialized = {key: serialize_field(val) for key, val in d.items()}