diff --git a/src/backend/base/langflow/api/build.py b/src/backend/base/langflow/api/build.py index 7fb32e5c3..586d1112f 100644 --- a/src/backend/base/langflow/api/build.py +++ b/src/backend/base/langflow/api/build.py @@ -180,9 +180,6 @@ async def generate_flow_events( graph.validate_stream() first_layer = sort_vertices(graph) - if inputs is not None and getattr(inputs, "session", None) is not None: - graph.session_id = inputs.session - for vertex_id in first_layer: graph.run_manager.add_to_vertices_being_run(vertex_id) @@ -218,8 +215,19 @@ async def generate_flow_events( ) async def create_graph(fresh_session, flow_id_str: str) -> Graph: + if inputs is not None and getattr(inputs, "session", None) is not None: + effective_session_id = inputs.session + else: + effective_session_id = flow_id_str + if not data: - return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service) + return await build_graph_from_db( + flow_id=flow_id, + session=fresh_session, + chat_service=chat_service, + user_id=str(current_user.id), + session_id=effective_session_id, + ) result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id)) flow_name = result.first() @@ -229,6 +237,7 @@ async def generate_flow_events( payload=data.model_dump(), user_id=str(current_user.id), flow_name=flow_name, + session_id=effective_session_id, ) def sort_vertices(graph: Graph) -> list[str]: @@ -280,7 +289,7 @@ async def generate_flow_events( outputs = {output_label: OutputValue(message=message, type="error")} result_data_response = ResultDataResponse(results={}, outputs=outputs) artifacts = {} - background_tasks.add_task(graph.end_all_traces, error=exc) + background_tasks.add_task(graph.end_all_traces_in_context(error=exc)) result_data_response.message = artifacts @@ -314,7 +323,7 @@ async def generate_flow_events( next_runnable_vertices = [graph.stop_vertex] if not graph.run_manager.vertices_being_run and not next_runnable_vertices: - background_tasks.add_task(graph.end_all_traces) + background_tasks.add_task(graph.end_all_traces_in_context()) build_response = VertexBuildResponse( inactivated_vertices=list(set(inactivated_vertices)), @@ -410,7 +419,7 @@ async def generate_flow_events( try: await asyncio.gather(*tasks) except asyncio.CancelledError: - background_tasks.add_task(graph.end_all_traces) + background_tasks.add_task(graph.end_all_traces_in_context()) raise except Exception as e: logger.error(f"Error building vertices: {e}") diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index 409f1b9f2..48549fd57 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -154,35 +154,37 @@ async def build_graph_from_data(flow_id: uuid.UUID | str, payload: dict, **kwarg # Get flow name if "flow_name" not in kwargs: flow_name = await _get_flow_name(flow_id if isinstance(flow_id, uuid.UUID) else uuid.UUID(flow_id)) - kwargs["flow_name"] = flow_name + else: + flow_name = kwargs["flow_name"] str_flow_id = str(flow_id) - graph = Graph.from_payload(payload, str_flow_id, **kwargs) + session_id = kwargs.get("session_id") or str_flow_id + + graph = Graph.from_payload(payload, str_flow_id, flow_name, kwargs.get("user_id")) for vertex_id in graph.has_session_id_vertices: vertex = graph.get_vertex(vertex_id) if vertex is None: msg = f"Vertex {vertex_id} not found" raise ValueError(msg) if not vertex.raw_params.get("session_id"): - vertex.update_raw_params({"session_id": str_flow_id}, overwrite=True) + vertex.update_raw_params({"session_id": session_id}, overwrite=True) - run_id = uuid.uuid4() - graph.set_run_id(run_id) - graph.set_run_name() + graph.session_id = session_id await graph.initialize_run() return graph -async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession): +async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession, **kwargs): """Build and cache the graph.""" flow: Flow | None = await session.get(Flow, flow_id) if not flow or not flow.data: msg = "Invalid flow ID" raise ValueError(msg) - return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id)) + kwargs["user_id"] = kwargs.get("user_id") or str(flow.user_id) + return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, **kwargs) -async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService): - graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session) +async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService, **kwargs): + graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session, **kwargs) await chat_service.set_cache(str(flow_id), graph) return graph diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index db4bd7ce3..0f2fed846 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -293,7 +293,7 @@ async def build_vertex( outputs = {output_label: OutputValue(message=message, type="error")} result_data_response = ResultDataResponse(results={}, outputs=outputs) artifacts = {} - background_tasks.add_task(graph.end_all_traces, error=exc) + background_tasks.add_task(graph.end_all_traces_in_context(error=exc)) # If there's an error building the vertex # we need to clear the cache await chat_service.clear_cache(flow_id_str) @@ -331,7 +331,7 @@ async def build_vertex( next_runnable_vertices = [graph.stop_vertex] if not graph.run_manager.vertices_being_run and not next_runnable_vertices: - background_tasks.add_task(graph.end_all_traces) + background_tasks.add_task(graph.end_all_traces_in_context()) build_response = VertexBuildResponse( inactivated_vertices=list(set(inactivated_vertices)), diff --git a/src/backend/base/langflow/components/langwatch/langwatch.py b/src/backend/base/langflow/components/langwatch/langwatch.py index f35b1265b..da36b529e 100644 --- a/src/backend/base/langflow/components/langwatch/langwatch.py +++ b/src/backend/base/langflow/components/langwatch/langwatch.py @@ -266,12 +266,10 @@ class LangWatchComponent(Component): "settings": {}, } - if ( - self._tracing_service - and self._tracing_service._tracers - and "langwatch" in self._tracing_service._tracers - ): - payload["trace_id"] = str(self._tracing_service._tracers["langwatch"].trace_id) # type: ignore[assignment] + if self._tracing_service: + tracer = self._tracing_service.get_tracer("langwatch") + if tracer is not None and hasattr(tracer, "trace_id"): + payload["settings"]["trace_id"] = str(tracer.trace_id) for setting_name in self.dynamic_inputs: payload["settings"][setting_name] = getattr(self, setting_name, None) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index c2fd5f369..d319a23d4 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -886,7 +886,7 @@ class Component(CustomComponent): async def _build_with_tracing(self): inputs = self.get_trace_as_inputs() metadata = self.get_trace_as_metadata() - async with self._tracing_service.trace_context(self, self.trace_name, inputs, metadata): + async with self._tracing_service.trace_component(self, self.trace_name, inputs, metadata): results, artifacts = await self._build_results() self._tracing_service.set_outputs(self.trace_name, results) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 45a3eeddc..df79ac575 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import contextlib +import contextvars import copy import json import queue @@ -43,7 +44,7 @@ from langflow.services.deps import get_chat_service, get_tracing_service from langflow.utils.async_helpers import run_until_complete if TYPE_CHECKING: - from collections.abc import Generator, Iterable + from collections.abc import Callable, Generator, Iterable from langflow.api.v1.schemas import InputValueRequest from langflow.custom.custom_component.component import Component @@ -636,15 +637,6 @@ class Graph: raise ValueError(msg) return self._run_id - def set_tracing_session_id(self) -> None: - """Sets the ID of the current session. - - Args: - session_id (str): The session ID. - """ - if self.tracing_service: - self.tracing_service.set_session_id(self._session_id) - def set_run_id(self, run_id: uuid.UUID | None = None) -> None: """Sets the ID of the current run. @@ -655,29 +647,37 @@ class Graph: run_id = uuid.uuid4() self._run_id = str(run_id) - if self.tracing_service: - self.tracing_service.set_run_id(run_id) - if self._session_id and self.tracing_service is not None: - self.tracing_service.set_session_id(self.session_id) - - def set_run_name(self) -> None: - # Given a flow name, flow_id - if not self.tracing_service: - return - name = f"{self.flow_name} - {self.flow_id}" - - self.set_run_id() - self.tracing_service.set_run_name(name) async def initialize_run(self) -> None: + if not self._run_id: + self.set_run_id() if self.tracing_service: - await self.tracing_service.initialize_tracers() + run_name = f"{self.flow_name} - {self.flow_id}" + await self.tracing_service.start_tracers( + run_id=uuid.UUID(self._run_id), + run_name=run_name, + user_id=self.user_id, + session_id=self.session_id, + ) def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None: task = asyncio.create_task(self.end_all_traces(outputs, error)) self._end_trace_tasks.add(task) task.add_done_callback(self._end_trace_tasks.discard) + def end_all_traces_in_context( + self, + outputs: dict[str, Any] | None = None, + error: Exception | None = None, + ) -> Callable: + # BackgroundTasks run in different context, so we need to copy the context + context = contextvars.copy_context() + + async def async_end_traces_func(): + await asyncio.create_task(self.end_all_traces(outputs, error), context=context) + + return async_end_traces_func + async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None: if not self.tracing_service: return @@ -685,7 +685,7 @@ class Graph: if outputs is None: outputs = {} outputs |= self.metadata - await self.tracing_service.end(outputs, error) + await self.tracing_service.end_tracers(outputs, error) @property def sorted_vertices_layers(self) -> list[list[str]]: @@ -853,6 +853,8 @@ class Graph: inputs_components.append([]) if types is None: types = [] + if session_id: + self.session_id = session_id for _ in range(len(inputs) - len(types)): types.append("chat") # default to chat for run_inputs, components, input_type in zip(inputs, inputs_components, types, strict=True): @@ -1053,7 +1055,6 @@ class Graph: self.state_manager = GraphStateManager() self.tracing_service = get_tracing_service() self.set_run_id(self._run_id) - self.set_run_name() @classmethod def from_payload( @@ -1527,8 +1528,6 @@ class Graph: to_process = deque(first_layer) layer_index = 0 chat_service = get_chat_service() - self.set_run_id() - self.set_run_name() await self.initialize_run() lock = asyncio.Lock() while to_process: diff --git a/src/backend/base/langflow/services/tracing/base.py b/src/backend/base/langflow/services/tracing/base.py index 9b51c6e38..38c90894b 100644 --- a/src/backend/base/langflow/services/tracing/base.py +++ b/src/backend/base/langflow/services/tracing/base.py @@ -17,7 +17,15 @@ class BaseTracer(ABC): trace_id: UUID @abstractmethod - def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID): + def __init__( + self, + trace_name: str, + trace_type: str, + project_name: str, + trace_id: UUID, + user_id: str | None = None, + session_id: str | None = None, + ) -> None: raise NotImplementedError @property diff --git a/src/backend/base/langflow/services/tracing/langfuse.py b/src/backend/base/langflow/services/tracing/langfuse.py index 9d8f474aa..58fe92a5e 100644 --- a/src/backend/base/langflow/services/tracing/langfuse.py +++ b/src/backend/base/langflow/services/tracing/langfuse.py @@ -23,11 +23,21 @@ if TYPE_CHECKING: class LangFuseTracer(BaseTracer): flow_id: str - def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID): + def __init__( + self, + trace_name: str, + trace_type: str, + project_name: str, + trace_id: UUID, + user_id: str | None = None, + session_id: str | None = None, + ) -> None: self.project_name = project_name self.trace_name = trace_name self.trace_type = trace_type self.trace_id = trace_id + self.user_id = user_id + self.session_id = session_id self.flow_id = trace_name.split(" - ")[-1] self.spans: dict = OrderedDict() # spans that are not ended @@ -43,7 +53,19 @@ class LangFuseTracer(BaseTracer): from langfuse import Langfuse self._client = Langfuse(**config) - self.trace = self._client.trace(id=str(self.trace_id), name=self.flow_id) + try: + from langfuse.api.core.request_options import RequestOptions + + self._client.client.health.health(request_options=RequestOptions(timeout_in_seconds=1)) + except Exception as e: # noqa: BLE001 + logger.debug(f"can not connect to Langfuse: {e}") + return False + self.trace = self._client.trace( + id=str(self.trace_id), + name=self.flow_id, + user_id=self.user_id, + session_id=self.session_id, + ) except ImportError: logger.exception("Could not import langfuse. Please install it with `pip install langfuse`.") @@ -69,7 +91,7 @@ class LangFuseTracer(BaseTracer): if not self._ready: return - metadata_: dict = {} + metadata_: dict = {"from_langflow_component": True, "component_id": trace_id} metadata_ |= {"trace_type": trace_type} if trace_type else {} metadata_ |= metadata or {} @@ -81,11 +103,12 @@ class LangFuseTracer(BaseTracer): "start_time": start_time, } - if len(self.spans) > 0: - last_span = next(reversed(self.spans)) - span = self.spans[last_span].span(**content_span) - else: - span = self.trace.span(**content_span) + # if two component is built concurrently, will use wrong last span. just flatten now, maybe fix in future. + # if len(self.spans) > 0: + # last_span = next(reversed(self.spans)) + # span = self.spans[last_span].span(**content_span) + # else: + span = self.trace.span(**content_span) self.spans[trace_id] = span @@ -121,7 +144,12 @@ class LangFuseTracer(BaseTracer): ) -> None: if not self._ready: return - + content_update = { + "input": inputs, + "output": outputs, + "metadata": metadata, + } + self.trace.update(**content_update) self._client.flush() def get_langchain_callback(self) -> BaseCallbackHandler | None: diff --git a/src/backend/base/langflow/services/tracing/service.py b/src/backend/base/langflow/services/tracing/service.py index cb3882d10..a9626c146 100644 --- a/src/backend/base/langflow/services/tracing/service.py +++ b/src/backend/base/langflow/services/tracing/service.py @@ -4,6 +4,7 @@ import asyncio import os from collections import defaultdict from contextlib import asynccontextmanager +from contextvars import ContextVar from typing import TYPE_CHECKING, Any from loguru import logger @@ -46,188 +47,204 @@ def _get_arize_phoenix_tracer(): return ArizePhoenixTracer -class TracingService(Service): - name = "tracing_service" +trace_context_var: ContextVar[TraceContext | None] = ContextVar("trace_context", default=None) +component_context_var: ContextVar[ComponentTraceContext | None] = ContextVar("component_trace_context", default=None) - def __init__(self, settings_service: SettingsService): - self.settings_service = settings_service - self.inputs: dict[str, dict] = defaultdict(dict) - self.inputs_metadata: dict[str, dict] = defaultdict(dict) - self.outputs: dict[str, dict] = defaultdict(dict) - self.outputs_metadata: dict[str, dict] = defaultdict(dict) - self.run_name: str | None = None - self.run_id: UUID | None = None - self.project_name: str | None = None - self._tracers: dict[str, BaseTracer] = {} - self._logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list) - self.end_trace_tasks: set[asyncio.Task] = set() - self.deactivated = self.settings_service.settings.deactivate_tracing - self.session_id: str | None = None - def _reset_io(self) -> None: - self.inputs = defaultdict(dict) - self.inputs_metadata = defaultdict(dict) - self.outputs = defaultdict(dict) - self.outputs_metadata = defaultdict(dict) +class TraceContext: + def __init__( + self, + run_id: UUID | None, + run_name: str | None, + project_name: str | None, + user_id: str | None, + session_id: str | None, + ): + self.run_id: UUID | None = run_id + self.run_name: str | None = run_name + self.project_name: str | None = project_name + self.user_id: str | None = user_id + self.session_id: str | None = session_id + self.tracers: dict[str, BaseTracer] = {} + self.all_inputs: dict[str, dict] = defaultdict(dict) + self.all_outputs: dict[str, dict] = defaultdict(dict) - async def initialize_tracers(self) -> None: - if self.deactivated: - return - try: - self._initialize_langsmith_tracer() - self._initialize_langwatch_tracer() - self._initialize_langfuse_tracer() - self._initialize_arize_phoenix_tracer() - except Exception: # noqa: BLE001 - logger.opt(exception=True).debug("Error initializing tracers") + self.traces_queue: asyncio.Queue = asyncio.Queue() + self.running = False + self.worker_task: asyncio.Task | None = None - def _initialize_langsmith_tracer(self) -> None: - project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow") - self.project_name = project_name - langsmith_tracer = _get_langsmith_tracer() - self._tracers["langsmith"] = langsmith_tracer( - trace_name=self.run_name, - trace_type="chain", - project_name=self.project_name, - trace_id=self.run_id, - ) - def _initialize_langwatch_tracer(self) -> None: - if "langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id: - langwatch_tracer = _get_langwatch_tracer() - self._tracers["langwatch"] = langwatch_tracer( - trace_name=self.run_name, - trace_type="chain", - project_name=self.project_name, - trace_id=self.run_id, - ) - - def _initialize_langfuse_tracer(self) -> None: - self.project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow") - langfuse_tracer = _get_langfuse_tracer() - self._tracers["langfuse"] = langfuse_tracer( - trace_name=self.run_name, - trace_type="chain", - project_name=self.project_name, - trace_id=self.run_id, - ) - - def _initialize_arize_phoenix_tracer(self) -> None: - self.project_name = os.getenv("ARIZE_PHOENIX_PROJECT", "Langflow") - arize_phoenix_tracer = _get_arize_phoenix_tracer() - self._tracers["arize_phoenix"] = arize_phoenix_tracer( - trace_name=self.run_name, - trace_type="chain", - project_name=self.project_name, - trace_id=self.run_id, - session_id=self.session_id, - ) - - def set_run_name(self, name: str) -> None: - self.run_name = name - - def set_run_id(self, run_id: UUID) -> None: - self.run_id = run_id - - def _start_traces( +class ComponentTraceContext: + def __init__( self, trace_id: str, trace_name: str, trace_type: str, - inputs: dict[str, Any], - metadata: dict[str, Any] | None = None, - vertex: Vertex | None = None, - ) -> None: - inputs = self._cleanup_inputs(inputs) - self.inputs[trace_name] = inputs - self.inputs_metadata[trace_name] = metadata or {} - for tracer in self._tracers.values(): - if not tracer.ready: - continue - try: - tracer.add_trace(trace_id, trace_name, trace_type, inputs, metadata, vertex) - except Exception: # noqa: BLE001 - logger.exception(f"Error starting trace {trace_name}") + vertex: Vertex | None, + inputs: dict[str, dict], + metadata: dict[str, dict] | None = None, + ): + self.trace_id: str = trace_id + self.trace_name: str = trace_name + self.trace_type: str = trace_type + self.vertex: Vertex | None = vertex + self.inputs: dict[str, dict] = inputs + self.inputs_metadata: dict[str, dict] = metadata or {} + self.outputs: dict[str, dict] = defaultdict(dict) + self.outputs_metadata: dict[str, dict] = defaultdict(dict) + self.logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list) - def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None: - for tracer in self._tracers.values(): + +class TracingService(Service): + """Tracing service. + + To trace a graph run: + 1. start_tracers: start a trace for a graph run + 2. with trace_component: start a sub-trace for a component build, three methods are available: + - add_log + - set_outputs + - get_langchain_callbacks + 3. end_tracers: end the trace for a graph run + + check context var in public methods. + """ + + name = "tracing_service" + + def __init__(self, settings_service: SettingsService): + self.settings_service = settings_service + self.deactivated = self.settings_service.settings.deactivate_tracing + + async def _trace_worker(self, trace_context: TraceContext) -> None: + while trace_context.running or not trace_context.traces_queue.empty(): + trace_func, args = await trace_context.traces_queue.get() + try: + trace_func(*args) + except Exception: # noqa: BLE001 + logger.exception("Error processing trace_func") + finally: + trace_context.traces_queue.task_done() + + async def _start(self, trace_context: TraceContext) -> None: + if trace_context.running: + return + try: + trace_context.running = True + trace_context.worker_task = asyncio.create_task(self._trace_worker(trace_context)) + except Exception: # noqa: BLE001 + logger.exception("Error starting tracing service") + + def _initialize_langsmith_tracer(self, trace_context: TraceContext) -> None: + langsmith_tracer = _get_langsmith_tracer() + trace_context.tracers["langsmith"] = langsmith_tracer( + trace_name=trace_context.run_name, + trace_type="chain", + project_name=trace_context.project_name, + trace_id=trace_context.run_id, + ) + + def _initialize_langwatch_tracer(self, trace_context: TraceContext) -> None: + if ( + "langwatch" not in trace_context.tracers + or trace_context.tracers["langwatch"].trace_id != trace_context.run_id + ): + langwatch_tracer = _get_langwatch_tracer() + trace_context.tracers["langwatch"] = langwatch_tracer( + trace_name=trace_context.run_name, + trace_type="chain", + project_name=trace_context.project_name, + trace_id=trace_context.run_id, + ) + + def _initialize_langfuse_tracer(self, trace_context: TraceContext) -> None: + langfuse_tracer = _get_langfuse_tracer() + trace_context.tracers["langfuse"] = langfuse_tracer( + trace_name=trace_context.run_name, + trace_type="chain", + project_name=trace_context.project_name, + trace_id=trace_context.run_id, + user_id=trace_context.user_id, + session_id=trace_context.session_id, + ) + + def _initialize_arize_phoenix_tracer(self, trace_context: TraceContext) -> None: + arize_phoenix_tracer = _get_arize_phoenix_tracer() + trace_context.tracers["arize_phoenix"] = arize_phoenix_tracer( + trace_name=trace_context.run_name, + trace_type="chain", + project_name=trace_context.project_name, + trace_id=trace_context.run_id, + ) + + async def start_tracers( + self, + run_id: UUID, + run_name: str, + user_id: str | None, + session_id: str | None, + project_name: str | None = None, + ) -> None: + """Start a trace for a graph run. + + - create a trace context + - start a worker for this trace context + - initialize the tracers + """ + if self.deactivated: + return + try: + project_name = project_name or os.getenv("LANGCHAIN_PROJECT", "Langflow") + trace_context = TraceContext(run_id, run_name, project_name, user_id, session_id) + trace_context_var.set(trace_context) + await self._start(trace_context) + self._initialize_langsmith_tracer(trace_context) + self._initialize_langwatch_tracer(trace_context) + self._initialize_langfuse_tracer(trace_context) + self._initialize_arize_phoenix_tracer(trace_context) + except Exception as e: # noqa: BLE001 + logger.debug(f"Error initializing tracers: {e}") + + async def _stop(self, trace_context: TraceContext) -> None: + try: + trace_context.running = False + # check the qeue is empty + if not trace_context.traces_queue.empty(): + await trace_context.traces_queue.join() + if trace_context.worker_task: + trace_context.worker_task.cancel() + trace_context.worker_task = None + + except Exception: # noqa: BLE001 + logger.exception("Error stopping tracing service") + + def _end_all_tracers(self, trace_context: TraceContext, outputs: dict, error: Exception | None = None) -> None: + for tracer in trace_context.tracers.values(): if tracer.ready: try: - tracer.end_trace( - trace_id=trace_id, - trace_name=trace_name, - outputs=self.outputs[trace_name], + # why all_inputs and all_outputs? why metadata=outputs? + tracer.end( + trace_context.all_inputs, + outputs=trace_context.all_outputs, error=error, - logs=self._logs[trace_name], + metadata=outputs, ) - except Exception: # noqa: BLE001 - logger.exception(f"Error ending trace {trace_name}") - self._reset_io() - - def _end_all_traces(self, outputs: dict, error: Exception | None = None) -> None: - for tracer in self._tracers.values(): - if tracer.ready: - try: - tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs) except Exception: # noqa: BLE001 logger.exception("Error ending all traces") - self._reset_io() - async def end(self, outputs: dict, error: Exception | None = None) -> None: - await asyncio.to_thread(self._end_all_traces, outputs, error) - try: - # Wait for any pending trace tasks to complete - await asyncio.gather(*self.end_trace_tasks) - except Exception: # noqa: BLE001 - logger.exception("Error flushing logs") + async def end_tracers(self, outputs: dict, error: Exception | None = None) -> None: + """End the trace for a graph run. - def add_log(self, trace_name: str, log: Log) -> None: - self._logs[trace_name].append(log) - - @asynccontextmanager - async def trace_context( - self, - component: Component, - trace_name: str, - inputs: dict[str, Any], - metadata: dict[str, Any] | None = None, - ): + - stop worker for current trace_context + - call end for all the tracers + """ if self.deactivated: - yield self return - trace_id = trace_name - if component._vertex: - trace_id = component._vertex.id - trace_type = component.trace_type - self._start_traces( - trace_id, - trace_name, - trace_type, - inputs, - metadata, - component._vertex, - ) - try: - yield self - except Exception as e: - self._end_and_reset(trace_id, trace_name, e) - raise - else: - self._end_and_reset(trace_id, trace_name) - - def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None: - task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error)) - self.end_trace_tasks.add(task) - task.add_done_callback(self.end_trace_tasks.discard) - - def set_outputs( - self, - trace_name: str, - outputs: dict[str, Any], - output_metadata: dict[str, Any] | None = None, - ) -> None: - self.outputs[trace_name] |= outputs or {} - self.outputs_metadata[trace_name] |= output_metadata or {} + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called end_tracers but no trace context found" + raise RuntimeError(msg) + await self._stop(trace_context) + self._end_all_tracers(trace_context, outputs, error) @staticmethod def _cleanup_inputs(inputs: dict[str, Any]): @@ -237,18 +254,152 @@ class TracingService(Service): inputs[key] = "*****" # avoid logging api_keys for security reasons return inputs + def _start_component_traces( + self, + component_trace_context: ComponentTraceContext, + trace_context: TraceContext, + ) -> None: + inputs = self._cleanup_inputs(component_trace_context.inputs) + component_trace_context.inputs = inputs + component_trace_context.inputs_metadata = component_trace_context.inputs_metadata or {} + for tracer in trace_context.tracers.values(): + if not tracer.ready: + continue + try: + tracer.add_trace( + component_trace_context.trace_id, + component_trace_context.trace_name, + component_trace_context.trace_type, + inputs, + component_trace_context.inputs_metadata, + component_trace_context.vertex, + ) + except Exception: # noqa: BLE001 + logger.exception(f"Error starting trace {component_trace_context.trace_name}") + + def _end_component_traces( + self, + component_trace_context: ComponentTraceContext, + trace_context: TraceContext, + error: Exception | None = None, + ) -> None: + for tracer in trace_context.tracers.values(): + if tracer.ready: + try: + tracer.end_trace( + trace_id=component_trace_context.trace_id, + trace_name=component_trace_context.trace_name, + outputs=trace_context.all_outputs[component_trace_context.trace_name], + error=error, + logs=component_trace_context.logs[component_trace_context.trace_name], + ) + except Exception: # noqa: BLE001 + logger.exception(f"Error ending trace {component_trace_context.trace_name}") + + @asynccontextmanager + async def trace_component( + self, + component: Component, + trace_name: str, + inputs: dict[str, Any], + metadata: dict[str, Any] | None = None, + ): + """Trace a component. + + @param component: the component to trace + @param trace_name: component name + component id + @param inputs: the inputs to the component + @param metadata: the metadata to the component + """ + if self.deactivated: + yield self + return + trace_id = trace_name + if component._vertex: + trace_id = component._vertex.id + trace_type = component.trace_type + component_trace_context = ComponentTraceContext( + trace_id, trace_name, trace_type, component._vertex, inputs, metadata + ) + component_context_var.set(component_trace_context) + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called trace_component but no trace context found" + raise RuntimeError(msg) + trace_context.all_inputs[trace_name] |= inputs or {} + await trace_context.traces_queue.put((self._start_component_traces, (component_trace_context, trace_context))) + try: + yield self + except Exception as e: + await trace_context.traces_queue.put( + (self._end_component_traces, (component_trace_context, trace_context, e)) + ) + raise + else: + await trace_context.traces_queue.put( + (self._end_component_traces, (component_trace_context, trace_context, None)) + ) + + @property + def project_name(self): + if self.deactivated: + return os.getenv("LANGCHAIN_PROJECT", "Langflow") + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called project_name but no trace context found" + raise RuntimeError(msg) + return trace_context.project_name + + def add_log(self, trace_name: str, log: Log) -> None: + """Add a log to the current component trace context.""" + if self.deactivated: + return + component_context = component_context_var.get() + if component_context is None: + msg = "called add_log but no component context found" + raise RuntimeError(msg) + component_context.logs[trace_name].append(log) + + def set_outputs( + self, + trace_name: str, + outputs: dict[str, Any], + output_metadata: dict[str, Any] | None = None, + ) -> None: + """Set the outputs for the current component trace context.""" + if self.deactivated: + return + component_context = component_context_var.get() + if component_context is None: + msg = "called set_outputs but no component context found" + raise RuntimeError(msg) + component_context.outputs[trace_name] |= outputs or {} + component_context.outputs_metadata[trace_name] |= output_metadata or {} + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called set_outputs but no trace context found" + raise RuntimeError(msg) + trace_context.all_outputs[trace_name] |= outputs or {} + + def get_tracer(self, tracer_name: str) -> BaseTracer | None: + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called get_tracer but no trace context found" + raise RuntimeError(msg) + return trace_context.tracers.get(tracer_name) + def get_langchain_callbacks(self) -> list[BaseCallbackHandler]: if self.deactivated: return [] callbacks = [] - for tracer in self._tracers.values(): + trace_context = trace_context_var.get() + if trace_context is None: + msg = "called get_langchain_callbacks but no trace context found" + raise RuntimeError(msg) + for tracer in trace_context.tracers.values(): if not tracer.ready: # type: ignore[truthy-function] continue langchain_callback = tracer.get_langchain_callback() if langchain_callback: callbacks.append(langchain_callback) return callbacks - - def set_session_id(self, session_id: str) -> None: - """Set the session ID for tracing.""" - self.session_id = session_id diff --git a/src/backend/tests/unit/services/tracing/__init__.py b/src/backend/tests/unit/services/tracing/__init__.py new file mode 100644 index 000000000..18f517ba6 --- /dev/null +++ b/src/backend/tests/unit/services/tracing/__init__.py @@ -0,0 +1 @@ +"""Services tests package.""" diff --git a/src/backend/tests/unit/services/tracing/test_tracing_service.py b/src/backend/tests/unit/services/tracing/test_tracing_service.py new file mode 100644 index 000000000..28796d142 --- /dev/null +++ b/src/backend/tests/unit/services/tracing/test_tracing_service.py @@ -0,0 +1,515 @@ +import asyncio +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from langflow.services.settings.base import Settings +from langflow.services.settings.service import SettingsService +from langflow.services.tracing.base import BaseTracer +from langflow.services.tracing.service import ( + TracingService, + component_context_var, + trace_context_var, +) + + +class MockTracer(BaseTracer): + def __init__( + self, + trace_name: str, + trace_type: str, + project_name: str, + trace_id: uuid.UUID, + user_id: str | None = None, + session_id: str | None = None, + ) -> None: + self.trace_name = trace_name + self.trace_type = trace_type + self.project_name = project_name + self.trace_id = trace_id + self.user_id = user_id + self.session_id = session_id + self._ready = True + self.end_called = False + self.get_langchain_callback_called = False + self.add_trace_list = [] + self.end_trace_list = [] + + @property + def ready(self) -> bool: + return self._ready + + def add_trace( + self, + trace_id: str, + trace_name: str, + trace_type: str, + inputs: dict[str, any], + metadata: dict[str, any] | None = None, + vertex=None, + ) -> None: + self.add_trace_list.append( + { + "trace_id": trace_id, + "trace_name": trace_name, + "trace_type": trace_type, + "inputs": inputs, + "metadata": metadata, + "vertex": vertex, + } + ) + + def end_trace( + self, + trace_id: str, + trace_name: str, + outputs: dict[str, any] | None = None, + error: Exception | None = None, + logs=(), + ) -> None: + self.end_trace_list.append( + { + "trace_id": trace_id, + "trace_name": trace_name, + "outputs": outputs, + "error": error, + "logs": logs, + } + ) + + def end( + self, + inputs: dict[str, any], + outputs: dict[str, any], + error: Exception | None = None, + metadata: dict[str, any] | None = None, + ) -> None: + self.end_called = True + self.inputs_param = inputs + self.outputs_param = outputs + self.error_param = error + self.metadata_param = metadata + + def get_langchain_callback(self): + self.get_langchain_callback_called = True + return MagicMock() + + +@pytest.fixture +def mock_settings_service(): + settings = Settings() + settings.deactivate_tracing = False + return SettingsService(settings, MagicMock()) + + +@pytest.fixture +def tracing_service(mock_settings_service): + return TracingService(mock_settings_service) + + +@pytest.fixture +def mock_component(): + component = MagicMock() + component._vertex = MagicMock() + component._vertex.id = "test_vertex_id" + component.trace_type = "test_trace_type" + return component + + +@pytest.fixture +def mock_tracers(): + with ( + patch( + "langflow.services.tracing.service._get_langsmith_tracer", + return_value=MockTracer, + ), + patch( + "langflow.services.tracing.service._get_langwatch_tracer", + return_value=MockTracer, + ), + patch( + "langflow.services.tracing.service._get_langfuse_tracer", + return_value=MockTracer, + ), + patch( + "langflow.services.tracing.service._get_arize_phoenix_tracer", + return_value=MockTracer, + ), + ): + yield + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_start_end_tracers(tracing_service): + """Test starting and ending tracers.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + outputs = {"output_key": "output_value"} + + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + # Verify trace_context is set correctly + trace_context = trace_context_var.get() + assert trace_context is not None + assert trace_context.run_id == run_id + assert trace_context.run_name == run_name + assert trace_context.project_name == project_name + assert trace_context.user_id == user_id + assert trace_context.session_id == session_id + + # Verify tracers are initialized + assert "langsmith" in trace_context.tracers + assert "langwatch" in trace_context.tracers + assert "langfuse" in trace_context.tracers + assert "arize_phoenix" in trace_context.tracers + + await tracing_service.end_tracers(outputs) + + # Verify end method was called for all tracers + trace_context = trace_context_var.get() + for tracer in trace_context.tracers.values(): + assert tracer.end_called + assert tracer.metadata_param == outputs + assert tracer.outputs_param == trace_context.all_outputs + + # Verify worker_task is cancelled + assert trace_context.worker_task is None + assert not trace_context.running + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_trace_component(tracing_service, mock_component): + """Test component tracing context manager.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + trace_name = "test_component_trace" + inputs = {"input_key": "input_value"} + metadata = {"metadata_key": "metadata_value"} + + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + async with tracing_service.trace_component(mock_component, trace_name, inputs, metadata) as ts: + # Verify component context is set + component_context = component_context_var.get() + assert component_context is not None + assert component_context.trace_id == mock_component._vertex.id + assert component_context.trace_name == trace_name + assert component_context.trace_type == mock_component.trace_type + assert component_context.vertex == mock_component._vertex + assert component_context.inputs == inputs + assert component_context.inputs_metadata == metadata + + # Verify add_trace method was called for tracers + await asyncio.sleep(0.1) # Wait for async queue processing + trace_context = trace_context_var.get() + for tracer in trace_context.tracers.values(): + assert tracer.add_trace_list[0]["trace_id"] == mock_component._vertex.id + assert tracer.add_trace_list[0]["trace_name"] == trace_name + assert tracer.add_trace_list[0]["trace_type"] == mock_component.trace_type + assert tracer.add_trace_list[0]["inputs"] == inputs + assert tracer.add_trace_list[0]["metadata"] == metadata + assert tracer.add_trace_list[0]["vertex"] == mock_component._vertex + + # Test adding logs + ts.add_log(trace_name, {"message": "test log"}) + assert {"message": "test log"} in component_context.logs[trace_name] + + # Test setting outputs + outputs = {"output_key": "output_value"} + output_metadata = {"output_metadata_key": "output_metadata_value"} + ts.set_outputs(trace_name, outputs, output_metadata) + assert component_context.outputs[trace_name] == outputs + assert component_context.outputs_metadata[trace_name] == output_metadata + assert trace_context.all_outputs[trace_name] == outputs + + # Verify end_trace method was called for tracers + await asyncio.sleep(0.1) # Wait for async queue processing + for tracer in trace_context.tracers.values(): + assert tracer.end_trace_list[0]["trace_id"] == mock_component._vertex.id + assert tracer.end_trace_list[0]["trace_name"] == trace_name + assert tracer.end_trace_list[0]["outputs"] == trace_context.all_outputs[trace_name] + assert tracer.end_trace_list[0]["error"] is None + assert tracer.end_trace_list[0]["logs"] == component_context.logs[trace_name] + + # Cleanup + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_trace_component_with_exception(tracing_service, mock_component): + """Test component tracing context manager with exception handling.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + trace_name = "test_component_trace" + inputs = {"input_key": "input_value"} + + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + test_exception = ValueError("Test exception") + + with pytest.raises(ValueError, match="Test exception"): + async with tracing_service.trace_component(mock_component, trace_name, inputs): + raise test_exception + + # Verify end_trace method was called with exception + await asyncio.sleep(0.1) # Wait for async queue processing + trace_context = trace_context_var.get() + for tracer in trace_context.tracers.values(): + assert tracer.end_trace_list[0]["error"] == test_exception + + # Cleanup + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_get_langchain_callbacks(tracing_service): + """Test getting LangChain callback handlers.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + callbacks = tracing_service.get_langchain_callbacks() + + # Verify get_langchain_callback method was called for each tracer + trace_context = trace_context_var.get() + for tracer in trace_context.tracers.values(): + assert tracer.get_langchain_callback_called + + # Verify returned callbacks list length + assert len(callbacks) == 4 # Four tracers + + # Cleanup + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +async def test_deactivated_tracing(mock_settings_service): + """Test deactivated tracing functionality.""" + # Set deactivate_tracing to True + mock_settings_service.settings.deactivate_tracing = True + tracing_service = TracingService(mock_settings_service) + + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + # Starting tracers should have no effect + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + # With tracing disabled, trace_context_var may be None or uninitialized + assert trace_context_var.get() is None + # We don't need to check trace_context_var state, just verify tracing operations don't execute + + # Test trace_component context manager + mock_component = MagicMock() + trace_name = "test_component_trace" + inputs = {"input_key": "input_value"} + + async with tracing_service.trace_component(mock_component, trace_name, inputs) as ts: + ts.add_log(trace_name, {"message": "test log"}) + ts.set_outputs(trace_name, {"output_key": "output_value"}) + + # Test getting LangChain callback handlers + callbacks = tracing_service.get_langchain_callbacks() + assert len(callbacks) == 0 # Should return empty list when tracing is disabled + + # Test end_tracers + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +async def test_cleanup_inputs(): + """Test cleaning sensitive information from input data.""" + inputs = { + "normal_key": "normal_value", + "api_key": "secret_api_key", + "openai_api_key": "secret_openai_api_key", + "nested_api_key": {"api_key": "nested_secret"}, + } + + cleaned_inputs = TracingService._cleanup_inputs(inputs) + + # Verify values for keys containing api_key are replaced with ***** + assert cleaned_inputs["normal_key"] == "normal_value" + assert cleaned_inputs["api_key"] == "*****" + assert cleaned_inputs["openai_api_key"] == "*****" + + # Verify values for keys containing api_key are replaced with *****, even in nested dicts + assert cleaned_inputs["nested_api_key"] == "*****" + + # Verify original input is not modified + assert inputs["api_key"] == "secret_api_key" + assert inputs["openai_api_key"] == "secret_openai_api_key" + + +@pytest.mark.asyncio +async def test_start_tracers_with_exception(tracing_service): + """Test starting tracers with exception handling.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + # Mock _initialize_langsmith_tracer to raise exception + with ( + patch.object( + tracing_service, + "_initialize_langsmith_tracer", + side_effect=Exception("Mock exception"), + ), + patch("langflow.services.tracing.service.logger.debug") as mock_logger, + ): + # start_tracers should return normally even with exception + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + # Verify exception was logged + mock_logger.assert_any_call("Error initializing tracers: Mock exception") + + # Verify trace_context was set even with exception + trace_context = trace_context_var.get() + assert trace_context is not None + assert trace_context.run_id == run_id + assert trace_context.run_name == run_name + + # Cleanup + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_trace_worker_with_exception(tracing_service): + """Test trace worker exception handling.""" + run_id = uuid.uuid4() + run_name = "test_run" + user_id = "test_user" + session_id = "test_session" + project_name = "test_project" + + # Create a trace function that raises an exception + def failing_trace_func(): + msg = "Mock trace function exception" + raise ValueError(msg) + + with patch("langflow.services.tracing.service.logger.exception") as mock_logger: + # Remove incorrect context manager usage + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + # Get trace_context and add failing trace function to queue + trace_context = trace_context_var.get() + await trace_context.traces_queue.put((failing_trace_func, ())) + + # Wait for async queue processing + await asyncio.sleep(0.1) + + # Verify exception was logged + mock_logger.assert_called_with("Error processing trace_func") + + # Cleanup + await tracing_service.end_tracers({}) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_tracers") +async def test_concurrent_tracing(tracing_service, mock_component): + """Test two tasks running start_tracers concurrently, with each task running 2 concurrent trace_component tasks.""" + + # Define common task function: start tracers and run two component traces + async def run_task( + run_id, + run_name, + user_id, + session_id, + project_name, + inputs, + metadata, + task_prefix, + sleep_duration=0.1, + ): + await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name) + + async def run_component_task(component, trace_name, component_suffix): + async with tracing_service.trace_component(component, trace_name, inputs, metadata) as ts: + ts.add_log(trace_name, {"message": f"{task_prefix} {component_suffix} log"}) + outputs = {"output_key": f"{task_prefix}_{component_suffix}_output"} + await asyncio.sleep(sleep_duration) + ts.set_outputs(trace_name, outputs) + + task1 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name1", f"{run_id} component1")) + await task1 + task2 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name2", f"{run_id} component2")) + await task2 + + await tracing_service.end_tracers({"final_output": f"{task_prefix}_final_output"}) + trace_context = trace_context_var.get() + return trace_context.tracers["langfuse"] + + inputs1 = {"input_key": "input_value1"} + metadata1 = {"metadata_key": "metadata_value1"} + inputs2 = {"input_key": "input_value2"} + metadata2 = {"metadata_key": "metadata_value2"} + + task1 = asyncio.create_task( + run_task( + "run_id1", + "run_name1", + "user_id1", + "session_id1", + "project_name1", + inputs1, + metadata1, + "task1", + 2, + ) + ) + await asyncio.sleep(0.1) + task2 = asyncio.create_task( + run_task( + "run_id2", + "run_name2", + "user_id2", + "session_id2", + "project_name2", + inputs2, + metadata2, + "task2", + 0.1, + ) + ) + tracer1 = await task1 + tracer2 = await task2 + + # Verify tracer1 and tracer2 have correct trace data + assert tracer1.trace_name == "run_name1" + assert tracer1.project_name == "project_name1" + assert tracer1.user_id == "user_id1" + assert tracer1.session_id == "session_id1" + assert dict(tracer1.outputs_param.get("run_id1 trace_name1")) == {"output_key": "task1_run_id1 component1_output"} + assert dict(tracer1.outputs_param.get("run_id1 trace_name2")) == {"output_key": "task1_run_id1 component2_output"} + + assert tracer2.trace_name == "run_name2" + assert tracer2.project_name == "project_name2" + assert tracer2.user_id == "user_id2" + assert tracer2.session_id == "session_id2" + assert dict(tracer2.outputs_param.get("run_id2 trace_name1")) == {"output_key": "task2_run_id2 component1_output"} + assert dict(tracer2.outputs_param.get("run_id2 trace_name2")) == {"output_key": "task2_run_id2 component2_output"}