refactor: enhance graph initialization and telemetry handling (#5721)

* refactor: asyncdbsession handling in event_generator

* refactor: simplify graph building logic in build_flow

* refactor: simplify graph initialization and sorting

* refactor: adjust graph initialization and telemetry
This commit is contained in:
Ítalo Johnny 2025-01-19 22:25:54 -03:00 committed by GitHub
commit d413abbf90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -160,32 +160,16 @@ async def build_flow(
async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]:
start_time = time.perf_counter()
components_count = None
components_count = 0
graph = None
try:
flow_id_str = str(flow_id)
# Create a fresh session for database operations
async with session_scope() as fresh_session:
if not data:
graph = await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
else:
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
flow_name = result.first()
graph = await build_graph_from_data(
flow_id=flow_id_str,
payload=data.model_dump(),
user_id=str(current_user.id),
flow_name=flow_name,
)
graph = await create_graph(fresh_session, flow_id_str)
graph.validate_stream()
if stop_component_id or start_component_id:
try:
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
except Exception: # noqa: BLE001
logger.exception("Error sorting vertices")
first_layer = graph.sort_vertices()
else:
first_layer = graph.sort_vertices()
first_layer = sort_vertices(graph)
if inputs is not None and hasattr(inputs, "session") and inputs.session is not None:
graph.session_id = inputs.session
@ -198,31 +182,53 @@ async def build_flow(
# and return the same structure but only with the ids
components_count = len(graph.vertices)
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
await chat_service.set_cache(flow_id_str, graph)
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playground_seconds=int(time.perf_counter() - start_time),
playground_component_count=components_count,
playground_success=True,
),
)
await log_telemetry(start_time, components_count, success=True)
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playground_seconds=int(time.perf_counter() - start_time),
playground_component_count=components_count,
playground_success=False,
playground_error_message=str(exc),
),
)
await log_telemetry(start_time, components_count, success=False, error_message=str(exc))
if "stream or streaming set to True" in str(exc):
raise HTTPException(status_code=400, detail=str(exc)) from exc
logger.exception("Error checking build status")
raise HTTPException(status_code=500, detail=str(exc)) from exc
return first_layer, vertices_to_run, graph
async def log_telemetry(
start_time: float, components_count: int, *, success: bool, error_message: str | None = None
):
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playground_seconds=int(time.perf_counter() - start_time),
playground_component_count=components_count,
playground_success=success,
playground_error_message=str(error_message) if error_message else "",
),
)
async def create_graph(fresh_session, flow_id_str: str) -> Graph:
if not data:
return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
flow_name = result.first()
return await build_graph_from_data(
flow_id=flow_id_str,
payload=data.model_dump(),
user_id=str(current_user.id),
flow_name=flow_name,
)
def sort_vertices(graph: Graph) -> list[str]:
try:
return graph.sort_vertices(stop_component_id, start_component_id)
except Exception: # noqa: BLE001
logger.exception("Error sorting vertices")
return graph.sort_vertices()
async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse:
flow_id_str = str(flow_id)
next_runnable_vertices = []
@ -372,33 +378,15 @@ async def build_flow(
return
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another task since the build_graph_and_get_order is now an async function
vertices_task = asyncio.create_task(build_graph_and_get_order())
try:
await vertices_task
except asyncio.CancelledError:
vertices_task.cancel()
return
except Exception as e:
error_message = ErrorMessage(
flow_id=flow_id,
exception=e,
)
event_manager.on_error(data=error_message.data)
raise
ids, vertices_to_run, graph = vertices_task.result()
else:
try:
ids, vertices_to_run, graph = await build_graph_and_get_order()
except Exception as e:
error_message = ErrorMessage(
flow_id=flow_id,
exception=e,
)
event_manager.on_error(data=error_message.data)
raise
try:
ids, vertices_to_run, graph = await build_graph_and_get_order()
except Exception as e:
error_message = ErrorMessage(
flow_id=flow_id,
exception=e,
)
event_manager.on_error(data=error_message.data)
raise
event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run})
await client_consumed_queue.get()