🐛 (chat.py): Add validation for stream configuration in graph to prevent two connected vertices with stream or streaming set to True from raising a ValueError

🐛 (chat.py): Raise HTTPException with status code 400 if "stream or streaming set to True" is found in the exception message
📝 (base.py): Add method `validate_stream` to Graph class to validate stream configuration and prevent connected vertices with stream or streaming set to True
📝 (base.py): Update method `get_all_successors` in Graph class to specify type hint for the `vertex` parameter and return type
📝 (base.py): Update method `get_successors` in Graph class to specify type hint for the `vertex` parameter and return type
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-05-07 20:04:31 -03:00
commit 82877d8b99
2 changed files with 24 additions and 2 deletions

View file

@ -85,6 +85,7 @@ async def retrieve_vertices_order(
graph = await build_and_cache_graph_from_data(
flow_id=flow_id, graph_data=data.model_dump(), chat_service=chat_service
)
graph.validate_stream()
if stop_component_id or start_component_id:
try:
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
@ -109,6 +110,8 @@ async def retrieve_vertices_order(
return VerticesOrderResponse(ids=first_layer, run_id=run_id, vertices_to_run=vertices_to_run)
except Exception as exc:
if "stream or streaming set to True" in str(exc):
raise HTTPException(status_code=400, detail=str(exc))
logger.error(f"Error checking build status: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

View file

@ -166,6 +166,25 @@ class Graph:
self.state_manager.append_state(name, record, run_id=self._run_id)
def validate_stream(self):
"""
Validates the stream configuration of the graph.
If there are two vertices in the same graph (connected by edges)
that have `stream=True` or `streaming=True`, raises a `ValueError`.
Raises:
ValueError: If two connected vertices have `stream=True` or `streaming=True`.
"""
for vertex in self.vertices:
if vertex.params.get("stream") or vertex.params.get("streaming"):
successors = self.get_all_successors(vertex)
for successor in successors:
if successor.params.get("stream") or successor.params.get("streaming"):
raise ValueError(
f"Components {vertex.display_name} and {successor.display_name} are connected and both have stream or streaming set to True"
)
@property
def run_id(self):
"""
@ -882,7 +901,7 @@ class Graph:
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
def get_all_successors(self, vertex, recursive=True, flat=True):
def get_all_successors(self, vertex: Vertex, recursive=True, flat=True):
# Recursively get the successors of the current vertex
# successors = vertex.successors
# if not successors:
@ -919,7 +938,7 @@ class Graph:
successors_result.append([successor])
return successors_result
def get_successors(self, vertex):
def get_successors(self, vertex: Vertex) -> List[Vertex]:
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]