chore: Refactor Component class to include tracing functionality
This commit is contained in:
parent
a6e9972f4b
commit
935aefcdea
2 changed files with 41 additions and 15 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, ClassVar, Generator, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Callable, ClassVar, Generator, Iterator, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
|
|
@ -13,9 +13,6 @@ from langflow.template.field.base import UNDEFINED, Output
|
|||
|
||||
from .custom_component import CustomComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
def recursive_serialize_or_str(obj):
|
||||
try:
|
||||
|
|
@ -47,6 +44,8 @@ class Component(CustomComponent):
|
|||
self._inputs: dict[str, InputTypes] = {}
|
||||
self._results: dict[str, Any] = {}
|
||||
self._attributes: dict[str, Any] = {}
|
||||
if not hasattr(self, "trace_type"):
|
||||
self.trace_type = "chain"
|
||||
if self.inputs is not None:
|
||||
self.map_inputs(self.inputs)
|
||||
|
||||
|
|
@ -57,12 +56,6 @@ class Component(CustomComponent):
|
|||
return self.__dict__["_inputs"][name].value
|
||||
raise AttributeError(f"{name} not found in {self.__class__.__name__}")
|
||||
|
||||
# def __getattribute__(self, name: str) -> Any:
|
||||
# try:
|
||||
# return super().__getattribute__(name)
|
||||
# except AttributeError:
|
||||
# return self.__getattr__(name)
|
||||
|
||||
def map_inputs(self, inputs: List[InputTypes]):
|
||||
self.inputs = inputs
|
||||
for input_ in inputs:
|
||||
|
|
@ -105,15 +98,40 @@ class Component(CustomComponent):
|
|||
for output in self.outputs:
|
||||
setattr(self, output.name, output)
|
||||
|
||||
async def build_results(self, vertex: "Vertex"):
|
||||
def get_trace_as_inputs(self):
|
||||
return {
|
||||
input_.name: input_.value
|
||||
for input_ in self.inputs
|
||||
if hasattr(input_, "trace_as_input") and input_.trace_as_input
|
||||
}
|
||||
|
||||
def get_trace_as_metadata(self):
|
||||
return {
|
||||
input_.name: input_.value
|
||||
for input_ in self.inputs
|
||||
if hasattr(input_, "trace_as_metadata") and input_.trace_as_metadata
|
||||
}
|
||||
|
||||
async def build_results(self):
|
||||
inputs = self.get_trace_as_inputs()
|
||||
metadata = self.get_trace_as_metadata()
|
||||
with self.tracing_service.trace_context(
|
||||
f"{self.display_name} ({self.vertex.id})", self.trace_type, inputs, metadata
|
||||
):
|
||||
_results, _artifacts = await self._build_results()
|
||||
self.tracing_service.set_outputs(_results)
|
||||
|
||||
return _results, _artifacts
|
||||
|
||||
async def _build_results(self):
|
||||
_results = {}
|
||||
_artifacts = {}
|
||||
if hasattr(self, "outputs"):
|
||||
self._set_outputs(vertex.outputs)
|
||||
self._set_outputs(self.vertex.outputs)
|
||||
for output in self.outputs:
|
||||
# Build the output if it's connected to some other vertex
|
||||
# or if it's not connected to any vertex
|
||||
if not vertex.outgoing_edges or output.name in vertex.edges_source_names:
|
||||
if not self.vertex.outgoing_edges or output.name in self.vertex.edges_source_names:
|
||||
if output.method is None:
|
||||
raise ValueError(f"Output {output.name} does not have a method defined.")
|
||||
method: Callable = getattr(self, output.method)
|
||||
|
|
@ -124,8 +142,12 @@ class Component(CustomComponent):
|
|||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
if isinstance(result, Message) and result.flow_id is None and vertex.graph.flow_id is not None:
|
||||
result.set_flow_id(vertex.graph.flow_id)
|
||||
if (
|
||||
isinstance(result, Message)
|
||||
and result.flow_id is None
|
||||
and self.vertex.graph.flow_id is not None
|
||||
):
|
||||
result.set_flow_id(self.vertex.graph.flow_id)
|
||||
_results[output.name] = result
|
||||
output.value = result
|
||||
custom_repr = self.custom_repr()
|
||||
|
|
@ -155,6 +177,8 @@ class Component(CustomComponent):
|
|||
_artifacts[output.name] = artifact
|
||||
self._artifacts = _artifacts
|
||||
self._results = _results
|
||||
if self.tracing_service:
|
||||
self.tracing_service.set_outputs(_results)
|
||||
return _results, _artifacts
|
||||
|
||||
def custom_repr(self):
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.services.tracing.service import TracingService
|
||||
|
||||
|
||||
LoggableType = Union[str, dict, list, int, float, bool, None, Data, Message]
|
||||
|
|
@ -82,6 +83,7 @@ class CustomComponent(BaseComponent):
|
|||
"""The status of the component. This is displayed on the frontend. Defaults to None."""
|
||||
_flows_data: Optional[List[Data]] = None
|
||||
_logs: List[Log] = []
|
||||
tracing_service: Optional["TracingService"] = None
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue