From 2fa4ebd03695eb63f728209951f28448f470c995 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 22 Jun 2024 00:38:16 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20(service.py):=20Add=20Optional?= =?UTF-8?q?=20import=20from=20typing=20to=20allow=20for=20Optional=20type?= =?UTF-8?q?=20hint=20=E2=99=BB=EF=B8=8F=20(service.py):=20Refactor=20varia?= =?UTF-8?q?ble=20declarations=20and=20type=20annotations=20for=20better=20?= =?UTF-8?q?code=20readability=20and=20maintainability=20=F0=9F=93=9D=20(se?= =?UTF-8?q?rvice.py):=20Update=20method=20signatures=20and=20type=20annota?= =?UTF-8?q?tions=20for=20better=20clarity=20and=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/langflow/services/tracing/service.py | 67 ++++++++++++------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/backend/base/langflow/services/tracing/service.py b/src/backend/base/langflow/services/tracing/service.py index 2acec2c20..9c73890e8 100644 --- a/src/backend/base/langflow/services/tracing/service.py +++ b/src/backend/base/langflow/services/tracing/service.py @@ -4,7 +4,8 @@ import traceback from collections import defaultdict from contextlib import asynccontextmanager from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional +from uuid import UUID from langchain.callbacks.tracers.langchain import wait_for_all_tracers from loguru import logger @@ -24,15 +25,15 @@ class TracingService(Service): def __init__(self, settings_service: "SettingsService", monitor_service: "MonitorService"): self.settings_service = settings_service self.monitor_service = monitor_service - self.inputs = defaultdict(dict) - self.inputs_metadata = defaultdict(dict) - self.outputs = defaultdict(dict) - self.outputs_metadata = defaultdict(dict) - self.run_name = None - self.run_id = None + self.inputs: dict[str, dict] = defaultdict(dict) + self.inputs_metadata: dict[str, dict] = defaultdict(dict) + self.outputs: dict[str, dict] = defaultdict(dict) + self.outputs_metadata: dict[str, dict] = defaultdict(dict) + self.run_name: str | None = None + self.run_id: UUID | None = None self.project_name = None self._tracers: dict[str, LangSmithTracer] = {} - self.logs_queue = asyncio.Queue() + self.logs_queue: asyncio.Queue = asyncio.Queue() self.running = False self.worker_task = None @@ -97,10 +98,12 @@ class TracingService(Service): def set_run_name(self, name: str): self.run_name = name - def set_run_id(self, run_id: str): + def set_run_id(self, run_id: UUID): self.run_id = run_id - def _start_traces(self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None): + def _start_traces( + self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None + ): self.inputs[trace_name] = inputs self.inputs_metadata[trace_name] = metadata or {} for tracer in self._tracers.values(): @@ -120,17 +123,17 @@ class TracingService(Service): except Exception as e: logger.error(f"Error ending trace {trace_name}: {e}") - def _end_all_traces(self, error: str | None = None): + def _end_all_traces(self, outputs: dict, error: str | None = None): for tracer in self._tracers.values(): if not tracer.ready: continue try: - tracer.end(self.inputs, outputs=self.outputs, error=error) + tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs) except Exception as e: logger.error(f"Error ending all traces: {e}") - async def end(self, error: str | None = None): - self._end_all_traces(error) + async def end(self, outputs: dict, error: str | None = None): + self._end_all_traces(outputs, error) self._reset_io() await self.stop() @@ -150,7 +153,11 @@ class TracingService(Service): @asynccontextmanager async def trace_context( - self, trace_name: str, trace_type: str, inputs: Dict[str, Any] = None, metadata: Dict[str, Any] = None + self, + trace_name: str, + trace_type: str, + inputs: Dict[str, Any], + metadata: Optional[Dict[str, Any]] = None, ): self._start_traces(trace_name, trace_type, inputs, metadata) try: @@ -164,14 +171,14 @@ class TracingService(Service): self._end_traces(trace_name, None) self._reset_io() - def set_outputs(self, trace_name: str, outputs: Dict[str, Any], output_metadata: Dict[str, Any] = None): + def set_outputs(self, trace_name: str, outputs: Dict[str, Any], output_metadata: Dict[str, Any] | None = None): self.outputs[trace_name] |= outputs or {} self.outputs_metadata[trace_name] |= output_metadata or {} class LangSmithTracer: - def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: str): + def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID): from langsmith.run_trees import RunTree self.trace_name = trace_name @@ -203,7 +210,9 @@ class LangSmithTracer: os.environ["LANGCHAIN_TRACING_V2"] = "true" return True - def add_trace(self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None): + def add_trace( + self, trace_name: str, trace_type: str, inputs: Dict[str, Any], metadata: Dict[str, Any] | None = None + ): if not self._ready: return raw_inputs = {} @@ -214,13 +223,13 @@ class LangSmithTracer: processed_inputs = self._convert_to_langchain_types(inputs) child = self._run_tree.create_child( name=trace_name, - run_type=trace_type, + run_type=trace_type, # type: ignore[arg-type] inputs=processed_inputs, ) if metadata: child.add_metadata(raw_inputs) self._children[trace_name] = child - self._child_link = {} + self._child_link: dict[str, str] = {} def _convert_to_langchain_types(self, io_dict: Dict[str, Any]): converted = {} @@ -248,7 +257,7 @@ class LangSmithTracer: value = value.to_lc_document() return value - def end_trace(self, trace_name: str, outputs: Dict[str, Any] = None, error: str = None): + def end_trace(self, trace_name: str, outputs: Dict[str, Any] | None = None, error: str | None = None): child = self._children[trace_name] raw_outputs = {} processed_outputs = {} @@ -264,11 +273,21 @@ class LangSmithTracer: self._child_link[trace_name] = child.get_url() def add_log(self, trace_name: str, log: Log): - log_dict = {"name": log.name, "time": datetime.now(timezone.utc).isoformat(), "message": log.message} + log_dict = { + "name": log.get("name"), + "time": datetime.now(timezone.utc).isoformat(), + "message": log.get("message"), + } self._children[trace_name].add_event(log_dict) - def end(self, inputs: dict[str, Any], outputs: Dict[str, Any], error: str | None = None): - self._run_tree.add_metadata({"inputs": inputs}) + def end( + self, + inputs: dict[str, Any], + outputs: Dict[str, Any], + error: str | None = None, + metadata: Optional[dict[str, Any]] = None, + ): + self._run_tree.add_metadata({"inputs": inputs, "metadata": metadata or {}}) self._run_tree.end(outputs=outputs, error=error) self._run_tree.post() wait_for_all_tracers()