ref: Add ruff rules for asyncio tasks references (RUF006) (#4079)

Add ruff rules for asyncio tasks references (RUF006)

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Christophe Bornet 2024-10-10 00:12:20 +02:00 committed by GitHub
commit de055f2113
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 59 additions and 41 deletions

View file

@ -119,6 +119,7 @@ class Graph:
self._cycle_vertices: set[str] | None = None
self._call_order: list[str] = []
self._snapshots: list[dict[str, Any]] = []
self._end_trace_tasks: set[asyncio.Task] = set()
try:
self.tracing_service: TracingService | None = get_tracing_service()
except Exception: # noqa: BLE001
@ -583,6 +584,11 @@ class Graph:
if self.tracing_service:
await self.tracing_service.initialize_tracers()
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
task = asyncio.create_task(self.end_all_traces(outputs, error))
self._end_trace_tasks.add(task)
task.add_done_callback(self._end_trace_tasks.discard)
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
if not self.tracing_service:
return
@ -685,11 +691,11 @@ class Graph:
await self.process(start_component_id=start_component_id, fallback_to_env_vars=fallback_to_env_vars)
self.increment_run_count()
except Exception as exc:
asyncio.create_task(self.end_all_traces(error=exc))
self._end_all_traces_async(error=exc)
msg = f"Error running graph: {exc}"
raise ValueError(msg) from exc
finally:
asyncio.create_task(self.end_all_traces())
self._end_all_traces_async()
# Get the outputs
vertex_outputs = []
for vertex in self.vertices:
@ -1257,7 +1263,7 @@ class Graph:
msg = "Graph not prepared. Call prepare() first."
raise ValueError(msg)
if not self._run_queue:
asyncio.create_task(self.end_all_traces())
self._end_all_traces_async()
return Finish()
vertex_id = self.get_next_in_queue()
chat_service = get_chat_service()

View file

@ -28,6 +28,8 @@ from langflow.utils.schemas import ChatOutputResponse
from langflow.utils.util import sync_to_async, unescape_string
if TYPE_CHECKING:
from uuid import UUID
from langflow.custom import Component
from langflow.events.event_manager import EventManager
from langflow.graph.edge.base import CycleEdge, Edge
@ -101,6 +103,7 @@ class Vertex:
self.use_result = False
self.build_times: list[float] = []
self.state = VertexStates.ACTIVE
self.log_transaction_tasks: set[asyncio.Task] = set()
def set_input_value(self, name: str, value: Any):
if self._custom_component is None:
@ -625,6 +628,13 @@ class Vertex:
async with self._lock:
return await self._get_result(requester, target_handle_name)
def _log_transaction_async(
self, flow_id: str | UUID, source: Vertex, status, target: Vertex | None = None, error=None
) -> None:
task = asyncio.create_task(log_transaction(flow_id, source, status, target, error))
self.log_transaction_tasks.add(task)
task.add_done_callback(self.log_transaction_tasks.discard)
async def _get_result(self, requester: Vertex, target_handle_name: str | None = None) -> Any:
"""
Retrieves the result of the built component.
@ -637,13 +647,13 @@ class Vertex:
flow_id = self.graph.flow_id
if not self._built:
if flow_id:
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="error"))
self._log_transaction_async(str(flow_id), source=self, target=requester, status="error")
msg = f"Component {self.display_name} has not been built yet"
raise ValueError(msg)
result = self._built_result if self.use_result else self._built_object
if flow_id:
asyncio.create_task(log_transaction(str(flow_id), source=self, target=requester, status="success"))
self._log_transaction_async(str(flow_id), source=self, target=requester, status="success")
return result
async def _build_vertex_and_update_params(self, key, vertex: Vertex):

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import contextlib
import json
from collections.abc import AsyncIterator, Generator, Iterator
@ -11,7 +10,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk
from loguru import logger
from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, log_transaction, log_vertex_build, rewrite_file_path, serialize_field
from langflow.graph.utils import UnbuiltObject, log_vertex_build, rewrite_file_path, serialize_field
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.exceptions import NoComponentInstance
from langflow.schema import Data
@ -109,9 +108,7 @@ class ComponentVertex(Vertex):
default_value = requester.get_value_from_template_dict(edge.target_param)
if flow_id:
asyncio.create_task(
log_transaction(source=self, target=requester, flow_id=str(flow_id), status="error")
)
self._log_transaction_async(source=self, target=requester, flow_id=str(flow_id), status="error")
if default_value is not UNDEFINED:
return default_value
msg = f"Component {self.display_name} has not been built yet"
@ -150,7 +147,7 @@ class ComponentVertex(Vertex):
msg = f"Result not found for {edge.source_handle.name} in {edge}"
raise ValueError(msg)
if flow_id:
asyncio.create_task(log_transaction(source=self, target=requester, flow_id=str(flow_id), status="success"))
self._log_transaction_async(source=self, target=requester, flow_id=str(flow_id), status="success")
return result
def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dict]:

View file

@ -87,6 +87,9 @@ class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
return response
telemetry_service_tasks = set()
def get_lifespan(fix_migration=False, socketio_server=None, version=None):
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -102,7 +105,9 @@ def get_lifespan(fix_migration=False, socketio_server=None, version=None):
initialize_super_user_if_needed()
task = asyncio.create_task(get_and_cache_all_types_dict(get_settings_service(), get_cache_service()))
await create_or_update_starter_projects(task)
asyncio.create_task(get_telemetry_service().start())
telemetry_service_task = asyncio.create_task(get_telemetry_service().start())
telemetry_service_tasks.add(telemetry_service_task)
telemetry_service_task.add_done_callback(telemetry_service_tasks.discard)
load_flows_from_directory()
yield
except Exception as exc:

View file

@ -119,7 +119,7 @@ class TelemetryService(Service):
self.running = True
self._start_time = datetime.now(timezone.utc)
self.worker_task = asyncio.create_task(self.telemetry_worker())
asyncio.create_task(self.log_package_version())
self.log_package_version_task = asyncio.create_task(self.log_package_version())
except Exception: # noqa: BLE001
logger.exception("Error starting telemetry service")

View file

@ -18,6 +18,7 @@ class BaseTracer(ABC):
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID):
raise NotImplementedError
@property
@abstractmethod
def ready(self) -> bool:
raise NotImplementedError

View file

@ -57,6 +57,7 @@ class TracingService(Service):
self.logs_queue: asyncio.Queue = asyncio.Queue()
self.running = False
self.worker_task = None
self.end_trace_tasks: set[asyncio.Task] = set()
async def log_worker(self):
while self.running or not self.logs_queue.empty():
@ -162,7 +163,7 @@ class TracingService(Service):
self.inputs[trace_name] = inputs
self.inputs_metadata[trace_name] = metadata or {}
for tracer in self._tracers.values():
if not tracer.ready: # type: ignore[truthy-function]
if not tracer.ready:
continue
try:
tracer.add_trace(trace_id, trace_name, trace_type, inputs, metadata, vertex)
@ -171,30 +172,28 @@ class TracingService(Service):
def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None):
for tracer in self._tracers.values():
if not tracer.ready: # type: ignore[truthy-function]
continue
try:
tracer.end_trace(
trace_id=trace_id,
trace_name=trace_name,
outputs=self.outputs[trace_name],
error=error,
logs=self._logs[trace_name],
)
except Exception: # noqa: BLE001
logger.exception(f"Error ending trace {trace_name}")
if tracer.ready:
try:
tracer.end_trace(
trace_id=trace_id,
trace_name=trace_name,
outputs=self.outputs[trace_name],
error=error,
logs=self._logs[trace_name],
)
except Exception: # noqa: BLE001
logger.exception(f"Error ending trace {trace_name}")
def _end_all_traces(self, outputs: dict, error: Exception | None = None):
for tracer in self._tracers.values():
if not tracer.ready: # type: ignore[truthy-function]
continue
try:
tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs)
except Exception: # noqa: BLE001
logger.exception("Error ending all traces")
if tracer.ready:
try:
tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs)
except Exception: # noqa: BLE001
logger.exception("Error ending all traces")
async def end(self, outputs: dict, error: Exception | None = None):
self._end_all_traces(outputs, error)
await asyncio.to_thread(self._end_all_traces, outputs, error)
self._reset_io()
await self.stop()
@ -224,13 +223,15 @@ class TracingService(Service):
try:
yield self
except Exception as e:
self._end_traces(trace_id, trace_name, e)
self._end_and_reset(trace_id, trace_name, e)
raise
finally:
asyncio.create_task(await asyncio.to_thread(self._end_and_reset, trace_id, trace_name, None))
else:
self._end_and_reset(trace_id, trace_name)
async def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None):
self._end_traces(trace_id, trace_name, error)
def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None):
task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error))
self.end_trace_tasks.add(task)
task.add_done_callback(self.end_trace_tasks.discard)
self._reset_io()
def set_outputs(

View file

@ -58,10 +58,8 @@ ignore = [
"ARG",
"D",
"DOC",
"EXE",
"FBT",
"N",
"RUF006", # Store a reference to the return value of `asyncio.create_task`
"S",
"SLF",
"T201",