chore: Refactor Component class to include tracing functionality

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-19 19:00:59 -03:00
commit 935aefcdea
2 changed files with 41 additions and 15 deletions

View file

@ -1,5 +1,5 @@
import inspect
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, ClassVar, Generator, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Callable, ClassVar, Generator, Iterator, List, Optional, Union
from uuid import UUID
import yaml
@ -13,9 +13,6 @@ from langflow.template.field.base import UNDEFINED, Output
from .custom_component import CustomComponent
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
def recursive_serialize_or_str(obj):
try:
@ -47,6 +44,8 @@ class Component(CustomComponent):
self._inputs: dict[str, InputTypes] = {}
self._results: dict[str, Any] = {}
self._attributes: dict[str, Any] = {}
if not hasattr(self, "trace_type"):
self.trace_type = "chain"
if self.inputs is not None:
self.map_inputs(self.inputs)
@ -57,12 +56,6 @@ class Component(CustomComponent):
return self.__dict__["_inputs"][name].value
raise AttributeError(f"{name} not found in {self.__class__.__name__}")
# def __getattribute__(self, name: str) -> Any:
# try:
# return super().__getattribute__(name)
# except AttributeError:
# return self.__getattr__(name)
def map_inputs(self, inputs: List[InputTypes]):
self.inputs = inputs
for input_ in inputs:
@ -105,15 +98,40 @@ class Component(CustomComponent):
for output in self.outputs:
setattr(self, output.name, output)
async def build_results(self, vertex: "Vertex"):
def get_trace_as_inputs(self):
return {
input_.name: input_.value
for input_ in self.inputs
if hasattr(input_, "trace_as_input") and input_.trace_as_input
}
def get_trace_as_metadata(self):
return {
input_.name: input_.value
for input_ in self.inputs
if hasattr(input_, "trace_as_metadata") and input_.trace_as_metadata
}
async def build_results(self):
inputs = self.get_trace_as_inputs()
metadata = self.get_trace_as_metadata()
with self.tracing_service.trace_context(
f"{self.display_name} ({self.vertex.id})", self.trace_type, inputs, metadata
):
_results, _artifacts = await self._build_results()
self.tracing_service.set_outputs(_results)
return _results, _artifacts
async def _build_results(self):
_results = {}
_artifacts = {}
if hasattr(self, "outputs"):
self._set_outputs(vertex.outputs)
self._set_outputs(self.vertex.outputs)
for output in self.outputs:
# Build the output if it's connected to some other vertex
# or if it's not connected to any vertex
if not vertex.outgoing_edges or output.name in vertex.edges_source_names:
if not self.vertex.outgoing_edges or output.name in self.vertex.edges_source_names:
if output.method is None:
raise ValueError(f"Output {output.name} does not have a method defined.")
method: Callable = getattr(self, output.method)
@ -124,8 +142,12 @@ class Component(CustomComponent):
# If the method is asynchronous, we need to await it
if inspect.iscoroutinefunction(method):
result = await result
if isinstance(result, Message) and result.flow_id is None and vertex.graph.flow_id is not None:
result.set_flow_id(vertex.graph.flow_id)
if (
isinstance(result, Message)
and result.flow_id is None
and self.vertex.graph.flow_id is not None
):
result.set_flow_id(self.vertex.graph.flow_id)
_results[output.name] = result
output.value = result
custom_repr = self.custom_repr()
@ -155,6 +177,8 @@ class Component(CustomComponent):
_artifacts[output.name] = artifact
self._artifacts = _artifacts
self._results = _results
if self.tracing_service:
self.tracing_service.set_outputs(_results)
return _results, _artifacts
def custom_repr(self):

View file

@ -26,6 +26,7 @@ if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
from langflow.services.tracing.service import TracingService
LoggableType = Union[str, dict, list, int, float, bool, None, Data, Message]
@ -82,6 +83,7 @@ class CustomComponent(BaseComponent):
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_data: Optional[List[Data]] = None
_logs: List[Log] = []
tracing_service: Optional["TracingService"] = None
def update_state(self, name: str, value: Any):
if not self.vertex: