fix(tracing_service): use correct trace name and attribute name (#2484)

* fix(tracing_service): use correct trace name and attribute name

* fix: update tracing_service attribute name in LCModelComponent

* feat: add trace_name property in CustomComponent

* fix: update trace_name usage in Component build method

* feat: update log method in CustomComponent to use trace_name

* fix: update trace_name usage in Component build method

* feat(custom_component): add logs to CustomComponent

The `CustomComponent` class now includes a `_logs` attribute to store log messages. This attribute is initialized as an empty list in the constructor. The `log` method has been updated to accept an optional `name` parameter, which allows specifying a custom name for the log message. If no name is provided, a default name is generated based on the number of logs already stored.

This change enhances the logging functionality of the `CustomComponent` and provides more flexibility in managing log messages.

* feat(tracing_service): add logs to TracingService

This commit adds a `_logs` attribute to the `TracingService` class to store log messages. The attribute is initialized as a defaultdict of lists in the constructor. The `add_log` method has been updated to append logs to the corresponding trace name in the `_logs` dictionary. This change enhances the logging functionality of the `TracingService` and allows for better management of log messages.

* chore(tracing_service): improve error handling in stop method

* refactor(tracing/service.py): update _logs data structure to support both Log objects and generic dictionaries for flexibility in handling different types of data

refactor(tracing/service.py): add conditional check to only add metadata if it is provided for better control over the information being added

* refactor: update build_model method return type annotation

* refactor(CohereModel.py): update return type of build_model method to only LanguageModel for clarity and consistency

* chore(GroqModel.py): add stop_sequences parameter to GroqModel query method

* refactor(AstraDB.py): reorganize imports and update cached_vectorstore type

* refactor: update cached_vectorstore type and input order in CassandraVectorStoreComponent

* chore(GroqModel.py): remove unused stop_sequences parameter in GroqModel query method
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-02 15:49:27 -03:00 committed by GitHub
commit 5f0e74e5d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 54 additions and 53 deletions

View file

@ -145,7 +145,7 @@ class LCModelComponent(Component):
inputs: Union[list, dict] = messages or {}
try:
runnable = runnable.with_config( # type: ignore
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
{"run_name": self.display_name, "project_name": self.tracing_service.project_name} # type: ignore
)
if stream:
return runnable.stream(inputs) # type: ignore
@ -167,7 +167,7 @@ class LCModelComponent(Component):
raise e
@abstractmethod
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
"""
Implement this method to build the model.
"""

View file

@ -1,5 +1,4 @@
from langchain_cohere import ChatCohere
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr
from langflow.base.constants import STREAM_INFO_TEXT
@ -33,7 +32,7 @@ class CohereComponent(LCModelComponent):
),
]
def build_model(self) -> LanguageModel | BaseChatModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
cohere_api_key = self.cohere_api_key
temperature = self.temperature

View file

@ -1,9 +1,9 @@
from langchain_core.vectorstores import VectorStore
from loguru import logger
from langchain_core.vectorstores import VectorStore
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers import docs_to_data
from langflow.inputs import FloatInput, DictInput
from langflow.inputs import DictInput, FloatInput
from langflow.io import (
BoolInput,
DataInput,
@ -23,7 +23,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
icon: str = "AstraDB"
_cached_vectorstore: VectorStore = None
_cached_vectorstore: VectorStore | None = None
inputs = [
StrInput(

View file

@ -1,10 +1,11 @@
from typing import List
from langchain_community.vectorstores import Cassandra
from loguru import logger
from langflow.base.vectorstores.model import LCVectorStoreComponent
from langflow.helpers.data import docs_to_data
from langflow.inputs import DictInput, FloatInput, BoolInput
from langflow.inputs import BoolInput, DictInput, FloatInput
from langflow.io import (
DataInput,
DropdownInput,
@ -15,7 +16,6 @@ from langflow.io import (
SecretStrInput,
)
from langflow.schema import Data
from loguru import logger
class CassandraVectorStoreComponent(LCVectorStoreComponent):
@ -24,7 +24,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
documentation = "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/cassandra"
icon = "Cassandra"
_cached_vectorstore: Cassandra = None
_cached_vectorstore: Cassandra | None = None
inputs = [
MessageTextInput(

View file

@ -118,18 +118,24 @@ class Component(CustomComponent):
if hasattr(input_, "trace_as_metadata") and input_.trace_as_metadata
}
async def build_results(self):
async def _build_with_tracing(self):
inputs = self.get_trace_as_inputs()
metadata = self.get_trace_as_metadata()
async with self._tracing_service.trace_context(
f"{self.display_name} ({self.vertex.id})", self.trace_type, inputs, metadata
):
async with self.tracing_service.trace_context(self.trace_name, self.trace_type, inputs, metadata):
_results, _artifacts = await self._build_results()
trace_name = self._tracing_service.run_name
self._tracing_service.set_outputs(trace_name, _results)
trace_name = self.tracing_service.run_name
self.tracing_service.set_outputs(trace_name, _results)
return _results, _artifacts
async def _build_without_tracing(self):
return await self._build_results()
async def build_results(self):
if self.tracing_service:
return await self._build_with_tracing()
return await self._build_without_tracing()
async def _build_results(self):
_results = {}
_artifacts = {}
@ -184,9 +190,8 @@ class Component(CustomComponent):
_artifacts[output.name] = artifact
self._artifacts = _artifacts
self._results = _results
if self._tracing_service:
trace_name = self._tracing_service.run_name
self._tracing_service.set_outputs(trace_name, _results)
if self.tracing_service:
self.tracing_service.set_outputs(self.trace_name, _results)
return _results, _artifacts
def custom_repr(self):

View file

@ -85,6 +85,10 @@ class CustomComponent(BaseComponent):
_logs: List[Log] = []
tracing_service: Optional["TracingService"] = None
@property
def trace_name(self):
return f"{self.display_name} ({self.vertex.id})"
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
@ -131,6 +135,7 @@ class CustomComponent(BaseComponent):
**data: Additional keyword arguments to initialize the custom component.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
super().__init__(**data)
@staticmethod
@ -481,21 +486,19 @@ class CustomComponent(BaseComponent):
"""
raise NotImplementedError
def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
def log(self, message: LoggableType | list[LoggableType], name: Optional[str] = None):
"""
Logs a message.
Args:
message (LoggableType | list[LoggableType]): The message to log.
"""
if name is None and self.display_name:
name = self.display_name
else:
name = self.__class__.__name__
if name is None:
name = f"Log {len(self._logs) + 1}"
log = Log(message=message, type=get_artifact_type(message), name=name)
self._logs.append(log)
if self.tracing_service and self.vertex:
self.tracing_service.add_log(trace_name=self.vertex.id, log=log)
self.tracing_service.add_log(trace_name=self.trace_name, log=log)
def post_code_processing(self, new_build_config: dict, current_build_config: dict):
"""

View file

@ -59,7 +59,7 @@ async def build_component_and_get_results(
# Remove code from params
class_object: Type["CustomComponent" | "Component"] = eval_custom_component_code(params_copy.pop("code"))
custom_component: "CustomComponent" | "Component" = class_object(
user_id=user_id, parameters=params_copy, vertex=vertex, _tracing_service=tracing_service
user_id=user_id, parameters=params_copy, vertex=vertex, tracing_service=tracing_service
)
params_copy = update_params_with_load_from_db_fields(
custom_component, params_copy, vertex.load_from_db_fields, fallback_to_env_vars

View file

@ -34,6 +34,7 @@ class TracingService(Service):
self.run_id: UUID | None = None
self.project_name = None
self._tracers: dict[str, LangSmithTracer] = {}
self._logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list)
self.logs_queue: asyncio.Queue = asyncio.Queue()
self.running = False
self.worker_task = None
@ -70,8 +71,9 @@ class TracingService(Service):
# check the qeue is empty
if not self.logs_queue.empty():
await self.logs_queue.join()
self.worker_task.cancel()
self.worker_task = None
if self.worker_task:
self.worker_task.cancel()
self.worker_task = None
except Exception as e:
logger.error(f"Error stopping tracing service: {e}")
@ -123,7 +125,9 @@ class TracingService(Service):
if not tracer.ready:
continue
try:
tracer.end_trace(trace_name=trace_name, outputs=self.outputs[trace_name], error=error)
tracer.end_trace(
trace_name=trace_name, outputs=self.outputs[trace_name], error=error, logs=self._logs[trace_name]
)
except Exception as e:
logger.error(f"Error ending trace {trace_name}: {e}")
@ -141,19 +145,8 @@ class TracingService(Service):
self._reset_io()
await self.stop()
async def _add_log(self, trace_name: str, log: Log):
for tracer in self._tracers.values():
if not tracer.ready:
continue
try:
tracer.add_log(trace_name, log)
except Exception as e:
logger.error(f"Error adding log to trace {trace_name}: {e}")
def add_log(self, trace_name: str, log: Log):
if not self.running:
asyncio.run(self.start())
self.logs_queue.put_nowait((self._add_log, (trace_name, log)))
self._logs[trace_name].append(log)
@asynccontextmanager
async def trace_context(
@ -177,7 +170,6 @@ class TracingService(Service):
def set_outputs(self, trace_name: str, outputs: Dict[str, Any], output_metadata: Dict[str, Any] | None = None):
self.outputs[trace_name] |= outputs or {}
self.outputs_metadata[trace_name] |= output_metadata or {}
@ -235,7 +227,7 @@ class LangSmithTracer(BaseTracer):
inputs=processed_inputs,
)
if metadata:
child.add_metadata(raw_inputs)
child.add_metadata(metadata)
self._children[trace_name] = child
self._child_link: dict[str, str] = {}
@ -265,13 +257,21 @@ class LangSmithTracer(BaseTracer):
value = value.to_lc_document()
return value
def end_trace(self, trace_name: str, outputs: Dict[str, Any] | None = None, error: str | None = None):
def end_trace(
self,
trace_name: str,
outputs: Dict[str, Any] | None = None,
error: str | None = None,
logs: list[Log | dict] = [],
):
child = self._children[trace_name]
raw_outputs = {}
processed_outputs = {}
if outputs:
raw_outputs = outputs
processed_outputs = self._convert_to_langchain_types(outputs)
if logs:
child.add_metadata({"logs": {log.get("name"): log for log in logs}})
child.add_metadata({"outputs": raw_outputs})
child.end(outputs=processed_outputs, error=error)
if error:
@ -280,14 +280,6 @@ class LangSmithTracer(BaseTracer):
child.post()
self._child_link[trace_name] = child.get_url()
def add_log(self, trace_name: str, log: Log):
log_dict = {
"name": log.get("name"),
"time": datetime.now(timezone.utc).isoformat(),
"message": log.get("message"),
}
self._children[trace_name].add_event(log_dict)
def end(
self,
inputs: dict[str, Any],
@ -295,7 +287,9 @@ class LangSmithTracer(BaseTracer):
error: str | None = None,
metadata: dict[str, Any] | None = None,
):
self._run_tree.add_metadata({"inputs": inputs, "metadata": metadata or {}})
self._run_tree.add_metadata({"inputs": inputs})
if metadata:
self._run_tree.add_metadata(metadata)
self._run_tree.end(outputs=outputs, error=error)
self._run_tree.post()
wait_for_all_tracers()