fix: refactor tracing service (#7015)
* fix: refactor tracing service 1. fix race condition when concurrent run flow #6899 2. make start_trace/end_trace both async 3. add session_id/user_id to trace #4274 * fix: merge * feat: unittest trace service * fix: ruff style * fix: handle tracing deactivated * fix: ruff style * feat: detect langfuse ready by health() * fix: ruff style * Update src/backend/base/langflow/api/build.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/api/build.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/graph/graph/base.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/langfuse.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update src/backend/base/langflow/services/tracing/service.py Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com> * Update service.py * fix: code style * fix: style * feat: langwatch component get_tracer from service * fix: only get contextvar in public method, check and raise * fix: style, long arg names --------- Co-authored-by: Ítalo Johnny <italojohnnydosanjos@gmail.com>
This commit is contained in:
parent
2dfdb9f701
commit
eeeb09e7ea
11 changed files with 942 additions and 231 deletions
|
|
@ -180,9 +180,6 @@ async def generate_flow_events(
|
|||
graph.validate_stream()
|
||||
first_layer = sort_vertices(graph)
|
||||
|
||||
if inputs is not None and getattr(inputs, "session", None) is not None:
|
||||
graph.session_id = inputs.session
|
||||
|
||||
for vertex_id in first_layer:
|
||||
graph.run_manager.add_to_vertices_being_run(vertex_id)
|
||||
|
||||
|
|
@ -218,8 +215,19 @@ async def generate_flow_events(
|
|||
)
|
||||
|
||||
async def create_graph(fresh_session, flow_id_str: str) -> Graph:
|
||||
if inputs is not None and getattr(inputs, "session", None) is not None:
|
||||
effective_session_id = inputs.session
|
||||
else:
|
||||
effective_session_id = flow_id_str
|
||||
|
||||
if not data:
|
||||
return await build_graph_from_db(flow_id=flow_id, session=fresh_session, chat_service=chat_service)
|
||||
return await build_graph_from_db(
|
||||
flow_id=flow_id,
|
||||
session=fresh_session,
|
||||
chat_service=chat_service,
|
||||
user_id=str(current_user.id),
|
||||
session_id=effective_session_id,
|
||||
)
|
||||
|
||||
result = await fresh_session.exec(select(Flow.name).where(Flow.id == flow_id))
|
||||
flow_name = result.first()
|
||||
|
|
@ -229,6 +237,7 @@ async def generate_flow_events(
|
|||
payload=data.model_dump(),
|
||||
user_id=str(current_user.id),
|
||||
flow_name=flow_name,
|
||||
session_id=effective_session_id,
|
||||
)
|
||||
|
||||
def sort_vertices(graph: Graph) -> list[str]:
|
||||
|
|
@ -280,7 +289,7 @@ async def generate_flow_events(
|
|||
outputs = {output_label: OutputValue(message=message, type="error")}
|
||||
result_data_response = ResultDataResponse(results={}, outputs=outputs)
|
||||
artifacts = {}
|
||||
background_tasks.add_task(graph.end_all_traces, error=exc)
|
||||
background_tasks.add_task(graph.end_all_traces_in_context(error=exc))
|
||||
|
||||
result_data_response.message = artifacts
|
||||
|
||||
|
|
@ -314,7 +323,7 @@ async def generate_flow_events(
|
|||
next_runnable_vertices = [graph.stop_vertex]
|
||||
|
||||
if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
|
||||
background_tasks.add_task(graph.end_all_traces)
|
||||
background_tasks.add_task(graph.end_all_traces_in_context())
|
||||
|
||||
build_response = VertexBuildResponse(
|
||||
inactivated_vertices=list(set(inactivated_vertices)),
|
||||
|
|
@ -410,7 +419,7 @@ async def generate_flow_events(
|
|||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except asyncio.CancelledError:
|
||||
background_tasks.add_task(graph.end_all_traces)
|
||||
background_tasks.add_task(graph.end_all_traces_in_context())
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error building vertices: {e}")
|
||||
|
|
|
|||
|
|
@ -154,35 +154,37 @@ async def build_graph_from_data(flow_id: uuid.UUID | str, payload: dict, **kwarg
|
|||
# Get flow name
|
||||
if "flow_name" not in kwargs:
|
||||
flow_name = await _get_flow_name(flow_id if isinstance(flow_id, uuid.UUID) else uuid.UUID(flow_id))
|
||||
kwargs["flow_name"] = flow_name
|
||||
else:
|
||||
flow_name = kwargs["flow_name"]
|
||||
str_flow_id = str(flow_id)
|
||||
graph = Graph.from_payload(payload, str_flow_id, **kwargs)
|
||||
session_id = kwargs.get("session_id") or str_flow_id
|
||||
|
||||
graph = Graph.from_payload(payload, str_flow_id, flow_name, kwargs.get("user_id"))
|
||||
for vertex_id in graph.has_session_id_vertices:
|
||||
vertex = graph.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
msg = f"Vertex {vertex_id} not found"
|
||||
raise ValueError(msg)
|
||||
if not vertex.raw_params.get("session_id"):
|
||||
vertex.update_raw_params({"session_id": str_flow_id}, overwrite=True)
|
||||
vertex.update_raw_params({"session_id": session_id}, overwrite=True)
|
||||
|
||||
run_id = uuid.uuid4()
|
||||
graph.set_run_id(run_id)
|
||||
graph.set_run_name()
|
||||
graph.session_id = session_id
|
||||
await graph.initialize_run()
|
||||
return graph
|
||||
|
||||
|
||||
async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession):
|
||||
async def build_graph_from_db_no_cache(flow_id: uuid.UUID, session: AsyncSession, **kwargs):
|
||||
"""Build and cache the graph."""
|
||||
flow: Flow | None = await session.get(Flow, flow_id)
|
||||
if not flow or not flow.data:
|
||||
msg = "Invalid flow ID"
|
||||
raise ValueError(msg)
|
||||
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))
|
||||
kwargs["user_id"] = kwargs.get("user_id") or str(flow.user_id)
|
||||
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, **kwargs)
|
||||
|
||||
|
||||
async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService):
|
||||
graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session)
|
||||
async def build_graph_from_db(flow_id: uuid.UUID, session: AsyncSession, chat_service: ChatService, **kwargs):
|
||||
graph = await build_graph_from_db_no_cache(flow_id=flow_id, session=session, **kwargs)
|
||||
await chat_service.set_cache(str(flow_id), graph)
|
||||
return graph
|
||||
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ async def build_vertex(
|
|||
outputs = {output_label: OutputValue(message=message, type="error")}
|
||||
result_data_response = ResultDataResponse(results={}, outputs=outputs)
|
||||
artifacts = {}
|
||||
background_tasks.add_task(graph.end_all_traces, error=exc)
|
||||
background_tasks.add_task(graph.end_all_traces_in_context(error=exc))
|
||||
# If there's an error building the vertex
|
||||
# we need to clear the cache
|
||||
await chat_service.clear_cache(flow_id_str)
|
||||
|
|
@ -331,7 +331,7 @@ async def build_vertex(
|
|||
next_runnable_vertices = [graph.stop_vertex]
|
||||
|
||||
if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
|
||||
background_tasks.add_task(graph.end_all_traces)
|
||||
background_tasks.add_task(graph.end_all_traces_in_context())
|
||||
|
||||
build_response = VertexBuildResponse(
|
||||
inactivated_vertices=list(set(inactivated_vertices)),
|
||||
|
|
|
|||
|
|
@ -266,12 +266,10 @@ class LangWatchComponent(Component):
|
|||
"settings": {},
|
||||
}
|
||||
|
||||
if (
|
||||
self._tracing_service
|
||||
and self._tracing_service._tracers
|
||||
and "langwatch" in self._tracing_service._tracers
|
||||
):
|
||||
payload["trace_id"] = str(self._tracing_service._tracers["langwatch"].trace_id) # type: ignore[assignment]
|
||||
if self._tracing_service:
|
||||
tracer = self._tracing_service.get_tracer("langwatch")
|
||||
if tracer is not None and hasattr(tracer, "trace_id"):
|
||||
payload["settings"]["trace_id"] = str(tracer.trace_id)
|
||||
|
||||
for setting_name in self.dynamic_inputs:
|
||||
payload["settings"][setting_name] = getattr(self, setting_name, None)
|
||||
|
|
|
|||
|
|
@ -886,7 +886,7 @@ class Component(CustomComponent):
|
|||
async def _build_with_tracing(self):
|
||||
inputs = self.get_trace_as_inputs()
|
||||
metadata = self.get_trace_as_metadata()
|
||||
async with self._tracing_service.trace_context(self, self.trace_name, inputs, metadata):
|
||||
async with self._tracing_service.trace_component(self, self.trace_name, inputs, metadata):
|
||||
results, artifacts = await self._build_results()
|
||||
self._tracing_service.set_outputs(self.trace_name, results)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import contextvars
|
||||
import copy
|
||||
import json
|
||||
import queue
|
||||
|
|
@ -43,7 +44,7 @@ from langflow.services.deps import get_chat_service, get_tracing_service
|
|||
from langflow.utils.async_helpers import run_until_complete
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterable
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.custom.custom_component.component import Component
|
||||
|
|
@ -636,15 +637,6 @@ class Graph:
|
|||
raise ValueError(msg)
|
||||
return self._run_id
|
||||
|
||||
def set_tracing_session_id(self) -> None:
|
||||
"""Sets the ID of the current session.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID.
|
||||
"""
|
||||
if self.tracing_service:
|
||||
self.tracing_service.set_session_id(self._session_id)
|
||||
|
||||
def set_run_id(self, run_id: uuid.UUID | None = None) -> None:
|
||||
"""Sets the ID of the current run.
|
||||
|
||||
|
|
@ -655,29 +647,37 @@ class Graph:
|
|||
run_id = uuid.uuid4()
|
||||
|
||||
self._run_id = str(run_id)
|
||||
if self.tracing_service:
|
||||
self.tracing_service.set_run_id(run_id)
|
||||
if self._session_id and self.tracing_service is not None:
|
||||
self.tracing_service.set_session_id(self.session_id)
|
||||
|
||||
def set_run_name(self) -> None:
|
||||
# Given a flow name, flow_id
|
||||
if not self.tracing_service:
|
||||
return
|
||||
name = f"{self.flow_name} - {self.flow_id}"
|
||||
|
||||
self.set_run_id()
|
||||
self.tracing_service.set_run_name(name)
|
||||
|
||||
async def initialize_run(self) -> None:
|
||||
if not self._run_id:
|
||||
self.set_run_id()
|
||||
if self.tracing_service:
|
||||
await self.tracing_service.initialize_tracers()
|
||||
run_name = f"{self.flow_name} - {self.flow_id}"
|
||||
await self.tracing_service.start_tracers(
|
||||
run_id=uuid.UUID(self._run_id),
|
||||
run_name=run_name,
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
|
||||
task = asyncio.create_task(self.end_all_traces(outputs, error))
|
||||
self._end_trace_tasks.add(task)
|
||||
task.add_done_callback(self._end_trace_tasks.discard)
|
||||
|
||||
def end_all_traces_in_context(
|
||||
self,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
error: Exception | None = None,
|
||||
) -> Callable:
|
||||
# BackgroundTasks run in different context, so we need to copy the context
|
||||
context = contextvars.copy_context()
|
||||
|
||||
async def async_end_traces_func():
|
||||
await asyncio.create_task(self.end_all_traces(outputs, error), context=context)
|
||||
|
||||
return async_end_traces_func
|
||||
|
||||
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
|
||||
if not self.tracing_service:
|
||||
return
|
||||
|
|
@ -685,7 +685,7 @@ class Graph:
|
|||
if outputs is None:
|
||||
outputs = {}
|
||||
outputs |= self.metadata
|
||||
await self.tracing_service.end(outputs, error)
|
||||
await self.tracing_service.end_tracers(outputs, error)
|
||||
|
||||
@property
|
||||
def sorted_vertices_layers(self) -> list[list[str]]:
|
||||
|
|
@ -853,6 +853,8 @@ class Graph:
|
|||
inputs_components.append([])
|
||||
if types is None:
|
||||
types = []
|
||||
if session_id:
|
||||
self.session_id = session_id
|
||||
for _ in range(len(inputs) - len(types)):
|
||||
types.append("chat") # default to chat
|
||||
for run_inputs, components, input_type in zip(inputs, inputs_components, types, strict=True):
|
||||
|
|
@ -1053,7 +1055,6 @@ class Graph:
|
|||
self.state_manager = GraphStateManager()
|
||||
self.tracing_service = get_tracing_service()
|
||||
self.set_run_id(self._run_id)
|
||||
self.set_run_name()
|
||||
|
||||
@classmethod
|
||||
def from_payload(
|
||||
|
|
@ -1527,8 +1528,6 @@ class Graph:
|
|||
to_process = deque(first_layer)
|
||||
layer_index = 0
|
||||
chat_service = get_chat_service()
|
||||
self.set_run_id()
|
||||
self.set_run_name()
|
||||
await self.initialize_run()
|
||||
lock = asyncio.Lock()
|
||||
while to_process:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,15 @@ class BaseTracer(ABC):
|
|||
trace_id: UUID
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID):
|
||||
def __init__(
|
||||
self,
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
project_name: str,
|
||||
trace_id: UUID,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -23,11 +23,21 @@ if TYPE_CHECKING:
|
|||
class LangFuseTracer(BaseTracer):
|
||||
flow_id: str
|
||||
|
||||
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID):
|
||||
def __init__(
|
||||
self,
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
project_name: str,
|
||||
trace_id: UUID,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> None:
|
||||
self.project_name = project_name
|
||||
self.trace_name = trace_name
|
||||
self.trace_type = trace_type
|
||||
self.trace_id = trace_id
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.flow_id = trace_name.split(" - ")[-1]
|
||||
self.spans: dict = OrderedDict() # spans that are not ended
|
||||
|
||||
|
|
@ -43,7 +53,19 @@ class LangFuseTracer(BaseTracer):
|
|||
from langfuse import Langfuse
|
||||
|
||||
self._client = Langfuse(**config)
|
||||
self.trace = self._client.trace(id=str(self.trace_id), name=self.flow_id)
|
||||
try:
|
||||
from langfuse.api.core.request_options import RequestOptions
|
||||
|
||||
self._client.client.health.health(request_options=RequestOptions(timeout_in_seconds=1))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"can not connect to Langfuse: {e}")
|
||||
return False
|
||||
self.trace = self._client.trace(
|
||||
id=str(self.trace_id),
|
||||
name=self.flow_id,
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
logger.exception("Could not import langfuse. Please install it with `pip install langfuse`.")
|
||||
|
|
@ -69,7 +91,7 @@ class LangFuseTracer(BaseTracer):
|
|||
if not self._ready:
|
||||
return
|
||||
|
||||
metadata_: dict = {}
|
||||
metadata_: dict = {"from_langflow_component": True, "component_id": trace_id}
|
||||
metadata_ |= {"trace_type": trace_type} if trace_type else {}
|
||||
metadata_ |= metadata or {}
|
||||
|
||||
|
|
@ -81,11 +103,12 @@ class LangFuseTracer(BaseTracer):
|
|||
"start_time": start_time,
|
||||
}
|
||||
|
||||
if len(self.spans) > 0:
|
||||
last_span = next(reversed(self.spans))
|
||||
span = self.spans[last_span].span(**content_span)
|
||||
else:
|
||||
span = self.trace.span(**content_span)
|
||||
# if two component is built concurrently, will use wrong last span. just flatten now, maybe fix in future.
|
||||
# if len(self.spans) > 0:
|
||||
# last_span = next(reversed(self.spans))
|
||||
# span = self.spans[last_span].span(**content_span)
|
||||
# else:
|
||||
span = self.trace.span(**content_span)
|
||||
|
||||
self.spans[trace_id] = span
|
||||
|
||||
|
|
@ -121,7 +144,12 @@ class LangFuseTracer(BaseTracer):
|
|||
) -> None:
|
||||
if not self._ready:
|
||||
return
|
||||
|
||||
content_update = {
|
||||
"input": inputs,
|
||||
"output": outputs,
|
||||
"metadata": metadata,
|
||||
}
|
||||
self.trace.update(**content_update)
|
||||
self._client.flush()
|
||||
|
||||
def get_langchain_callback(self) -> BaseCallbackHandler | None:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -46,188 +47,204 @@ def _get_arize_phoenix_tracer():
|
|||
return ArizePhoenixTracer
|
||||
|
||||
|
||||
class TracingService(Service):
|
||||
name = "tracing_service"
|
||||
trace_context_var: ContextVar[TraceContext | None] = ContextVar("trace_context", default=None)
|
||||
component_context_var: ContextVar[ComponentTraceContext | None] = ContextVar("component_trace_context", default=None)
|
||||
|
||||
def __init__(self, settings_service: SettingsService):
|
||||
self.settings_service = settings_service
|
||||
self.inputs: dict[str, dict] = defaultdict(dict)
|
||||
self.inputs_metadata: dict[str, dict] = defaultdict(dict)
|
||||
self.outputs: dict[str, dict] = defaultdict(dict)
|
||||
self.outputs_metadata: dict[str, dict] = defaultdict(dict)
|
||||
self.run_name: str | None = None
|
||||
self.run_id: UUID | None = None
|
||||
self.project_name: str | None = None
|
||||
self._tracers: dict[str, BaseTracer] = {}
|
||||
self._logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list)
|
||||
self.end_trace_tasks: set[asyncio.Task] = set()
|
||||
self.deactivated = self.settings_service.settings.deactivate_tracing
|
||||
self.session_id: str | None = None
|
||||
|
||||
def _reset_io(self) -> None:
|
||||
self.inputs = defaultdict(dict)
|
||||
self.inputs_metadata = defaultdict(dict)
|
||||
self.outputs = defaultdict(dict)
|
||||
self.outputs_metadata = defaultdict(dict)
|
||||
class TraceContext:
|
||||
def __init__(
|
||||
self,
|
||||
run_id: UUID | None,
|
||||
run_name: str | None,
|
||||
project_name: str | None,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
):
|
||||
self.run_id: UUID | None = run_id
|
||||
self.run_name: str | None = run_name
|
||||
self.project_name: str | None = project_name
|
||||
self.user_id: str | None = user_id
|
||||
self.session_id: str | None = session_id
|
||||
self.tracers: dict[str, BaseTracer] = {}
|
||||
self.all_inputs: dict[str, dict] = defaultdict(dict)
|
||||
self.all_outputs: dict[str, dict] = defaultdict(dict)
|
||||
|
||||
async def initialize_tracers(self) -> None:
|
||||
if self.deactivated:
|
||||
return
|
||||
try:
|
||||
self._initialize_langsmith_tracer()
|
||||
self._initialize_langwatch_tracer()
|
||||
self._initialize_langfuse_tracer()
|
||||
self._initialize_arize_phoenix_tracer()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.opt(exception=True).debug("Error initializing tracers")
|
||||
self.traces_queue: asyncio.Queue = asyncio.Queue()
|
||||
self.running = False
|
||||
self.worker_task: asyncio.Task | None = None
|
||||
|
||||
def _initialize_langsmith_tracer(self) -> None:
|
||||
project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow")
|
||||
self.project_name = project_name
|
||||
langsmith_tracer = _get_langsmith_tracer()
|
||||
self._tracers["langsmith"] = langsmith_tracer(
|
||||
trace_name=self.run_name,
|
||||
trace_type="chain",
|
||||
project_name=self.project_name,
|
||||
trace_id=self.run_id,
|
||||
)
|
||||
|
||||
def _initialize_langwatch_tracer(self) -> None:
|
||||
if "langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id:
|
||||
langwatch_tracer = _get_langwatch_tracer()
|
||||
self._tracers["langwatch"] = langwatch_tracer(
|
||||
trace_name=self.run_name,
|
||||
trace_type="chain",
|
||||
project_name=self.project_name,
|
||||
trace_id=self.run_id,
|
||||
)
|
||||
|
||||
def _initialize_langfuse_tracer(self) -> None:
|
||||
self.project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow")
|
||||
langfuse_tracer = _get_langfuse_tracer()
|
||||
self._tracers["langfuse"] = langfuse_tracer(
|
||||
trace_name=self.run_name,
|
||||
trace_type="chain",
|
||||
project_name=self.project_name,
|
||||
trace_id=self.run_id,
|
||||
)
|
||||
|
||||
def _initialize_arize_phoenix_tracer(self) -> None:
|
||||
self.project_name = os.getenv("ARIZE_PHOENIX_PROJECT", "Langflow")
|
||||
arize_phoenix_tracer = _get_arize_phoenix_tracer()
|
||||
self._tracers["arize_phoenix"] = arize_phoenix_tracer(
|
||||
trace_name=self.run_name,
|
||||
trace_type="chain",
|
||||
project_name=self.project_name,
|
||||
trace_id=self.run_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def set_run_name(self, name: str) -> None:
|
||||
self.run_name = name
|
||||
|
||||
def set_run_id(self, run_id: UUID) -> None:
|
||||
self.run_id = run_id
|
||||
|
||||
def _start_traces(
|
||||
class ComponentTraceContext:
|
||||
def __init__(
|
||||
self,
|
||||
trace_id: str,
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
inputs: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
vertex: Vertex | None = None,
|
||||
) -> None:
|
||||
inputs = self._cleanup_inputs(inputs)
|
||||
self.inputs[trace_name] = inputs
|
||||
self.inputs_metadata[trace_name] = metadata or {}
|
||||
for tracer in self._tracers.values():
|
||||
if not tracer.ready:
|
||||
continue
|
||||
try:
|
||||
tracer.add_trace(trace_id, trace_name, trace_type, inputs, metadata, vertex)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error starting trace {trace_name}")
|
||||
vertex: Vertex | None,
|
||||
inputs: dict[str, dict],
|
||||
metadata: dict[str, dict] | None = None,
|
||||
):
|
||||
self.trace_id: str = trace_id
|
||||
self.trace_name: str = trace_name
|
||||
self.trace_type: str = trace_type
|
||||
self.vertex: Vertex | None = vertex
|
||||
self.inputs: dict[str, dict] = inputs
|
||||
self.inputs_metadata: dict[str, dict] = metadata or {}
|
||||
self.outputs: dict[str, dict] = defaultdict(dict)
|
||||
self.outputs_metadata: dict[str, dict] = defaultdict(dict)
|
||||
self.logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list)
|
||||
|
||||
def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None:
|
||||
for tracer in self._tracers.values():
|
||||
|
||||
class TracingService(Service):
|
||||
"""Tracing service.
|
||||
|
||||
To trace a graph run:
|
||||
1. start_tracers: start a trace for a graph run
|
||||
2. with trace_component: start a sub-trace for a component build, three methods are available:
|
||||
- add_log
|
||||
- set_outputs
|
||||
- get_langchain_callbacks
|
||||
3. end_tracers: end the trace for a graph run
|
||||
|
||||
check context var in public methods.
|
||||
"""
|
||||
|
||||
name = "tracing_service"
|
||||
|
||||
def __init__(self, settings_service: SettingsService):
|
||||
self.settings_service = settings_service
|
||||
self.deactivated = self.settings_service.settings.deactivate_tracing
|
||||
|
||||
async def _trace_worker(self, trace_context: TraceContext) -> None:
|
||||
while trace_context.running or not trace_context.traces_queue.empty():
|
||||
trace_func, args = await trace_context.traces_queue.get()
|
||||
try:
|
||||
trace_func(*args)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error processing trace_func")
|
||||
finally:
|
||||
trace_context.traces_queue.task_done()
|
||||
|
||||
async def _start(self, trace_context: TraceContext) -> None:
|
||||
if trace_context.running:
|
||||
return
|
||||
try:
|
||||
trace_context.running = True
|
||||
trace_context.worker_task = asyncio.create_task(self._trace_worker(trace_context))
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error starting tracing service")
|
||||
|
||||
def _initialize_langsmith_tracer(self, trace_context: TraceContext) -> None:
|
||||
langsmith_tracer = _get_langsmith_tracer()
|
||||
trace_context.tracers["langsmith"] = langsmith_tracer(
|
||||
trace_name=trace_context.run_name,
|
||||
trace_type="chain",
|
||||
project_name=trace_context.project_name,
|
||||
trace_id=trace_context.run_id,
|
||||
)
|
||||
|
||||
def _initialize_langwatch_tracer(self, trace_context: TraceContext) -> None:
|
||||
if (
|
||||
"langwatch" not in trace_context.tracers
|
||||
or trace_context.tracers["langwatch"].trace_id != trace_context.run_id
|
||||
):
|
||||
langwatch_tracer = _get_langwatch_tracer()
|
||||
trace_context.tracers["langwatch"] = langwatch_tracer(
|
||||
trace_name=trace_context.run_name,
|
||||
trace_type="chain",
|
||||
project_name=trace_context.project_name,
|
||||
trace_id=trace_context.run_id,
|
||||
)
|
||||
|
||||
def _initialize_langfuse_tracer(self, trace_context: TraceContext) -> None:
|
||||
langfuse_tracer = _get_langfuse_tracer()
|
||||
trace_context.tracers["langfuse"] = langfuse_tracer(
|
||||
trace_name=trace_context.run_name,
|
||||
trace_type="chain",
|
||||
project_name=trace_context.project_name,
|
||||
trace_id=trace_context.run_id,
|
||||
user_id=trace_context.user_id,
|
||||
session_id=trace_context.session_id,
|
||||
)
|
||||
|
||||
def _initialize_arize_phoenix_tracer(self, trace_context: TraceContext) -> None:
|
||||
arize_phoenix_tracer = _get_arize_phoenix_tracer()
|
||||
trace_context.tracers["arize_phoenix"] = arize_phoenix_tracer(
|
||||
trace_name=trace_context.run_name,
|
||||
trace_type="chain",
|
||||
project_name=trace_context.project_name,
|
||||
trace_id=trace_context.run_id,
|
||||
)
|
||||
|
||||
async def start_tracers(
|
||||
self,
|
||||
run_id: UUID,
|
||||
run_name: str,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
project_name: str | None = None,
|
||||
) -> None:
|
||||
"""Start a trace for a graph run.
|
||||
|
||||
- create a trace context
|
||||
- start a worker for this trace context
|
||||
- initialize the tracers
|
||||
"""
|
||||
if self.deactivated:
|
||||
return
|
||||
try:
|
||||
project_name = project_name or os.getenv("LANGCHAIN_PROJECT", "Langflow")
|
||||
trace_context = TraceContext(run_id, run_name, project_name, user_id, session_id)
|
||||
trace_context_var.set(trace_context)
|
||||
await self._start(trace_context)
|
||||
self._initialize_langsmith_tracer(trace_context)
|
||||
self._initialize_langwatch_tracer(trace_context)
|
||||
self._initialize_langfuse_tracer(trace_context)
|
||||
self._initialize_arize_phoenix_tracer(trace_context)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"Error initializing tracers: {e}")
|
||||
|
||||
async def _stop(self, trace_context: TraceContext) -> None:
|
||||
try:
|
||||
trace_context.running = False
|
||||
# check the qeue is empty
|
||||
if not trace_context.traces_queue.empty():
|
||||
await trace_context.traces_queue.join()
|
||||
if trace_context.worker_task:
|
||||
trace_context.worker_task.cancel()
|
||||
trace_context.worker_task = None
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error stopping tracing service")
|
||||
|
||||
def _end_all_tracers(self, trace_context: TraceContext, outputs: dict, error: Exception | None = None) -> None:
|
||||
for tracer in trace_context.tracers.values():
|
||||
if tracer.ready:
|
||||
try:
|
||||
tracer.end_trace(
|
||||
trace_id=trace_id,
|
||||
trace_name=trace_name,
|
||||
outputs=self.outputs[trace_name],
|
||||
# why all_inputs and all_outputs? why metadata=outputs?
|
||||
tracer.end(
|
||||
trace_context.all_inputs,
|
||||
outputs=trace_context.all_outputs,
|
||||
error=error,
|
||||
logs=self._logs[trace_name],
|
||||
metadata=outputs,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error ending trace {trace_name}")
|
||||
self._reset_io()
|
||||
|
||||
def _end_all_traces(self, outputs: dict, error: Exception | None = None) -> None:
|
||||
for tracer in self._tracers.values():
|
||||
if tracer.ready:
|
||||
try:
|
||||
tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error ending all traces")
|
||||
self._reset_io()
|
||||
|
||||
async def end(self, outputs: dict, error: Exception | None = None) -> None:
|
||||
await asyncio.to_thread(self._end_all_traces, outputs, error)
|
||||
try:
|
||||
# Wait for any pending trace tasks to complete
|
||||
await asyncio.gather(*self.end_trace_tasks)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error flushing logs")
|
||||
async def end_tracers(self, outputs: dict, error: Exception | None = None) -> None:
|
||||
"""End the trace for a graph run.
|
||||
|
||||
def add_log(self, trace_name: str, log: Log) -> None:
|
||||
self._logs[trace_name].append(log)
|
||||
|
||||
@asynccontextmanager
|
||||
async def trace_context(
|
||||
self,
|
||||
component: Component,
|
||||
trace_name: str,
|
||||
inputs: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
- stop worker for current trace_context
|
||||
- call end for all the tracers
|
||||
"""
|
||||
if self.deactivated:
|
||||
yield self
|
||||
return
|
||||
trace_id = trace_name
|
||||
if component._vertex:
|
||||
trace_id = component._vertex.id
|
||||
trace_type = component.trace_type
|
||||
self._start_traces(
|
||||
trace_id,
|
||||
trace_name,
|
||||
trace_type,
|
||||
inputs,
|
||||
metadata,
|
||||
component._vertex,
|
||||
)
|
||||
try:
|
||||
yield self
|
||||
except Exception as e:
|
||||
self._end_and_reset(trace_id, trace_name, e)
|
||||
raise
|
||||
else:
|
||||
self._end_and_reset(trace_id, trace_name)
|
||||
|
||||
def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None:
|
||||
task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error))
|
||||
self.end_trace_tasks.add(task)
|
||||
task.add_done_callback(self.end_trace_tasks.discard)
|
||||
|
||||
def set_outputs(
|
||||
self,
|
||||
trace_name: str,
|
||||
outputs: dict[str, Any],
|
||||
output_metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.outputs[trace_name] |= outputs or {}
|
||||
self.outputs_metadata[trace_name] |= output_metadata or {}
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called end_tracers but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
await self._stop(trace_context)
|
||||
self._end_all_tracers(trace_context, outputs, error)
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_inputs(inputs: dict[str, Any]):
|
||||
|
|
@ -237,18 +254,152 @@ class TracingService(Service):
|
|||
inputs[key] = "*****" # avoid logging api_keys for security reasons
|
||||
return inputs
|
||||
|
||||
def _start_component_traces(
|
||||
self,
|
||||
component_trace_context: ComponentTraceContext,
|
||||
trace_context: TraceContext,
|
||||
) -> None:
|
||||
inputs = self._cleanup_inputs(component_trace_context.inputs)
|
||||
component_trace_context.inputs = inputs
|
||||
component_trace_context.inputs_metadata = component_trace_context.inputs_metadata or {}
|
||||
for tracer in trace_context.tracers.values():
|
||||
if not tracer.ready:
|
||||
continue
|
||||
try:
|
||||
tracer.add_trace(
|
||||
component_trace_context.trace_id,
|
||||
component_trace_context.trace_name,
|
||||
component_trace_context.trace_type,
|
||||
inputs,
|
||||
component_trace_context.inputs_metadata,
|
||||
component_trace_context.vertex,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error starting trace {component_trace_context.trace_name}")
|
||||
|
||||
def _end_component_traces(
|
||||
self,
|
||||
component_trace_context: ComponentTraceContext,
|
||||
trace_context: TraceContext,
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
for tracer in trace_context.tracers.values():
|
||||
if tracer.ready:
|
||||
try:
|
||||
tracer.end_trace(
|
||||
trace_id=component_trace_context.trace_id,
|
||||
trace_name=component_trace_context.trace_name,
|
||||
outputs=trace_context.all_outputs[component_trace_context.trace_name],
|
||||
error=error,
|
||||
logs=component_trace_context.logs[component_trace_context.trace_name],
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error ending trace {component_trace_context.trace_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def trace_component(
|
||||
self,
|
||||
component: Component,
|
||||
trace_name: str,
|
||||
inputs: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Trace a component.
|
||||
|
||||
@param component: the component to trace
|
||||
@param trace_name: component name + component id
|
||||
@param inputs: the inputs to the component
|
||||
@param metadata: the metadata to the component
|
||||
"""
|
||||
if self.deactivated:
|
||||
yield self
|
||||
return
|
||||
trace_id = trace_name
|
||||
if component._vertex:
|
||||
trace_id = component._vertex.id
|
||||
trace_type = component.trace_type
|
||||
component_trace_context = ComponentTraceContext(
|
||||
trace_id, trace_name, trace_type, component._vertex, inputs, metadata
|
||||
)
|
||||
component_context_var.set(component_trace_context)
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called trace_component but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
trace_context.all_inputs[trace_name] |= inputs or {}
|
||||
await trace_context.traces_queue.put((self._start_component_traces, (component_trace_context, trace_context)))
|
||||
try:
|
||||
yield self
|
||||
except Exception as e:
|
||||
await trace_context.traces_queue.put(
|
||||
(self._end_component_traces, (component_trace_context, trace_context, e))
|
||||
)
|
||||
raise
|
||||
else:
|
||||
await trace_context.traces_queue.put(
|
||||
(self._end_component_traces, (component_trace_context, trace_context, None))
|
||||
)
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
if self.deactivated:
|
||||
return os.getenv("LANGCHAIN_PROJECT", "Langflow")
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called project_name but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
return trace_context.project_name
|
||||
|
||||
def add_log(self, trace_name: str, log: Log) -> None:
|
||||
"""Add a log to the current component trace context."""
|
||||
if self.deactivated:
|
||||
return
|
||||
component_context = component_context_var.get()
|
||||
if component_context is None:
|
||||
msg = "called add_log but no component context found"
|
||||
raise RuntimeError(msg)
|
||||
component_context.logs[trace_name].append(log)
|
||||
|
||||
def set_outputs(
|
||||
self,
|
||||
trace_name: str,
|
||||
outputs: dict[str, Any],
|
||||
output_metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Set the outputs for the current component trace context."""
|
||||
if self.deactivated:
|
||||
return
|
||||
component_context = component_context_var.get()
|
||||
if component_context is None:
|
||||
msg = "called set_outputs but no component context found"
|
||||
raise RuntimeError(msg)
|
||||
component_context.outputs[trace_name] |= outputs or {}
|
||||
component_context.outputs_metadata[trace_name] |= output_metadata or {}
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called set_outputs but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
trace_context.all_outputs[trace_name] |= outputs or {}
|
||||
|
||||
def get_tracer(self, tracer_name: str) -> BaseTracer | None:
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called get_tracer but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
return trace_context.tracers.get(tracer_name)
|
||||
|
||||
def get_langchain_callbacks(self) -> list[BaseCallbackHandler]:
|
||||
if self.deactivated:
|
||||
return []
|
||||
callbacks = []
|
||||
for tracer in self._tracers.values():
|
||||
trace_context = trace_context_var.get()
|
||||
if trace_context is None:
|
||||
msg = "called get_langchain_callbacks but no trace context found"
|
||||
raise RuntimeError(msg)
|
||||
for tracer in trace_context.tracers.values():
|
||||
if not tracer.ready: # type: ignore[truthy-function]
|
||||
continue
|
||||
langchain_callback = tracer.get_langchain_callback()
|
||||
if langchain_callback:
|
||||
callbacks.append(langchain_callback)
|
||||
return callbacks
|
||||
|
||||
def set_session_id(self, session_id: str) -> None:
|
||||
"""Set the session ID for tracing."""
|
||||
self.session_id = session_id
|
||||
|
|
|
|||
1
src/backend/tests/unit/services/tracing/__init__.py
Normal file
1
src/backend/tests/unit/services/tracing/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Services tests package."""
|
||||
515
src/backend/tests/unit/services/tracing/test_tracing_service.py
Normal file
515
src/backend/tests/unit/services/tracing/test_tracing_service.py
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langflow.services.settings.base import Settings
|
||||
from langflow.services.settings.service import SettingsService
|
||||
from langflow.services.tracing.base import BaseTracer
|
||||
from langflow.services.tracing.service import (
|
||||
TracingService,
|
||||
component_context_var,
|
||||
trace_context_var,
|
||||
)
|
||||
|
||||
|
||||
class MockTracer(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
project_name: str,
|
||||
trace_id: uuid.UUID,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> None:
|
||||
self.trace_name = trace_name
|
||||
self.trace_type = trace_type
|
||||
self.project_name = project_name
|
||||
self.trace_id = trace_id
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self._ready = True
|
||||
self.end_called = False
|
||||
self.get_langchain_callback_called = False
|
||||
self.add_trace_list = []
|
||||
self.end_trace_list = []
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
return self._ready
|
||||
|
||||
def add_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
trace_name: str,
|
||||
trace_type: str,
|
||||
inputs: dict[str, any],
|
||||
metadata: dict[str, any] | None = None,
|
||||
vertex=None,
|
||||
) -> None:
|
||||
self.add_trace_list.append(
|
||||
{
|
||||
"trace_id": trace_id,
|
||||
"trace_name": trace_name,
|
||||
"trace_type": trace_type,
|
||||
"inputs": inputs,
|
||||
"metadata": metadata,
|
||||
"vertex": vertex,
|
||||
}
|
||||
)
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
trace_name: str,
|
||||
outputs: dict[str, any] | None = None,
|
||||
error: Exception | None = None,
|
||||
logs=(),
|
||||
) -> None:
|
||||
self.end_trace_list.append(
|
||||
{
|
||||
"trace_id": trace_id,
|
||||
"trace_name": trace_name,
|
||||
"outputs": outputs,
|
||||
"error": error,
|
||||
"logs": logs,
|
||||
}
|
||||
)
|
||||
|
||||
def end(
|
||||
self,
|
||||
inputs: dict[str, any],
|
||||
outputs: dict[str, any],
|
||||
error: Exception | None = None,
|
||||
metadata: dict[str, any] | None = None,
|
||||
) -> None:
|
||||
self.end_called = True
|
||||
self.inputs_param = inputs
|
||||
self.outputs_param = outputs
|
||||
self.error_param = error
|
||||
self.metadata_param = metadata
|
||||
|
||||
def get_langchain_callback(self):
|
||||
self.get_langchain_callback_called = True
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_service():
|
||||
settings = Settings()
|
||||
settings.deactivate_tracing = False
|
||||
return SettingsService(settings, MagicMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracing_service(mock_settings_service):
|
||||
return TracingService(mock_settings_service)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_component():
|
||||
component = MagicMock()
|
||||
component._vertex = MagicMock()
|
||||
component._vertex.id = "test_vertex_id"
|
||||
component.trace_type = "test_trace_type"
|
||||
return component
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tracers():
|
||||
with (
|
||||
patch(
|
||||
"langflow.services.tracing.service._get_langsmith_tracer",
|
||||
return_value=MockTracer,
|
||||
),
|
||||
patch(
|
||||
"langflow.services.tracing.service._get_langwatch_tracer",
|
||||
return_value=MockTracer,
|
||||
),
|
||||
patch(
|
||||
"langflow.services.tracing.service._get_langfuse_tracer",
|
||||
return_value=MockTracer,
|
||||
),
|
||||
patch(
|
||||
"langflow.services.tracing.service._get_arize_phoenix_tracer",
|
||||
return_value=MockTracer,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_start_end_tracers(tracing_service):
|
||||
"""Test starting and ending tracers."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
outputs = {"output_key": "output_value"}
|
||||
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
# Verify trace_context is set correctly
|
||||
trace_context = trace_context_var.get()
|
||||
assert trace_context is not None
|
||||
assert trace_context.run_id == run_id
|
||||
assert trace_context.run_name == run_name
|
||||
assert trace_context.project_name == project_name
|
||||
assert trace_context.user_id == user_id
|
||||
assert trace_context.session_id == session_id
|
||||
|
||||
# Verify tracers are initialized
|
||||
assert "langsmith" in trace_context.tracers
|
||||
assert "langwatch" in trace_context.tracers
|
||||
assert "langfuse" in trace_context.tracers
|
||||
assert "arize_phoenix" in trace_context.tracers
|
||||
|
||||
await tracing_service.end_tracers(outputs)
|
||||
|
||||
# Verify end method was called for all tracers
|
||||
trace_context = trace_context_var.get()
|
||||
for tracer in trace_context.tracers.values():
|
||||
assert tracer.end_called
|
||||
assert tracer.metadata_param == outputs
|
||||
assert tracer.outputs_param == trace_context.all_outputs
|
||||
|
||||
# Verify worker_task is cancelled
|
||||
assert trace_context.worker_task is None
|
||||
assert not trace_context.running
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_trace_component(tracing_service, mock_component):
|
||||
"""Test component tracing context manager."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
trace_name = "test_component_trace"
|
||||
inputs = {"input_key": "input_value"}
|
||||
metadata = {"metadata_key": "metadata_value"}
|
||||
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
async with tracing_service.trace_component(mock_component, trace_name, inputs, metadata) as ts:
|
||||
# Verify component context is set
|
||||
component_context = component_context_var.get()
|
||||
assert component_context is not None
|
||||
assert component_context.trace_id == mock_component._vertex.id
|
||||
assert component_context.trace_name == trace_name
|
||||
assert component_context.trace_type == mock_component.trace_type
|
||||
assert component_context.vertex == mock_component._vertex
|
||||
assert component_context.inputs == inputs
|
||||
assert component_context.inputs_metadata == metadata
|
||||
|
||||
# Verify add_trace method was called for tracers
|
||||
await asyncio.sleep(0.1) # Wait for async queue processing
|
||||
trace_context = trace_context_var.get()
|
||||
for tracer in trace_context.tracers.values():
|
||||
assert tracer.add_trace_list[0]["trace_id"] == mock_component._vertex.id
|
||||
assert tracer.add_trace_list[0]["trace_name"] == trace_name
|
||||
assert tracer.add_trace_list[0]["trace_type"] == mock_component.trace_type
|
||||
assert tracer.add_trace_list[0]["inputs"] == inputs
|
||||
assert tracer.add_trace_list[0]["metadata"] == metadata
|
||||
assert tracer.add_trace_list[0]["vertex"] == mock_component._vertex
|
||||
|
||||
# Test adding logs
|
||||
ts.add_log(trace_name, {"message": "test log"})
|
||||
assert {"message": "test log"} in component_context.logs[trace_name]
|
||||
|
||||
# Test setting outputs
|
||||
outputs = {"output_key": "output_value"}
|
||||
output_metadata = {"output_metadata_key": "output_metadata_value"}
|
||||
ts.set_outputs(trace_name, outputs, output_metadata)
|
||||
assert component_context.outputs[trace_name] == outputs
|
||||
assert component_context.outputs_metadata[trace_name] == output_metadata
|
||||
assert trace_context.all_outputs[trace_name] == outputs
|
||||
|
||||
# Verify end_trace method was called for tracers
|
||||
await asyncio.sleep(0.1) # Wait for async queue processing
|
||||
for tracer in trace_context.tracers.values():
|
||||
assert tracer.end_trace_list[0]["trace_id"] == mock_component._vertex.id
|
||||
assert tracer.end_trace_list[0]["trace_name"] == trace_name
|
||||
assert tracer.end_trace_list[0]["outputs"] == trace_context.all_outputs[trace_name]
|
||||
assert tracer.end_trace_list[0]["error"] is None
|
||||
assert tracer.end_trace_list[0]["logs"] == component_context.logs[trace_name]
|
||||
|
||||
# Cleanup
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_trace_component_with_exception(tracing_service, mock_component):
|
||||
"""Test component tracing context manager with exception handling."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
trace_name = "test_component_trace"
|
||||
inputs = {"input_key": "input_value"}
|
||||
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
test_exception = ValueError("Test exception")
|
||||
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
async with tracing_service.trace_component(mock_component, trace_name, inputs):
|
||||
raise test_exception
|
||||
|
||||
# Verify end_trace method was called with exception
|
||||
await asyncio.sleep(0.1) # Wait for async queue processing
|
||||
trace_context = trace_context_var.get()
|
||||
for tracer in trace_context.tracers.values():
|
||||
assert tracer.end_trace_list[0]["error"] == test_exception
|
||||
|
||||
# Cleanup
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_get_langchain_callbacks(tracing_service):
|
||||
"""Test getting LangChain callback handlers."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
callbacks = tracing_service.get_langchain_callbacks()
|
||||
|
||||
# Verify get_langchain_callback method was called for each tracer
|
||||
trace_context = trace_context_var.get()
|
||||
for tracer in trace_context.tracers.values():
|
||||
assert tracer.get_langchain_callback_called
|
||||
|
||||
# Verify returned callbacks list length
|
||||
assert len(callbacks) == 4 # Four tracers
|
||||
|
||||
# Cleanup
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivated_tracing(mock_settings_service):
|
||||
"""Test deactivated tracing functionality."""
|
||||
# Set deactivate_tracing to True
|
||||
mock_settings_service.settings.deactivate_tracing = True
|
||||
tracing_service = TracingService(mock_settings_service)
|
||||
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
# Starting tracers should have no effect
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
# With tracing disabled, trace_context_var may be None or uninitialized
|
||||
assert trace_context_var.get() is None
|
||||
# We don't need to check trace_context_var state, just verify tracing operations don't execute
|
||||
|
||||
# Test trace_component context manager
|
||||
mock_component = MagicMock()
|
||||
trace_name = "test_component_trace"
|
||||
inputs = {"input_key": "input_value"}
|
||||
|
||||
async with tracing_service.trace_component(mock_component, trace_name, inputs) as ts:
|
||||
ts.add_log(trace_name, {"message": "test log"})
|
||||
ts.set_outputs(trace_name, {"output_key": "output_value"})
|
||||
|
||||
# Test getting LangChain callback handlers
|
||||
callbacks = tracing_service.get_langchain_callbacks()
|
||||
assert len(callbacks) == 0 # Should return empty list when tracing is disabled
|
||||
|
||||
# Test end_tracers
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_inputs():
|
||||
"""Test cleaning sensitive information from input data."""
|
||||
inputs = {
|
||||
"normal_key": "normal_value",
|
||||
"api_key": "secret_api_key",
|
||||
"openai_api_key": "secret_openai_api_key",
|
||||
"nested_api_key": {"api_key": "nested_secret"},
|
||||
}
|
||||
|
||||
cleaned_inputs = TracingService._cleanup_inputs(inputs)
|
||||
|
||||
# Verify values for keys containing api_key are replaced with *****
|
||||
assert cleaned_inputs["normal_key"] == "normal_value"
|
||||
assert cleaned_inputs["api_key"] == "*****"
|
||||
assert cleaned_inputs["openai_api_key"] == "*****"
|
||||
|
||||
# Verify values for keys containing api_key are replaced with *****, even in nested dicts
|
||||
assert cleaned_inputs["nested_api_key"] == "*****"
|
||||
|
||||
# Verify original input is not modified
|
||||
assert inputs["api_key"] == "secret_api_key"
|
||||
assert inputs["openai_api_key"] == "secret_openai_api_key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_tracers_with_exception(tracing_service):
|
||||
"""Test starting tracers with exception handling."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
# Mock _initialize_langsmith_tracer to raise exception
|
||||
with (
|
||||
patch.object(
|
||||
tracing_service,
|
||||
"_initialize_langsmith_tracer",
|
||||
side_effect=Exception("Mock exception"),
|
||||
),
|
||||
patch("langflow.services.tracing.service.logger.debug") as mock_logger,
|
||||
):
|
||||
# start_tracers should return normally even with exception
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
# Verify exception was logged
|
||||
mock_logger.assert_any_call("Error initializing tracers: Mock exception")
|
||||
|
||||
# Verify trace_context was set even with exception
|
||||
trace_context = trace_context_var.get()
|
||||
assert trace_context is not None
|
||||
assert trace_context.run_id == run_id
|
||||
assert trace_context.run_name == run_name
|
||||
|
||||
# Cleanup
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_trace_worker_with_exception(tracing_service):
|
||||
"""Test trace worker exception handling."""
|
||||
run_id = uuid.uuid4()
|
||||
run_name = "test_run"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
project_name = "test_project"
|
||||
|
||||
# Create a trace function that raises an exception
|
||||
def failing_trace_func():
|
||||
msg = "Mock trace function exception"
|
||||
raise ValueError(msg)
|
||||
|
||||
with patch("langflow.services.tracing.service.logger.exception") as mock_logger:
|
||||
# Remove incorrect context manager usage
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
# Get trace_context and add failing trace function to queue
|
||||
trace_context = trace_context_var.get()
|
||||
await trace_context.traces_queue.put((failing_trace_func, ()))
|
||||
|
||||
# Wait for async queue processing
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify exception was logged
|
||||
mock_logger.assert_called_with("Error processing trace_func")
|
||||
|
||||
# Cleanup
|
||||
await tracing_service.end_tracers({})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("mock_tracers")
|
||||
async def test_concurrent_tracing(tracing_service, mock_component):
|
||||
"""Test two tasks running start_tracers concurrently, with each task running 2 concurrent trace_component tasks."""
|
||||
|
||||
# Define common task function: start tracers and run two component traces
|
||||
async def run_task(
|
||||
run_id,
|
||||
run_name,
|
||||
user_id,
|
||||
session_id,
|
||||
project_name,
|
||||
inputs,
|
||||
metadata,
|
||||
task_prefix,
|
||||
sleep_duration=0.1,
|
||||
):
|
||||
await tracing_service.start_tracers(run_id, run_name, user_id, session_id, project_name)
|
||||
|
||||
async def run_component_task(component, trace_name, component_suffix):
|
||||
async with tracing_service.trace_component(component, trace_name, inputs, metadata) as ts:
|
||||
ts.add_log(trace_name, {"message": f"{task_prefix} {component_suffix} log"})
|
||||
outputs = {"output_key": f"{task_prefix}_{component_suffix}_output"}
|
||||
await asyncio.sleep(sleep_duration)
|
||||
ts.set_outputs(trace_name, outputs)
|
||||
|
||||
task1 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name1", f"{run_id} component1"))
|
||||
await task1
|
||||
task2 = asyncio.create_task(run_component_task(mock_component, f"{run_id} trace_name2", f"{run_id} component2"))
|
||||
await task2
|
||||
|
||||
await tracing_service.end_tracers({"final_output": f"{task_prefix}_final_output"})
|
||||
trace_context = trace_context_var.get()
|
||||
return trace_context.tracers["langfuse"]
|
||||
|
||||
inputs1 = {"input_key": "input_value1"}
|
||||
metadata1 = {"metadata_key": "metadata_value1"}
|
||||
inputs2 = {"input_key": "input_value2"}
|
||||
metadata2 = {"metadata_key": "metadata_value2"}
|
||||
|
||||
task1 = asyncio.create_task(
|
||||
run_task(
|
||||
"run_id1",
|
||||
"run_name1",
|
||||
"user_id1",
|
||||
"session_id1",
|
||||
"project_name1",
|
||||
inputs1,
|
||||
metadata1,
|
||||
"task1",
|
||||
2,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
task2 = asyncio.create_task(
|
||||
run_task(
|
||||
"run_id2",
|
||||
"run_name2",
|
||||
"user_id2",
|
||||
"session_id2",
|
||||
"project_name2",
|
||||
inputs2,
|
||||
metadata2,
|
||||
"task2",
|
||||
0.1,
|
||||
)
|
||||
)
|
||||
tracer1 = await task1
|
||||
tracer2 = await task2
|
||||
|
||||
# Verify tracer1 and tracer2 have correct trace data
|
||||
assert tracer1.trace_name == "run_name1"
|
||||
assert tracer1.project_name == "project_name1"
|
||||
assert tracer1.user_id == "user_id1"
|
||||
assert tracer1.session_id == "session_id1"
|
||||
assert dict(tracer1.outputs_param.get("run_id1 trace_name1")) == {"output_key": "task1_run_id1 component1_output"}
|
||||
assert dict(tracer1.outputs_param.get("run_id1 trace_name2")) == {"output_key": "task1_run_id1 component2_output"}
|
||||
|
||||
assert tracer2.trace_name == "run_name2"
|
||||
assert tracer2.project_name == "project_name2"
|
||||
assert tracer2.user_id == "user_id2"
|
||||
assert tracer2.session_id == "session_id2"
|
||||
assert dict(tracer2.outputs_param.get("run_id2 trace_name1")) == {"output_key": "task2_run_id2 component1_output"}
|
||||
assert dict(tracer2.outputs_param.get("run_id2 trace_name2")) == {"output_key": "task2_run_id2 component2_output"}
|
||||
Loading…
Add table
Add a link
Reference in a new issue