📝 (chat.py): Add an empty line before setting cache to improve code readability
🔧 (TextOperator.py): Add logging when stopping with a message 🔧 (TextOperator.py): Add logging when stopping with a message 🔧 (component.py): Refactor _set_outputs method to improve code readability 🔧 (component.py): Refactor build_results method to improve code readability 🔧 (component.py): Refactor custom_repr method to improve code readability 🔧 (custom_component.py): Refactor stop method to accept output_name parameter 🔧 (utils.py): Set output as selected after adding return types 🔧 (base.py): Reset inactivated vertices in the graph before marking them as active 🔧 (base.py): Refactor mark_branch method to only mark child vertices connected through a specific output 🔧 (base.py): Add get_edge method to retrieve edge between two vertices 🔧 (base.py): Refactor mark_branch method to consider output_name when marking child vertices 🔧 (base.py): Refactor build_parent_child_map method to improve code readability 🔧 (types.py): Add _built_object_repr method to handle custom representation of built object
This commit is contained in:
parent
e693fba120
commit
d47d254a81
7 changed files with 47 additions and 16 deletions
|
|
@ -209,6 +209,7 @@ async def build_vertex(
|
|||
inactivated_vertices = list(graph.inactivated_vertices)
|
||||
graph.reset_inactivated_vertices()
|
||||
graph.reset_activated_vertices()
|
||||
|
||||
await chat_service.set_cache(flow_id_str, graph)
|
||||
|
||||
# graph.stop_vertex tells us if the user asked
|
||||
|
|
|
|||
|
|
@ -48,9 +48,11 @@ class TextOperatorComponent(Component):
|
|||
]
|
||||
|
||||
def true_response(self) -> Union[Text, Record]:
|
||||
self.stop("False Result")
|
||||
return self.true_output if self.true_output else self.input_text
|
||||
|
||||
def false_response(self) -> Union[Text, Record]:
|
||||
self.stop("True Result")
|
||||
return self.false_output if self.false_output else self.input_text
|
||||
|
||||
def result_response(self) -> Union[Text, Record]:
|
||||
|
|
@ -79,8 +81,10 @@ class TextOperatorComponent(Component):
|
|||
result = input_text.endswith(match_text)
|
||||
|
||||
if result:
|
||||
self.status = self.true_response()
|
||||
return self.true_response()
|
||||
response = self.true_response()
|
||||
self.status = response
|
||||
return response
|
||||
else:
|
||||
self.status = self.false_response()
|
||||
return self.false_response()
|
||||
response = self.false_response()
|
||||
self.status = response
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -58,36 +58,41 @@ class Component(CustomComponent):
|
|||
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
|
||||
setattr(self, key, value)
|
||||
|
||||
def _set_outputs(self, outputs: List[dict]):
|
||||
self.outputs = [Output(**output) for output in outputs]
|
||||
|
||||
async def build_results(self, vertex: "Vertex"):
|
||||
build_results = {}
|
||||
_results = {}
|
||||
|
||||
if hasattr(self, "outputs"):
|
||||
self._set_outputs(vertex.outputs)
|
||||
for output in self.outputs:
|
||||
# Build the output if it's connected to some other vertex
|
||||
# or if it's not connected to any vertex
|
||||
self.output = output
|
||||
if not vertex.outgoing_edges or output.name in vertex.edges_source_names:
|
||||
method: Callable | Awaitable = getattr(self, output.method)
|
||||
result = method()
|
||||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
build_results[output.name] = result
|
||||
self.build_results = build_results
|
||||
return build_results
|
||||
_results[output.name] = result
|
||||
self._results = _results
|
||||
return _results
|
||||
|
||||
def custom_repr(self):
|
||||
# ! Temporary REPR
|
||||
# Since all are dict, yaml.dump them
|
||||
if isinstance(self.build_results, dict):
|
||||
_build_results = recursive_serialize_or_str(self.build_results)
|
||||
if isinstance(self._results, dict):
|
||||
_build_results = recursive_serialize_or_str(self._results)
|
||||
try:
|
||||
custom_repr = yaml.dump(_build_results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while dumping build_result: {e}")
|
||||
custom_repr = str(self.build_results)
|
||||
custom_repr = str(self._results)
|
||||
|
||||
if custom_repr is None and isinstance(self.build_results, (dict, Record, str)):
|
||||
custom_repr = self.build_results
|
||||
if custom_repr is None and isinstance(self._results, (dict, Record, str)):
|
||||
custom_repr = self._results
|
||||
if not isinstance(custom_repr, str):
|
||||
custom_repr = str(custom_repr)
|
||||
return custom_repr
|
||||
|
|
|
|||
|
|
@ -84,11 +84,11 @@ class CustomComponent(BaseComponent):
|
|||
except Exception as e:
|
||||
raise ValueError(f"Error updating state: {e}")
|
||||
|
||||
def stop(self):
|
||||
def stop(self, output_name: str):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
try:
|
||||
self.graph.mark_branch(self.vertex.id, "INACTIVE")
|
||||
self.graph.mark_branch(vertex_id=self.vertex.id, output_name=output_name, state="INACTIVE")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error stopping {self.display_name}: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -337,6 +337,7 @@ def build_custom_component_template_from_inputs(
|
|||
return_types = custom_component.get_method_return_type(output.method)
|
||||
return_types = [format_type(return_type) for return_type in return_types]
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
# ! This should be removed when we have a better way to handle this
|
||||
frontend_node.get_base_classes_from_outputs()
|
||||
reorder_fields(frontend_node, custom_component._get_field_order())
|
||||
|
|
|
|||
|
|
@ -457,6 +457,8 @@ class Graph:
|
|||
"""
|
||||
Resets the inactivated vertices in the graph.
|
||||
"""
|
||||
for vertex_id in self.inactivated_vertices.copy():
|
||||
self.mark_vertex(vertex_id, "ACTIVE")
|
||||
self.inactivated_vertices = []
|
||||
self.inactivated_vertices = set()
|
||||
|
||||
|
|
@ -470,7 +472,7 @@ class Graph:
|
|||
vertex = self.get_vertex(vertex_id)
|
||||
vertex.set_state(state)
|
||||
|
||||
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None):
|
||||
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None, output_name: Optional[str] = None):
|
||||
"""Marks a branch of the graph."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
|
@ -481,8 +483,21 @@ class Graph:
|
|||
self.mark_vertex(vertex_id, state)
|
||||
|
||||
for child_id in self.parent_child_map[vertex_id]:
|
||||
# Only child_id that have an edge with the vertex_id through the output_name
|
||||
# should be marked
|
||||
if output_name:
|
||||
edge = self.get_edge(vertex_id, child_id)
|
||||
if edge.source_handle.name != output_name:
|
||||
continue
|
||||
self.mark_branch(child_id, state)
|
||||
|
||||
def get_edge(self, source_id: str, target_id: str) -> Optional[ContractEdge]:
|
||||
"""Returns the edge between two vertices."""
|
||||
for edge in self.edges:
|
||||
if edge.source_id == source_id and edge.target_id == target_id:
|
||||
return edge
|
||||
return None
|
||||
|
||||
def build_parent_child_map(self, vertices: List[Vertex]):
|
||||
parent_child_map = defaultdict(list)
|
||||
for vertex in vertices:
|
||||
|
|
@ -1132,6 +1147,7 @@ class Graph:
|
|||
)
|
||||
layers: List[List[str]] = []
|
||||
visited = set(queue)
|
||||
|
||||
current_layer = 0
|
||||
while queue:
|
||||
layers.append([]) # Start a new layer
|
||||
|
|
|
|||
|
|
@ -309,6 +309,10 @@ class StateVertex(Vertex):
|
|||
successors = self.graph.successor_map.get(self.id, [])
|
||||
return successors + self.graph.activated_vertices
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.artifacts and "repr" in self.artifacts:
|
||||
return self.artifacts["repr"] or super()._built_object_repr()
|
||||
|
||||
|
||||
def dict_to_codeblock(d: dict) -> str:
|
||||
serialized = {key: serialize_field(val) for key, val in d.items()}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue