diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 30e709664..549e0c2b7 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -85,6 +85,7 @@ async def retrieve_vertices_order( graph = await build_and_cache_graph_from_data( flow_id=flow_id, graph_data=data.model_dump(), chat_service=chat_service ) + graph.validate_stream() if stop_component_id or start_component_id: try: first_layer = graph.sort_vertices(stop_component_id, start_component_id) @@ -109,6 +110,8 @@ async def retrieve_vertices_order( return VerticesOrderResponse(ids=first_layer, run_id=run_id, vertices_to_run=vertices_to_run) except Exception as exc: + if "stream or streaming set to True" in str(exc): + raise HTTPException(status_code=400, detail=str(exc)) logger.error(f"Error checking build status: {exc}") logger.exception(exc) raise HTTPException(status_code=500, detail=str(exc)) from exc diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 1744ac687..01bbab251 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -166,6 +166,25 @@ class Graph: self.state_manager.append_state(name, record, run_id=self._run_id) + def validate_stream(self): + """ + Validates the stream configuration of the graph. + + If there are two vertices in the same graph (connected by edges) + that have `stream=True` or `streaming=True`, raises a `ValueError`. + + Raises: + ValueError: If two connected vertices have `stream=True` or `streaming=True`. + """ + for vertex in self.vertices: + if vertex.params.get("stream") or vertex.params.get("streaming"): + successors = self.get_all_successors(vertex) + for successor in successors: + if successor.params.get("stream") or successor.params.get("streaming"): + raise ValueError( + f"Components {vertex.display_name} and {successor.display_name} are connected and both have stream or streaming set to True" + ) + @property def run_id(self): """ @@ -882,7 +901,7 @@ class Graph: """Returns the predecessors of a vertex.""" return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])] - def get_all_successors(self, vertex, recursive=True, flat=True): + def get_all_successors(self, vertex: Vertex, recursive=True, flat=True): # Recursively get the successors of the current vertex # successors = vertex.successors # if not successors: @@ -919,7 +938,7 @@ class Graph: successors_result.append([successor]) return successors_result - def get_successors(self, vertex): + def get_successors(self, vertex: Vertex) -> List[Vertex]: """Returns the successors of a vertex.""" return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]