refactor: enhance graph initialization and telemetry handling (#5721)
* refactor: asyncdbsession handling in event_generator * refactor: simplify graph building logic in build_flow * refactor: simplify graph initialization and sorting * refactor: adjust graph initialization and telemetry
This commit is contained in:
parent
f22cf01d5c
commit
d413abbf90
1 changed files with 52 additions and 64 deletions
|
|
@ -160,32 +160,16 @@ async def build_flow(
|
|||
|
||||
async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]:
|
||||
start_time = time.perf_counter()
|
||||
components_count = None
|
||||
components_count = 0
|
||||
graph = None
|
||||
try:
|
||||
flow_id_str = str(flow_id)
|
||||
# Create a fresh session for database operations
|
||||
async with session_scope() as fresh_session:
|
||||
if not data:
|
||||
graph = await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
|
||||
else:
|
||||
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
|
||||
flow_name = result.first()
|
||||
graph = await build_graph_from_data(
|
||||
flow_id=flow_id_str,
|
||||
payload=data.model_dump(),
|
||||
user_id=str(current_user.id),
|
||||
flow_name=flow_name,
|
||||
)
|
||||
graph = await create_graph(fresh_session, flow_id_str)
|
||||
|
||||
graph.validate_stream()
|
||||
if stop_component_id or start_component_id:
|
||||
try:
|
||||
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error sorting vertices")
|
||||
first_layer = graph.sort_vertices()
|
||||
else:
|
||||
first_layer = graph.sort_vertices()
|
||||
first_layer = sort_vertices(graph)
|
||||
|
||||
if inputs is not None and hasattr(inputs, "session") and inputs.session is not None:
|
||||
graph.session_id = inputs.session
|
||||
|
|
@ -198,31 +182,53 @@ async def build_flow(
|
|||
# and return the same structure but only with the ids
|
||||
components_count = len(graph.vertices)
|
||||
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
|
||||
|
||||
await chat_service.set_cache(flow_id_str, graph)
|
||||
background_tasks.add_task(
|
||||
telemetry_service.log_package_playground,
|
||||
PlaygroundPayload(
|
||||
playground_seconds=int(time.perf_counter() - start_time),
|
||||
playground_component_count=components_count,
|
||||
playground_success=True,
|
||||
),
|
||||
)
|
||||
await log_telemetry(start_time, components_count, success=True)
|
||||
|
||||
except Exception as exc:
|
||||
background_tasks.add_task(
|
||||
telemetry_service.log_package_playground,
|
||||
PlaygroundPayload(
|
||||
playground_seconds=int(time.perf_counter() - start_time),
|
||||
playground_component_count=components_count,
|
||||
playground_success=False,
|
||||
playground_error_message=str(exc),
|
||||
),
|
||||
)
|
||||
await log_telemetry(start_time, components_count, success=False, error_message=str(exc))
|
||||
|
||||
if "stream or streaming set to True" in str(exc):
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
logger.exception("Error checking build status")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
return first_layer, vertices_to_run, graph
|
||||
|
||||
async def log_telemetry(
|
||||
start_time: float, components_count: int, *, success: bool, error_message: str | None = None
|
||||
):
|
||||
background_tasks.add_task(
|
||||
telemetry_service.log_package_playground,
|
||||
PlaygroundPayload(
|
||||
playground_seconds=int(time.perf_counter() - start_time),
|
||||
playground_component_count=components_count,
|
||||
playground_success=success,
|
||||
playground_error_message=str(error_message) if error_message else "",
|
||||
),
|
||||
)
|
||||
|
||||
async def create_graph(fresh_session, flow_id_str: str) -> Graph:
|
||||
if not data:
|
||||
return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
|
||||
|
||||
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
|
||||
flow_name = result.first()
|
||||
|
||||
return await build_graph_from_data(
|
||||
flow_id=flow_id_str,
|
||||
payload=data.model_dump(),
|
||||
user_id=str(current_user.id),
|
||||
flow_name=flow_name,
|
||||
)
|
||||
|
||||
def sort_vertices(graph: Graph) -> list[str]:
|
||||
try:
|
||||
return graph.sort_vertices(stop_component_id, start_component_id)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error sorting vertices")
|
||||
return graph.sort_vertices()
|
||||
|
||||
async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse:
|
||||
flow_id_str = str(flow_id)
|
||||
next_runnable_vertices = []
|
||||
|
|
@ -372,33 +378,15 @@ async def build_flow(
|
|||
return
|
||||
|
||||
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
|
||||
if not data:
|
||||
# using another task since the build_graph_and_get_order is now an async function
|
||||
vertices_task = asyncio.create_task(build_graph_and_get_order())
|
||||
try:
|
||||
await vertices_task
|
||||
except asyncio.CancelledError:
|
||||
vertices_task.cancel()
|
||||
return
|
||||
except Exception as e:
|
||||
error_message = ErrorMessage(
|
||||
flow_id=flow_id,
|
||||
exception=e,
|
||||
)
|
||||
event_manager.on_error(data=error_message.data)
|
||||
raise
|
||||
|
||||
ids, vertices_to_run, graph = vertices_task.result()
|
||||
else:
|
||||
try:
|
||||
ids, vertices_to_run, graph = await build_graph_and_get_order()
|
||||
except Exception as e:
|
||||
error_message = ErrorMessage(
|
||||
flow_id=flow_id,
|
||||
exception=e,
|
||||
)
|
||||
event_manager.on_error(data=error_message.data)
|
||||
raise
|
||||
try:
|
||||
ids, vertices_to_run, graph = await build_graph_and_get_order()
|
||||
except Exception as e:
|
||||
error_message = ErrorMessage(
|
||||
flow_id=flow_id,
|
||||
exception=e,
|
||||
)
|
||||
event_manager.on_error(data=error_message.data)
|
||||
raise
|
||||
event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run})
|
||||
await client_consumed_queue.get()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue