fix: refactor tracing service (#7015)

* fix: refactor tracing service
1. fix race condition when concurrent run flow #6899
2. make start_trace/end_trace both async
3. add session_id/user_id to trace #4274

* fix: merge

* feat: unittest trace service

* fix: ruff style

* fix: handle tracing deactivated

* fix: ruff style

* feat: detect langfuse ready by health()

* fix: ruff style

* Update src/backend/base/langflow/api/build.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/api/build.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/graph/graph/base.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/langfuse.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update src/backend/base/langflow/services/tracing/service.py

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>

* Update service.py

* fix: code style

* fix: style

* feat: langwatch component get_tracer from service

* fix: only get contextvar in public method, check and raise

* fix: style, long arg names

---------

Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
tianzhipeng 2025-03-15 01:01:34 +08:00 committed by GitHub
commit eeeb09e7ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 942 additions and 231 deletions

View file

@ -180,9 +180,6 @@ async def generate_flow_events(
graph.validate_stream()
first_layer = sort_vertices(graph)
if inputs is not None and getattr(inputs, "session", None) is not None:
graph.session_id = inputs.session
for vertex_id in first_layer:
graph.run_manager.add_to_vertices_being_run(vertex_id)
@ -218,8 +215,19 @@ async def generate_flow_events(
)
async def create_graph(fresh_session, flow_id_str: str) -> Graph:
if inputs is not None and getattr(inputs, "session", None) is not None:
effective_session_id = inputs.session
else:
effective_session_id = flow_id_str
if not data:
return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
return await build_graph_from_db(
flow_id=flow_id,
session=fresh_session,
chat_service=chat_service,
user_id=str(current_user.id),
session_id=effective_session_id,
)
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
flow_name = result.first()
@ -229,6 +237,7 @@ async def generate_flow_events(
payload=data.model_dump(),
user_id=str(current_user.id),
flow_name=flow_name,
session_id=effective_session_id,
)
def sort_vertices(graph: Graph) -> list[str]:
@ -280,7 +289,7 @@ async def generate_flow_events(
outputs = {output_label: OutputValue(message=message, type="error")}
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)
background_tasks.add_task(graph.end_all_traces_in_context(error=exc))
result_data_response.message = artifacts
@ -314,7 +323,7 @@ async def generate_flow_events(
next_runnable_vertices = [graph.stop_vertex]
if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
background_tasks.add_task(graph.end_all_traces)
background_tasks.add_task(graph.end_all_traces_in_context())
build_response = VertexBuildResponse(
inactivated_vertices=list(set(inactivated_vertices)),
@ -410,7 +419,7 @@ async def generate_flow_events(
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
background_tasks.add_task(graph.end_all_traces)
background_tasks.add_task(graph.end_all_traces_in_context())
raise
except Exception as e:
logger.error(f"Error building vertices: {e}")

View file

@ -154,35 +154,37 @@ async def build_graph_from_data(flow_id: uuid.UUID | str, payload: dict, **kwarg
# Get flow name
if "flow_name" not in kwargs:
flow_name = await _get_flow_name(flow_id if isinstance(flow_id, uuid.UUID) else uuid.UUID(flow_id))
kwargs["flow_name"] = flow_name
else:
flow_name = kwargs["flow_name"]
str_flow_id = str(flow_id)
graph = Graph.from_payload(payload, str_flow_id, **kwargs)
session_id = kwargs.get("session_id") or str_flow_id
graph = Graph.from_payload(payload, str_flow_id, flow_name, kwargs.get("user_id"))
for vertex_id in graph.has_session_id_vertices:
vertex = graph.get_vertex(vertex_id)
if vertex is None:
msg = f"Vertex {vertex_id} not found"
raise ValueError(msg)
if not vertex.raw_params.get("session_id"):
vertex.update_raw_params({"session_id": str_flow_id}, overwrite=True)
vertex.update_raw_params({"session_id": session_id}, overwrite=True)
run_id = uuid.uuid4()
graph.set_run_id(run_id)
graph.set_run_name()
graph.session_id = session_id
await graph.initialize_run()
return graph
async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession):
async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession, **kwargs):
"""Build and cache the graph."""
flow: Flow | None = await session.get(Flow, flow_id)
if not flow or not flow.data:
msg = "Invalid flow ID"
raise ValueError(msg)
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))
kwargs["user_id"] = kwargs.get("user_id") or str(flow.user_id)
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, **kwargs)
async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService):
graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session)
async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService, **kwargs):
graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session, **kwargs)
await chat_service.set_cache(str(flow_id), graph)
return graph

View file

@ -293,7 +293,7 @@ async def build_vertex(
outputs = {output_label: OutputValue(message=message, type="error")}
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)
background_tasks.add_task(graph.end_all_traces_in_context(error=exc))
# If there's an error building the vertex
# we need to clear the cache
await chat_service.clear_cache(flow_id_str)
@ -331,7 +331,7 @@ async def build_vertex(
next_runnable_vertices = [graph.stop_vertex]
if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
background_tasks.add_task(graph.end_all_traces)
background_tasks.add_task(graph.end_all_traces_in_context())
build_response = VertexBuildResponse(
inactivated_vertices=list(set(inactivated_vertices)),

View file

@ -266,12 +266,10 @@ class LangWatchComponent(Component):
"settings": {},
}
if (
self._tracing_service
and self._tracing_service._tracers
and "langwatch" in self._tracing_service._tracers
):
payload["trace_id"] = str(self._tracing_service._tracers["langwatch"].trace_id) # type: ignore[assignment]
if self._tracing_service:
tracer = self._tracing_service.get_tracer("langwatch")
if tracer is not None and hasattr(tracer, "trace_id"):
payload["settings"]["trace_id"] = str(tracer.trace_id)
for setting_name in self.dynamic_inputs:
payload["settings"][setting_name] = getattr(self, setting_name, None)

View file

@ -886,7 +886,7 @@ class Component(CustomComponent):
async def _build_with_tracing(self):
inputs = self.get_trace_as_inputs()
metadata = self.get_trace_as_metadata()
async with self._tracing_service.trace_context(self, self.trace_name, inputs, metadata):
async with self._tracing_service.trace_component(self, self.trace_name, inputs, metadata):
results, artifacts = await self._build_results()
self._tracing_service.set_outputs(self.trace_name, results)

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import contextlib
import contextvars
import copy
import json
import queue
@ -43,7 +44,7 @@ from langflow.services.deps import get_chat_service, get_tracing_service
from langflow.utils.async_helpers import run_until_complete
if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from langflow.api.v1.schemas import InputValueRequest
from langflow.custom.custom_component.component import Component
@ -636,15 +637,6 @@ class Graph:
raise ValueError(msg)
return self._run_id
def set_tracing_session_id(self) -> None:
"""Sets the ID of the current session.
Args:
session_id (str): The session ID.
"""
if self.tracing_service:
self.tracing_service.set_session_id(self._session_id)
def set_run_id(self, run_id: uuid.UUID | None = None) -> None:
"""Sets the ID of the current run.
@ -655,29 +647,37 @@ class Graph:
run_id = uuid.uuid4()
self._run_id = str(run_id)
if self.tracing_service:
self.tracing_service.set_run_id(run_id)
if self._session_id and self.tracing_service is not None:
self.tracing_service.set_session_id(self.session_id)
def set_run_name(self) -> None:
# Given a flow name, flow_id
if not self.tracing_service:
return
name = f"{self.flow_name} - {self.flow_id}"
self.set_run_id()
self.tracing_service.set_run_name(name)
async def initialize_run(self) -> None:
if not self._run_id:
self.set_run_id()
if self.tracing_service:
await self.tracing_service.initialize_tracers()
run_name = f"{self.flow_name} - {self.flow_id}"
await self.tracing_service.start_tracers(
run_id=uuid.UUID(self._run_id),
run_name=run_name,
user_id=self.user_id,
session_id=self.session_id,
)
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
task = asyncio.create_task(self.end_all_traces(outputs, error))
self._end_trace_tasks.add(task)
task.add_done_callback(self._end_trace_tasks.discard)
def end_all_traces_in_context(
self,
outputs: dict[str, Any] | None = None,
error: Exception | None = None,
) -> Callable:
# BackgroundTasks run in different context, so we need to copy the context
context = contextvars.copy_context()
async def async_end_traces_func():
await asyncio.create_task(self.end_all_traces(outputs, error), context=context)
return async_end_traces_func
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
if not self.tracing_service:
return
@ -685,7 +685,7 @@ class Graph:
if outputs is None:
outputs = {}
outputs |= self.metadata
await self.tracing_service.end(outputs, error)
await self.tracing_service.end_tracers(outputs, error)
@property
def sorted_vertices_layers(self) -> list[list[str]]:
@ -853,6 +853,8 @@ class Graph:
inputs_components.append([])
if types is None:
types = []
if session_id:
self.session_id = session_id
for _ in range(len(inputs) - len(types)):
types.append("chat") # default to chat
for run_inputs, components, input_type in zip(inputs, inputs_components, types, strict=True):
@ -1053,7 +1055,6 @@ class Graph:
self.state_manager = GraphStateManager()
self.tracing_service = get_tracing_service()
self.set_run_id(self._run_id)
self.set_run_name()
@classmethod
def from_payload(
@ -1527,8 +1528,6 @@ class Graph:
to_process = deque(first_layer)
layer_index = 0
chat_service = get_chat_service()
self.set_run_id()
self.set_run_name()
await self.initialize_run()
lock = asyncio.Lock()
while to_process:

View file

@ -17,7 +17,15 @@ class BaseTracer(ABC):
trace_id: UUID
@abstractmethod
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID):
def __init__(
self,
trace_name: str,
trace_type: str,
project_name: str,
trace_id: UUID,
user_id: str | None = None,
session_id: str | None = None,
) -> None:
raise NotImplementedError
@property

View file

@ -23,11 +23,21 @@ if TYPE_CHECKING:
class LangFuseTracer(BaseTracer):
flow_id: str
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID):
def __init__(
self,
trace_name: str,
trace_type: str,
project_name: str,
trace_id: UUID,
user_id: str | None = None,
session_id: str | None = None,
) -> None:
self.project_name = project_name
self.trace_name = trace_name
self.trace_type = trace_type
self.trace_id = trace_id
self.user_id = user_id
self.session_id = session_id
self.flow_id = trace_name.split(" - ")[-1]
self.spans: dict = OrderedDict() # spans that are not ended
@ -43,7 +53,19 @@ class LangFuseTracer(BaseTracer):
from langfuse import Langfuse
self._client = Langfuse(**config)
self.trace = self._client.trace(id=str(self.trace_id), name=self.flow_id)
try:
from langfuse.api.core.request_options import RequestOptions
self._client.client.health.health(request_options=RequestOptions(timeout_in_seconds=1))
except Exception as e: # noqa: BLE001
logger.debug(f"can not connect to Langfuse: {e}")
return False
self.trace = self._client.trace(
id=str(self.trace_id),
name=self.flow_id,
user_id=self.user_id,
session_id=self.session_id,
)
except ImportError:
logger.exception("Could not import langfuse. Please install it with `pip install langfuse`.")
@ -69,7 +91,7 @@ class LangFuseTracer(BaseTracer):
if not self._ready:
return
metadata_: dict = {}
metadata_: dict = {"from_langflow_component": True, "component_id": trace_id}
metadata_ |= {"trace_type": trace_type} if trace_type else {}
metadata_ |= metadata or {}
@ -81,11 +103,12 @@ class LangFuseTracer(BaseTracer):
"start_time": start_time,
}
if len(self.spans) > 0:
last_span = next(reversed(self.spans))
span = self.spans[last_span].span(**content_span)
else:
span = self.trace.span(**content_span)
# if two component is built concurrently, will use wrong last span. just flatten now, maybe fix in future.
# if len(self.spans) > 0:
# last_span = next(reversed(self.spans))
# span = self.spans[last_span].span(**content_span)
# else:
span = self.trace.span(**content_span)
self.spans[trace_id] = span
@ -121,7 +144,12 @@ class LangFuseTracer(BaseTracer):
) -> None:
if not self._ready:
return
content_update = {
"input": inputs,
"output": outputs,
"metadata": metadata,
}
self.trace.update(**content_update)
self._client.flush()
def get_langchain_callback(self) -> BaseCallbackHandler | None:

View file

@ -4,6 +4,7 @@ import asyncio
import os
from collections import defaultdict
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from loguru import logger
@ -46,188 +47,204 @@ def _get_arize_phoenix_tracer():
return ArizePhoenixTracer
class TracingService(Service):
name = "tracing_service"
trace_context_var: ContextVar[TraceContext | None] = ContextVar("trace_context", default=None)
component_context_var: ContextVar[ComponentTraceContext | None] = ContextVar("component_trace_context", default=None)
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
self.inputs: dict[str, dict] = defaultdict(dict)
self.inputs_metadata: dict[str, dict] = defaultdict(dict)
self.outputs: dict[str, dict] = defaultdict(dict)
self.outputs_metadata: dict[str, dict] = defaultdict(dict)
self.run_name: str | None = None
self.run_id: UUID | None = None
self.project_name: str | None = None
self._tracers: dict[str, BaseTracer] = {}
self._logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list)
self.end_trace_tasks: set[asyncio.Task] = set()
self.deactivated = self.settings_service.settings.deactivate_tracing
self.session_id: str | None = None
def _reset_io(self) -> None:
self.inputs = defaultdict(dict)
self.inputs_metadata = defaultdict(dict)
self.outputs = defaultdict(dict)
self.outputs_metadata = defaultdict(dict)
class TraceContext:
def __init__(
self,
run_id: UUID | None,
run_name: str | None,
project_name: str | None,
user_id: str | None,
session_id: str | None,
):
self.run_id: UUID | None = run_id
self.run_name: str | None = run_name
self.project_name: str | None = project_name
self.user_id: str | None = user_id
self.session_id: str | None = session_id
self.tracers: dict[str, BaseTracer] = {}
self.all_inputs: dict[str, dict] = defaultdict(dict)
self.all_outputs: dict[str, dict] = defaultdict(dict)
async def initialize_tracers(self) -> None:
if self.deactivated:
return
try:
self._initialize_langsmith_tracer()
self._initialize_langwatch_tracer()
self._initialize_langfuse_tracer()
self._initialize_arize_phoenix_tracer()
except Exception: # noqa: BLE001
logger.opt(exception=True).debug("Error initializing tracers")
self.traces_queue: asyncio.Queue = asyncio.Queue()
self.running = False
self.worker_task: asyncio.Task | None = None
def _initialize_langsmith_tracer(self) -> None:
project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow")
self.project_name = project_name
langsmith_tracer = _get_langsmith_tracer()
self._tracers["langsmith"] = langsmith_tracer(
trace_name=self.run_name,
trace_type="chain",
project_name=self.project_name,
trace_id=self.run_id,
)
def _initialize_langwatch_tracer(self) -> None:
if "langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id:
langwatch_tracer = _get_langwatch_tracer()
self._tracers["langwatch"] = langwatch_tracer(
trace_name=self.run_name,
trace_type="chain",
project_name=self.project_name,
trace_id=self.run_id,
)
def _initialize_langfuse_tracer(self) -> None:
self.project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow")
langfuse_tracer = _get_langfuse_tracer()
self._tracers["langfuse"] = langfuse_tracer(
trace_name=self.run_name,
trace_type="chain",
project_name=self.project_name,
trace_id=self.run_id,
)
def _initialize_arize_phoenix_tracer(self) -> None:
self.project_name = os.getenv("ARIZE_PHOENIX_PROJECT", "Langflow")
arize_phoenix_tracer = _get_arize_phoenix_tracer()
self._tracers["arize_phoenix"] = arize_phoenix_tracer(
trace_name=self.run_name,
trace_type="chain",
project_name=self.project_name,
trace_id=self.run_id,
session_id=self.session_id,
)
def set_run_name(self, name: str) -> None:
self.run_name = name
def set_run_id(self, run_id: UUID) -> None:
self.run_id = run_id
def _start_traces(
class ComponentTraceContext:
def __init__(
self,
trace_id: str,
trace_name: str,
trace_type: str,
inputs: dict[str, Any],
metadata: dict[str, Any] | None = None,
vertex: Vertex | None = None,
) -> None:
inputs = self._cleanup_inputs(inputs)
self.inputs[trace_name] = inputs
self.inputs_metadata[trace_name] = metadata or {}
for tracer in self._tracers.values():
if not tracer.ready:
continue
try:
tracer.add_trace(trace_id, trace_name, trace_type, inputs, metadata, vertex)
except Exception: # noqa: BLE001
logger.exception(f"Error starting trace {trace_name}")
vertex: Vertex | None,
inputs: dict[str, dict],
metadata: dict[str, dict] | None = None,
):
self.trace_id: str = trace_id
self.trace_name: str = trace_name
self.trace_type: str = trace_type
self.vertex: Vertex | None = vertex
self.inputs: dict[str, dict] = inputs
self.inputs_metadata: dict[str, dict] = metadata or {}
self.outputs: dict[str, dict] = defaultdict(dict)
self.outputs_metadata: dict[str, dict] = defaultdict(dict)
self.logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list)
def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None:
for tracer in self._tracers.values():
class TracingService(Service):
"""Tracing service.
To trace a graph run:
1. start_tracers: start a trace for a graph run
2. with trace_component: start a sub-trace for a component build, three methods are available:
- add_log
- set_outputs
- get_langchain_callbacks
3. end_tracers: end the trace for a graph run
check context var in public methods.
"""
name = "tracing_service"
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
self.deactivated = self.settings_service.settings.deactivate_tracing
async def _trace_worker(self, trace_context: TraceContext) -> None:
while trace_context.running or not trace_context.traces_queue.empty():
trace_func, args = await trace_context.traces_queue.get()
try:
trace_func(*args)
except Exception: # noqa: BLE001
logger.exception("Error processing trace_func")
finally:
trace_context.traces_queue.task_done()
async def _start(self, trace_context: TraceContext) -> None:
if trace_context.running:
return
try:
trace_context.running = True
trace_context.worker_task = asyncio.create_task(self._trace_worker(trace_context))
except Exception: # noqa: BLE001
logger.exception("Error starting tracing service")
def _initialize_langsmith_tracer(self, trace_context: TraceContext) -> None:
langsmith_tracer = _get_langsmith_tracer()
trace_context.tracers["langsmith"] = langsmith_tracer(
trace_name=trace_context.run_name,
trace_type="chain",
project_name=trace_context.project_name,
trace_id=trace_context.run_id,
)
def _initialize_langwatch_tracer(self, trace_context: TraceContext) -> None:
if (
"langwatch" not in trace_context.tracers
or trace_context.tracers["langwatch"].trace_id != trace_context.run_id
):
langwatch_tracer = _get_langwatch_tracer()
trace_context.tracers["langwatch"] = langwatch_tracer(
trace_name=trace_context.run_name,
trace_type="chain",
project_name=trace_context.project_name,
trace_id=trace_context.run_id,
)
def _initialize_langfuse_tracer(self, trace_context: TraceContext) -> None:
langfuse_tracer = _get_langfuse_tracer()
trace_context.tracers["langfuse"] = langfuse_tracer(
trace_name=trace_context.run_name,
trace_type="chain",
project_name=trace_context.project_name,
trace_id=trace_context.run_id,
user_id=trace_context.user_id,
session_id=trace_context.session_id,
)
def _initialize_arize_phoenix_tracer(self, trace_context: TraceContext) -> None:
arize_phoenix_tracer = _get_arize_phoenix_tracer()
trace_context.tracers["arize_phoenix"] = arize_phoenix_tracer(
trace_name=trace_context.run_name,
trace_type="chain",
project_name=trace_context.project_name,
trace_id=trace_context.run_id,
)
async def start_tracers(
self,
run_id: UUID,
run_name: str,
user_id: str | None,
session_id: str | None,
project_name: str | None = None,
) -> None:
"""Start a trace for a graph run.
- create a trace context
- start a worker for this trace context
- initialize the tracers
"""
if self.deactivated:
return
try:
project_name = project_name or os.getenv("LANGCHAIN_PROJECT", "Langflow")
trace_context = TraceContext(run_id, run_name, project_name, user_id, session_id)
trace_context_var.set(trace_context)
await self._start(trace_context)
self._initialize_langsmith_tracer(trace_context)
self._initialize_langwatch_tracer(trace_context)
self._initialize_langfuse_tracer(trace_context)
self._initialize_arize_phoenix_tracer(trace_context)
except Exception as e: # noqa: BLE001
logger.debug(f"Error initializing tracers: {e}")
async def _stop(self, trace_context: TraceContext) -> None:
try:
trace_context.running = False
# check the qeue is empty
if not trace_context.traces_queue.empty():
await trace_context.traces_queue.join()
if trace_context.worker_task:
trace_context.worker_task.cancel()
trace_context.worker_task = None
except Exception: # noqa: BLE001
logger.exception("Error stopping tracing service")
def _end_all_tracers(self, trace_context: TraceContext, outputs: dict, error: Exception | None = None) -> None:
for tracer in trace_context.tracers.values():
if tracer.ready:
try:
tracer.end_trace(
trace_id=trace_id,
trace_name=trace_name,
outputs=self.outputs[trace_name],
# why all_inputs and all_outputs? why metadata=outputs?
tracer.end(
trace_context.all_inputs,
outputs=trace_context.all_outputs,
error=error,
logs=self._logs[trace_name],
metadata=outputs,
)
except Exception: # noqa: BLE001
logger.exception(f"Error ending trace {trace_name}")
self._reset_io()
def _end_all_traces(self, outputs: dict, error: Exception | None = None) -> None:
for tracer in self._tracers.values():
if tracer.ready:
try:
tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs)
except Exception: # noqa: BLE001
logger.exception("Error ending all traces")
self._reset_io()
async def end(self, outputs: dict, error: Exception | None = None) -> None:
await asyncio.to_thread(self._end_all_traces, outputs, error)
try:
# Wait for any pending trace tasks to complete
await asyncio.gather(*self.end_trace_tasks)
except Exception: # noqa: BLE001
logger.exception("Error flushing logs")
async def end_tracers(self, outputs: dict, error: Exception | None = None) -> None:
"""End the trace for a graph run.
def add_log(self, trace_name: str, log: Log) -> None:
self._logs[trace_name].append(log)
@asynccontextmanager
async def trace_context(
self,
component: Component,
trace_name: str,
inputs: dict[str, Any],
metadata: dict[str, Any] | None = None,
):
- stop worker for current trace_context
- call end for all the tracers
"""
if self.deactivated:
yield self
return
trace_id = trace_name
if component._vertex:
trace_id = component._vertex.id
trace_type = component.trace_type
self._start_traces(
trace_id,
trace_name,
trace_type,
inputs,
metadata,
component._vertex,
)
try:
yield self
except Exception as e:
self._end_and_reset(trace_id, trace_name, e)
raise
else:
self._end_and_reset(trace_id, trace_name)
def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None:
task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error))
self.end_trace_tasks.add(task)
task.add_done_callback(self.end_trace_tasks.discard)
def set_outputs(
self,
trace_name: str,
outputs: dict[str, Any],
output_metadata: dict[str, Any] | None = None,
) -> None:
self.outputs[trace_name] |= outputs or {}
self.outputs_metadata[trace_name] |= output_metadata or {}
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called end_tracers but no trace context found"
raise RuntimeError(msg)
await self._stop(trace_context)
self._end_all_tracers(trace_context, outputs, error)
@staticmethod
def _cleanup_inputs(inputs: dict[str, Any]):
@ -237,18 +254,152 @@ class TracingService(Service):
inputs[key] = "*****" # avoid logging api_keys for security reasons
return inputs
def _start_component_traces(
self,
component_trace_context: ComponentTraceContext,
trace_context: TraceContext,
) -> None:
inputs = self._cleanup_inputs(component_trace_context.inputs)
component_trace_context.inputs = inputs
component_trace_context.inputs_metadata = component_trace_context.inputs_metadata or {}
for tracer in trace_context.tracers.values():
if not tracer.ready:
continue
try:
tracer.add_trace(
component_trace_context.trace_id,
component_trace_context.trace_name,
component_trace_context.trace_type,
inputs,
component_trace_context.inputs_metadata,
component_trace_context.vertex,
)
except Exception: # noqa: BLE001
logger.exception(f"Error starting trace {component_trace_context.trace_name}")
def _end_component_traces(
self,
component_trace_context: ComponentTraceContext,
trace_context: TraceContext,
error: Exception | None = None,
) -> None:
for tracer in trace_context.tracers.values():
if tracer.ready:
try:
tracer.end_trace(
trace_id=component_trace_context.trace_id,
trace_name=component_trace_context.trace_name,
outputs=trace_context.all_outputs[component_trace_context.trace_name],
error=error,
logs=component_trace_context.logs[component_trace_context.trace_name],
)
except Exception: # noqa: BLE001
logger.exception(f"Error ending trace {component_trace_context.trace_name}")
@asynccontextmanager
async def trace_component(
self,
component: Component,
trace_name: str,
inputs: dict[str, Any],
metadata: dict[str, Any] | None = None,
):
"""Trace a component.
@param component: the component to trace
@param trace_name: component name + component id
@param inputs: the inputs to the component
@param metadata: the metadata to the component
"""
if self.deactivated:
yield self
return
trace_id = trace_name
if component._vertex:
trace_id = component._vertex.id
trace_type = component.trace_type
component_trace_context = ComponentTraceContext(
trace_id, trace_name, trace_type, component._vertex, inputs, metadata
)
component_context_var.set(component_trace_context)
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called trace_component but no trace context found"
raise RuntimeError(msg)
trace_context.all_inputs[trace_name] |= inputs or {}
await trace_context.traces_queue.put((self._start_component_traces, (component_trace_context, trace_context)))
try:
yield self
except Exception as e:
await trace_context.traces_queue.put(
(self._end_component_traces, (component_trace_context, trace_context, e))
)
raise
else:
await trace_context.traces_queue.put(
(self._end_component_traces, (component_trace_context, trace_context, None))
)
@property
def project_name(self):
if self.deactivated:
return os.getenv("LANGCHAIN_PROJECT", "Langflow")
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called project_name but no trace context found"
raise RuntimeError(msg)
return trace_context.project_name
def add_log(self, trace_name: str, log: Log) -> None:
"""Add a log to the current component trace context."""
if self.deactivated:
return
component_context = component_context_var.get()
if component_context is None:
msg = "called add_log but no component context found"
raise RuntimeError(msg)
component_context.logs[trace_name].append(log)
def set_outputs(
self,
trace_name: str,
outputs: dict[str, Any],
output_metadata: dict[str, Any] | None = None,
) -> None:
"""Set the outputs for the current component trace context."""
if self.deactivated:
return
component_context = component_context_var.get()
if component_context is None:
msg = "called set_outputs but no component context found"
raise RuntimeError(msg)
component_context.outputs[trace_name] |= outputs or {}
component_context.outputs_metadata[trace_name] |= output_metadata or {}
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called set_outputs but no trace context found"
raise RuntimeError(msg)
trace_context.all_outputs[trace_name] |= outputs or {}
def get_tracer(self, tracer_name: str) -> BaseTracer | None:
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called get_tracer but no trace context found"
raise RuntimeError(msg)
return trace_context.tracers.get(tracer_name)
def get_langchain_callbacks(self) -> list[BaseCallbackHandler]:
if self.deactivated:
return []
callbacks = []
for tracer in self._tracers.values():
trace_context = trace_context_var.get()
if trace_context is None:
msg = "called get_langchain_callbacks but no trace context found"
raise RuntimeError(msg)
for tracer in trace_context.tracers.values():
if not tracer.ready: # type: ignore[truthy-function]
continue
langchain_callback = tracer.get_langchain_callback()
if langchain_callback:
callbacks.append(langchain_callback)
return callbacks
def set_session_id(self, session_id: str) -> None:
"""Set the session ID for tracing."""
self.session_id = session_id

View file

@ -0,0 +1 @@
"""Services tests package."""

View file

@ -0,0 +1,515 @@
import asyncio
import uuid
from unittest.mock import MagicMock, patch
import pytest
from langflow.services.settings.base import Settings
from langflow.services.settings.service import SettingsService
from langflow.services.tracing.base import BaseTracer
from langflow.services.tracing.service import (
TracingService,
component_context_var,
trace_context_var,
)
class MockTracer(BaseTracer):
def __init__(
self,
trace_name: str,
trace_type: str,
project_name: str,
trace_id: uuid.UUID,
user_id: str | None = None,
session_id: str | None = None,
) -> None:
self.trace_name = trace_name
self.trace_type = trace_type
self.project_name = project_name
self.trace_id = trace_id
self.user_id = user_id
self.session_id = session_id
self._ready = True
self.end_called = False
self.get_langchain_callback_called = False
self.add_trace_list = []
self.end_trace_list = []
@property
def ready(self) -> bool:
return self._ready
def add_trace(
self,
trace_id: str,
trace_name: str,
trace_type: str,
inputs: dict[str, any],
metadata: dict[str, any] | None = None,
vertex=None,
) -> None:
self.add_trace_list.append(
{
"trace_id": trace_id,
"trace_name": trace_name,
"trace_type": trace_type,
"inputs": inputs,
"metadata": metadata,
"vertex": vertex,
}
)
def end_trace(
self,
trace_id: str,
trace_name: str,
outputs: dict[str, any] | None = None,
error: Exception | None = None,
logs=(),
) -> None:
self.end_trace_list.append(
{
"trace_id": trace_id,
"trace_name": trace_name,
"outputs": outputs,
"error": error,
"logs": logs,
}
)
def end(
self,
inputs: dict[str, any],
outputs: dict[str, any],
error: Exception | None = None,
metadata: dict[str, any] | None = None,
) -> None:
self.end_called = True
self.inputs_param = inputs
self.outputs_param = outputs
self.error_param = error
self.metadata_param = metadata
def get_langchain_callback(self):
self.get_langchain_callback_called = True
return MagicMock()
@pytest.fixture
def mock_settings_service():
settings = Settings()
settings.deactivate_tracing = False
return SettingsService(settings, MagicMock())
@pytest.fixture
def tracing_service(mock_settings_service):
return TracingService(mock_settings_service)
@pytest.fixture
def mock_component():
component = MagicMock()
component._vertex = MagicMock()
component._vertex.id = "test_vertex_id"
component.trace_type = "test_trace_type"
return component
@pytest.fixture
def mock_tracers():
with (
patch(
"langflow.services.tracing.service._get_langsmith_tracer",
return_value=MockTracer,
),
patch(
"langflow.services.tracing.service._get_langwatch_tracer",
return_value=MockTracer,
),
patch(
"langflow.services.tracing.service._get_langfuse_tracer",
return_value=MockTracer,
),
patch(
"langflow.services.tracing.service._get_arize_phoenix_tracer",
return_value=MockTracer,
),
):
yield
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_start_end_tracers(tracing_service):
"""Test starting and ending tracers."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
outputs = {"output_key": "output_value"}
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
# Verify trace_context is set correctly
trace_context = trace_context_var.get()
assert trace_context is not None
assert trace_context.run_id == run_id
assert trace_context.run_name == run_name
assert trace_context.project_name == project_name
assert trace_context.user_id == user_id
assert trace_context.session_id == session_id
# Verify tracers are initialized
assert "langsmith" in trace_context.tracers
assert "langwatch" in trace_context.tracers
assert "langfuse" in trace_context.tracers
assert "arize_phoenix" in trace_context.tracers
await tracing_service.end_tracers(outputs)
# Verify end method was called for all tracers
trace_context = trace_context_var.get()
for tracer in trace_context.tracers.values():
assert tracer.end_called
assert tracer.metadata_param == outputs
assert tracer.outputs_param == trace_context.all_outputs
# Verify worker_task is cancelled
assert trace_context.worker_task is None
assert not trace_context.running
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_trace_component(tracing_service, mock_component):
"""Test component tracing context manager."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
trace_name = "test_component_trace"
inputs = {"input_key": "input_value"}
metadata = {"metadata_key": "metadata_value"}
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
async with tracing_service.trace_component(mock_component, trace_name, inputs, metadata) as ts:
# Verify component context is set
component_context = component_context_var.get()
assert component_context is not None
assert component_context.trace_id == mock_component._vertex.id
assert component_context.trace_name == trace_name
assert component_context.trace_type == mock_component.trace_type
assert component_context.vertex == mock_component._vertex
assert component_context.inputs == inputs
assert component_context.inputs_metadata == metadata
# Verify add_trace method was called for tracers
await asyncio.sleep(0.1) # Wait for async queue processing
trace_context = trace_context_var.get()
for tracer in trace_context.tracers.values():
assert tracer.add_trace_list[0]["trace_id"] == mock_component._vertex.id
assert tracer.add_trace_list[0]["trace_name"] == trace_name
assert tracer.add_trace_list[0]["trace_type"] == mock_component.trace_type
assert tracer.add_trace_list[0]["inputs"] == inputs
assert tracer.add_trace_list[0]["metadata"] == metadata
assert tracer.add_trace_list[0]["vertex"] == mock_component._vertex
# Test adding logs
ts.add_log(trace_name, {"message": "test log"})
assert {"message": "test log"} in component_context.logs[trace_name]
# Test setting outputs
outputs = {"output_key": "output_value"}
output_metadata = {"output_metadata_key": "output_metadata_value"}
ts.set_outputs(trace_name, outputs, output_metadata)
assert component_context.outputs[trace_name] == outputs
assert component_context.outputs_metadata[trace_name] == output_metadata
assert trace_context.all_outputs[trace_name] == outputs
# Verify end_trace method was called for tracers
await asyncio.sleep(0.1) # Wait for async queue processing
for tracer in trace_context.tracers.values():
assert tracer.end_trace_list[0]["trace_id"] == mock_component._vertex.id
assert tracer.end_trace_list[0]["trace_name"] == trace_name
assert tracer.end_trace_list[0]["outputs"] == trace_context.all_outputs[trace_name]
assert tracer.end_trace_list[0]["error"] is None
assert tracer.end_trace_list[0]["logs"] == component_context.logs[trace_name]
# Cleanup
await tracing_service.end_tracers({})
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_trace_component_with_exception(tracing_service, mock_component):
"""Test component tracing context manager with exception handling."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
trace_name = "test_component_trace"
inputs = {"input_key": "input_value"}
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
test_exception = ValueError("Test exception")
with pytest.raises(ValueError, match="Test exception"):
async with tracing_service.trace_component(mock_component, trace_name, inputs):
raise test_exception
# Verify end_trace method was called with exception
await asyncio.sleep(0.1) # Wait for async queue processing
trace_context = trace_context_var.get()
for tracer in trace_context.tracers.values():
assert tracer.end_trace_list[0]["error"] == test_exception
# Cleanup
await tracing_service.end_tracers({})
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_get_langchain_callbacks(tracing_service):
"""Test getting LangChain callback handlers."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
callbacks = tracing_service.get_langchain_callbacks()
# Verify get_langchain_callback method was called for each tracer
trace_context = trace_context_var.get()
for tracer in trace_context.tracers.values():
assert tracer.get_langchain_callback_called
# Verify returned callbacks list length
assert len(callbacks) == 4 # Four tracers
# Cleanup
await tracing_service.end_tracers({})
@pytest.mark.asyncio
async def test_deactivated_tracing(mock_settings_service):
"""Test deactivated tracing functionality."""
# Set deactivate_tracing to True
mock_settings_service.settings.deactivate_tracing = True
tracing_service = TracingService(mock_settings_service)
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
# Starting tracers should have no effect
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
# With tracing disabled, trace_context_var may be None or uninitialized
assert trace_context_var.get() is None
# We don't need to check trace_context_var state, just verify tracing operations don't execute
# Test trace_component context manager
mock_component = MagicMock()
trace_name = "test_component_trace"
inputs = {"input_key": "input_value"}
async with tracing_service.trace_component(mock_component, trace_name, inputs) as ts:
ts.add_log(trace_name, {"message": "test log"})
ts.set_outputs(trace_name, {"output_key": "output_value"})
# Test getting LangChain callback handlers
callbacks = tracing_service.get_langchain_callbacks()
assert len(callbacks) == 0 # Should return empty list when tracing is disabled
# Test end_tracers
await tracing_service.end_tracers({})
@pytest.mark.asyncio
async def test_cleanup_inputs():
"""Test cleaning sensitive information from input data."""
inputs = {
"normal_key": "normal_value",
"api_key": "secret_api_key",
"openai_api_key": "secret_openai_api_key",
"nested_api_key": {"api_key": "nested_secret"},
}
cleaned_inputs = TracingService._cleanup_inputs(inputs)
# Verify values for keys containing api_key are replaced with *****
assert cleaned_inputs["normal_key"] == "normal_value"
assert cleaned_inputs["api_key"] == "*****"
assert cleaned_inputs["openai_api_key"] == "*****"
# Verify values for keys containing api_key are replaced with *****, even in nested dicts
assert cleaned_inputs["nested_api_key"] == "*****"
# Verify original input is not modified
assert inputs["api_key"] == "secret_api_key"
assert inputs["openai_api_key"] == "secret_openai_api_key"
@pytest.mark.asyncio
async def test_start_tracers_with_exception(tracing_service):
"""Test starting tracers with exception handling."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
# Mock _initialize_langsmith_tracer to raise exception
with (
patch.object(
tracing_service,
"_initialize_langsmith_tracer",
side_effect=Exception("Mock exception"),
),
patch("langflow.services.tracing.service.logger.debug") as mock_logger,
):
# start_tracers should return normally even with exception
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
# Verify exception was logged
mock_logger.assert_any_call("Error initializing tracers: Mock exception")
# Verify trace_context was set even with exception
trace_context = trace_context_var.get()
assert trace_context is not None
assert trace_context.run_id == run_id
assert trace_context.run_name == run_name
# Cleanup
await tracing_service.end_tracers({})
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_trace_worker_with_exception(tracing_service):
"""Test trace worker exception handling."""
run_id = uuid.uuid4()
run_name = "test_run"
user_id = "test_user"
session_id = "test_session"
project_name = "test_project"
# Create a trace function that raises an exception
def failing_trace_func():
msg = "Mock trace function exception"
raise ValueError(msg)
with patch("langflow.services.tracing.service.logger.exception") as mock_logger:
# Remove incorrect context manager usage
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
# Get trace_context and add failing trace function to queue
trace_context = trace_context_var.get()
await trace_context.traces_queue.put((failing_trace_func, ()))
# Wait for async queue processing
await asyncio.sleep(0.1)
# Verify exception was logged
mock_logger.assert_called_with("Error processing trace_func")
# Cleanup
await tracing_service.end_tracers({})
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_tracers")
async def test_concurrent_tracing(tracing_service, mock_component):
"""Test two tasks running start_tracers concurrently, with each task running 2 concurrent trace_component tasks."""
# Define common task function: start tracers and run two component traces
async def run_task(
run_id,
run_name,
user_id,
session_id,
project_name,
inputs,
metadata,
task_prefix,
sleep_duration=0.1,
):
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
async def run_component_task(component, trace_name, component_suffix):
async with tracing_service.trace_component(component, trace_name, inputs, metadata) as ts:
ts.add_log(trace_name, {"message": f"{task_prefix} {component_suffix} log"})
outputs = {"output_key": f"{task_prefix}_{component_suffix}_output"}
await asyncio.sleep(sleep_duration)
ts.set_outputs(trace_name, outputs)
task1 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name1", f"{run_id} component1"))
await task1
task2 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name2", f"{run_id} component2"))
await task2
await tracing_service.end_tracers({"final_output": f"{task_prefix}_final_output"})
trace_context = trace_context_var.get()
return trace_context.tracers["langfuse"]
inputs1 = {"input_key": "input_value1"}
metadata1 = {"metadata_key": "metadata_value1"}
inputs2 = {"input_key": "input_value2"}
metadata2 = {"metadata_key": "metadata_value2"}
task1 = asyncio.create_task(
run_task(
"run_id1",
"run_name1",
"user_id1",
"session_id1",
"project_name1",
inputs1,
metadata1,
"task1",
2,
)
)
await asyncio.sleep(0.1)
task2 = asyncio.create_task(
run_task(
"run_id2",
"run_name2",
"user_id2",
"session_id2",
"project_name2",
inputs2,
metadata2,
"task2",
0.1,
)
)
tracer1 = await task1
tracer2 = await task2
# Verify tracer1 and tracer2 have correct trace data
assert tracer1.trace_name == "run_name1"
assert tracer1.project_name == "project_name1"
assert tracer1.user_id == "user_id1"
assert tracer1.session_id == "session_id1"
assert dict(tracer1.outputs_param.get("run_id1 trace_name1")) == {"output_key": "task1_run_id1 component1_output"}
assert dict(tracer1.outputs_param.get("run_id1 trace_name2")) == {"output_key": "task1_run_id1 component2_output"}
assert tracer2.trace_name == "run_name2"
assert tracer2.project_name == "project_name2"
assert tracer2.user_id == "user_id2"
assert tracer2.session_id == "session_id2"
assert dict(tracer2.outputs_param.get("run_id2 trace_name1")) == {"output_key": "task2_run_id2 component1_output"}
assert dict(tracer2.outputs_param.get("run_id2 trace_name2")) == {"output_key": "task2_run_id2 component2_output"}