refactor: Add error handling to TracingService methods

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 17:59:07 -03:00
commit 44a23ebef4

View file

@ -55,22 +55,28 @@ class TracingService(Service):
for tracer in self._tracers.values():
if not tracer.ready:
continue
tracer.add_trace(trace_name, trace_type, inputs, metadata)
try:
tracer.add_trace(trace_name, trace_type, inputs, metadata)
except Exception as e:
logger.error(f"Error starting trace {trace_name}: {e}")
def _end_traces(self, trace_name: str, error: str | None = None):
for tracer in self._tracers.values():
if not tracer.ready:
continue
tracer.end_trace(trace_name=trace_name, outputs=self.outputs, error=error)
try:
tracer.end_trace(trace_name=trace_name, outputs=self.outputs, error=error)
except Exception as e:
logger.error(f"Error ending trace {trace_name}: {e}")
def _end_all_traces(self, outputs: dict[str, Any]):
def _end_all_traces(self, outputs: dict[str, Any], error: str | None = None):
for tracer in self._tracers.values():
if not tracer.ready:
continue
tracer.end(outputs=outputs)
tracer.end(outputs=outputs, error=error)
def end(self, outputs: dict[str, Any] | None = None):
self._end_all_traces(outputs)
async def end(self, outputs: dict[str, Any] | None = None, error: str | None = None):
self._end_all_traces(outputs, error)
self._reset_io()
@contextmanager
@ -90,8 +96,8 @@ class TracingService(Service):
self._reset_io()
def set_outputs(self, outputs: Dict[str, Any], output_metadata: Dict[str, Any] = None):
self.outputs = outputs
self.outputs_metadata = output_metadata
self.outputs |= outputs or {}
self.outputs_metadata |= output_metadata or {}
class LangSmithTracer:
@ -129,37 +135,46 @@ class LangSmithTracer:
def add_trace(self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None):
if not self._ready:
return
inputs = self._convert_to_langchain_types(inputs)
raw_inputs = {}
processed_inputs = {}
if inputs:
raw_inputs = inputs.copy()
raw_inputs |= metadata or {}
processed_inputs = self._convert_to_langchain_types(inputs)
child = self._run_tree.create_child(
name=trace_name,
run_type=trace_type,
inputs=inputs,
inputs=processed_inputs,
)
if metadata:
child.add_metadata(metadata)
child.add_metadata(raw_inputs)
self._children[trace_name] = child
def _convert_to_langchain_types(self, io_dict: Dict[str, Any]):
converted = {}
for key, value in io_dict.items():
converted[key] = self._convert_to_langchain_type(value)
return converted
def _convert_to_langchain_type(self, value):
from langflow.schema.message import Message
_converted = {}
for key, value in io_dict.items():
if isinstance(value, dict):
_converted[key] = self._convert_to_langchain_types(value)
elif isinstance(value, list):
_converted[key] = [self._convert_to_langchain_types(v) for v in value]
elif isinstance(value, Message):
if value.sender:
_converted[key] = value.to_lc_message()
elif "prompt" in value:
_converted[key] = value.load_lc_prompt()
else:
_converted[key] = value.to_lc_document()
elif isinstance(value, Data):
_converted[key] = value.to_lc_document()
if isinstance(value, dict):
for key, _value in value.copy().items():
_value = self._convert_to_langchain_type(_value)
value[key] = _value
elif isinstance(value, list):
value = [self._convert_to_langchain_type(v) for v in value]
elif isinstance(value, Message):
if "prompt" in value:
value = value.load_lc_prompt()
elif value.sender:
value = value.to_lc_message()
else:
_converted[key] = value
return _converted
value = value.to_lc_document()
elif isinstance(value, Data):
value = value.to_lc_document()
return value
def end_trace(self, trace_name: str, outputs: Dict[str, Any] = None, error: str = None):
child = self._children[trace_name]
@ -175,8 +190,8 @@ class LangSmithTracer:
else:
child.post()
def end(self, outputs: Dict[str, Any]):
self._run_tree.end(outputs=outputs)
def end(self, outputs: Dict[str, Any], error: str | None = None):
self._run_tree.end(outputs=outputs, error=error)
self._run_tree.post()
wait_for_all_tracers()
self._run_link = self._run_tree.get_url()