refactor: Update TracingService to use defaultdict for inputs and outputs
This commit is contained in:
parent
1e7b92a44c
commit
977181fd22
1 changed files with 48 additions and 23 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
|
@ -23,9 +24,10 @@ class TracingService(Service):
|
|||
def __init__(self, settings_service: "SettingsService", monitor_service: "MonitorService"):
|
||||
self.settings_service = settings_service
|
||||
self.monitor_service = monitor_service
|
||||
self.inputs = {}
|
||||
self.outputs = {}
|
||||
self.outputs_metadata = {}
|
||||
self.inputs = defaultdict(dict)
|
||||
self.inputs_metadata = defaultdict(dict)
|
||||
self.outputs = defaultdict(dict)
|
||||
self.outputs_metadata = defaultdict(dict)
|
||||
self.run_name = None
|
||||
self.run_id = None
|
||||
self.project_name = None
|
||||
|
|
@ -45,27 +47,42 @@ class TracingService(Service):
|
|||
self.logs_queue.task_done()
|
||||
|
||||
async def start(self):
|
||||
if not self.running:
|
||||
if self.running:
|
||||
return
|
||||
try:
|
||||
self.running = True
|
||||
self.worker_task = asyncio.create_task(self.log_worker())
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting tracing service: {e}")
|
||||
|
||||
async def flush(self):
|
||||
await self.logs_queue.join()
|
||||
try:
|
||||
await self.logs_queue.join()
|
||||
except Exception as e:
|
||||
logger.error(f"Error flushing logs: {e}")
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
await self.flush()
|
||||
self.worker_task.cancel()
|
||||
if self.worker_task:
|
||||
await self.worker_task
|
||||
try:
|
||||
self.running = False
|
||||
await self.flush()
|
||||
self.worker_task.cancel()
|
||||
if self.worker_task:
|
||||
await self.worker_task
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping tracing service: {e}")
|
||||
|
||||
def _reset_io(self):
|
||||
self.inputs = {}
|
||||
self.outputs = {}
|
||||
self.inputs = defaultdict(dict)
|
||||
self.inputs_metadata = defaultdict(dict)
|
||||
self.outputs = defaultdict(dict)
|
||||
self.outputs_metadata = defaultdict(dict)
|
||||
|
||||
async def initialize_tracers(self):
|
||||
await self.start()
|
||||
self._initialize_langsmith_tracer()
|
||||
try:
|
||||
await self.start()
|
||||
self._initialize_langsmith_tracer()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tracers: {e}")
|
||||
|
||||
def _initialize_langsmith_tracer(self):
|
||||
project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow")
|
||||
|
|
@ -84,6 +101,8 @@ class TracingService(Service):
|
|||
self.run_id = run_id
|
||||
|
||||
def _start_traces(self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None):
|
||||
self.inputs[trace_name] = inputs
|
||||
self.inputs_metadata[trace_name] = metadata or {}
|
||||
for tracer in self._tracers.values():
|
||||
if not tracer.ready:
|
||||
continue
|
||||
|
|
@ -97,18 +116,21 @@ class TracingService(Service):
|
|||
if not tracer.ready:
|
||||
continue
|
||||
try:
|
||||
tracer.end_trace(trace_name=trace_name, outputs=self.outputs, error=error)
|
||||
tracer.end_trace(trace_name=trace_name, outputs=self.outputs[trace_name], error=error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error ending trace {trace_name}: {e}")
|
||||
|
||||
def _end_all_traces(self, outputs: dict[str, Any], error: str | None = None):
|
||||
def _end_all_traces(self, error: str | None = None):
|
||||
for tracer in self._tracers.values():
|
||||
if not tracer.ready:
|
||||
continue
|
||||
tracer.end(outputs=outputs, error=error)
|
||||
try:
|
||||
tracer.end(self.inputs, outputs=self.outputs, error=error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error ending all traces: {e}")
|
||||
|
||||
async def end(self, outputs: dict[str, Any] | None = None, error: str | None = None):
|
||||
self._end_all_traces(outputs, error)
|
||||
async def end(self, error: str | None = None):
|
||||
self._end_all_traces(error)
|
||||
self._reset_io()
|
||||
await self.stop()
|
||||
|
||||
|
|
@ -142,9 +164,10 @@ class TracingService(Service):
|
|||
self._end_traces(trace_name, None)
|
||||
self._reset_io()
|
||||
|
||||
def set_outputs(self, outputs: Dict[str, Any], output_metadata: Dict[str, Any] = None):
|
||||
self.outputs |= outputs or {}
|
||||
self.outputs_metadata |= output_metadata or {}
|
||||
def set_outputs(self, trace_name: str, outputs: Dict[str, Any], output_metadata: Dict[str, Any] = None):
|
||||
self.outputs[trace_name] |= outputs or {}
|
||||
|
||||
self.outputs_metadata[trace_name] |= output_metadata or {}
|
||||
|
||||
|
||||
class LangSmithTracer:
|
||||
|
|
@ -161,6 +184,7 @@ class LangSmithTracer:
|
|||
run_type=self.trace_type,
|
||||
id=self.trace_id,
|
||||
)
|
||||
self._run_tree.add_event({"name": "Start", "time": datetime.now(timezone.utc).isoformat()})
|
||||
self._children: dict[str, RunTree] = {}
|
||||
self._ready = self.setup_langsmith()
|
||||
|
||||
|
|
@ -243,7 +267,8 @@ class LangSmithTracer:
|
|||
log_dict = {"name": log.name, "time": datetime.now(timezone.utc).isoformat(), "message": log.message}
|
||||
self._children[trace_name].add_event(log_dict)
|
||||
|
||||
def end(self, outputs: Dict[str, Any], error: str | None = None):
|
||||
def end(self, inputs: dict[str, Any], outputs: Dict[str, Any], error: str | None = None):
|
||||
self._run_tree.add_metadata({"inputs": inputs})
|
||||
self._run_tree.end(outputs=outputs, error=error)
|
||||
self._run_tree.post()
|
||||
wait_for_all_tracers()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue