fix: traces inputs and outputs not being sent to tracing services (#4669)

* Fix node inputs not being captured, add runtime inputs as well to be captured by the tracers properly

* Fix outputs missing on traces due to them being reset before ending the traces because of race conditions

* Fallback to project name if none

* Remove 'dynamic inputs' to stop sending the component code every time

* fix: Add async flow name retrieval in graph building process

* fix: Retrieve flow name from database when building graph from data

* Fix: make session.exec call awaitable in chat API

* Refactor `_get_flow_name` to manage session internally

* Refactor session handling to use `async_session_scope` in chat API

* Refactor test cases to remove unnecessary async usage in mock functions

* [autofix.ci] apply automated fixes

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Rogério Chaves 2024-11-22 02:41:39 +01:00 committed by GitHub
commit 1532da59f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 32 additions and 25 deletions

View file

@ -17,7 +17,7 @@ from langflow.services.database.models import User
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from langflow.services.deps import get_async_session, get_session
from langflow.services.deps import async_session_scope, get_async_session, get_session
from langflow.services.store.utils import get_lf_version_from_pypi
if TYPE_CHECKING:
@ -142,8 +142,21 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"
async def _get_flow_name(flow_id: str) -> str:
async with async_session_scope() as session:
flow = await session.get(Flow, flow_id)
if flow is None:
msg = f"Flow {flow_id} not found"
raise ValueError(msg)
return flow.name
async def build_graph_from_data(flow_id: str, payload: dict, **kwargs):
"""Build and cache the graph."""
# Get flow name
if "flow_name" not in kwargs:
flow_name = await _get_flow_name(flow_id)
kwargs["flow_name"] = flow_name
graph = Graph.from_payload(payload, flow_id, **kwargs)
for vertex_id in graph.has_session_id_vertices:
vertex = graph.get_vertex(vertex_id)

View file

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from sqlmodel import select
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
from starlette.types import Receive
@ -42,7 +43,8 @@ from langflow.graph.utils import log_vertex_build
from langflow.schema.schema import OutputValue
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_async_session, get_chat_service, get_telemetry_service
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import async_session_scope, get_async_session, get_chat_service, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
if TYPE_CHECKING:
@ -166,7 +168,12 @@ async def build_flow(
if not data:
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session)
else:
graph = await build_graph_from_data(flow_id_str, data.model_dump(), user_id=str(current_user.id))
async with async_session_scope() as new_session:
result = await new_session.exec(select(Flow.name).where(Flow.id == flow_id_str))
flow_name = result.first()
graph = await build_graph_from_data(
flow_id_str, data.model_dump(), user_id=str(current_user.id), flow_name=flow_name
)
graph.validate_stream()
if stop_component_id or start_component_id:
try:

View file

@ -801,9 +801,9 @@ class Component(CustomComponent):
for input_ in self.inputs
if hasattr(input_, "trace_as_input") and input_.trace_as_input
}
# Dynamic inputs
dynamic_inputs = {key: value for key, value in self._attributes.items() if key not in predefined_inputs}
return {**predefined_inputs, **dynamic_inputs}
# Runtime inputs
runtime_inputs = {name: input_.value for name, input_ in self._inputs.items() if hasattr(input_, "value")}
return {**predefined_inputs, **runtime_inputs}
def get_trace_as_metadata(self):
return {

View file

@ -41,6 +41,7 @@ class LangWatchTracer(BaseTracer):
self.spans: dict[str, ContextSpan] = {}
name_without_id = " - ".join(trace_name.split(" - ")[0:-1])
name_without_id = project_name if name_without_id == "None" else name_without_id
self.trace.root_span.update(
# nanoid to make the span_id globally unique, which is required for LangWatch for now
span_id=f"{self.flow_id}-{nanoid.generate(size=6)}",

View file

@ -184,6 +184,7 @@ class TracingService(Service):
)
except Exception: # noqa: BLE001
logger.exception(f"Error ending trace {trace_name}")
self._reset_io()
def _end_all_traces(self, outputs: dict, error: Exception | None = None) -> None:
for tracer in self._tracers.values():
@ -192,10 +193,10 @@ class TracingService(Service):
tracer.end(self.inputs, outputs=self.outputs, error=error, metadata=outputs)
except Exception: # noqa: BLE001
logger.exception("Error ending all traces")
self._reset_io()
async def end(self, outputs: dict, error: Exception | None = None) -> None:
await asyncio.to_thread(self._end_all_traces, outputs, error)
self._reset_io()
await self.stop()
def add_log(self, trace_name: str, log: Log) -> None:
@ -236,7 +237,6 @@ class TracingService(Service):
task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error))
self.end_trace_tasks.add(task)
task.add_done_callback(self.end_trace_tasks.discard)
self._reset_io()
def set_outputs(
self,

View file

@ -36,21 +36,6 @@ class TestEventManager:
assert "on_test_event" in manager.events
assert manager.events["on_test_event"].func == manager.send_event
# Sending an event with valid event_type and data using pytest-asyncio plugin
async def test_sending_event_with_valid_type_and_data_asyncio_plugin(self):
async def mock_queue_put_nowait(item):
await queue.put(item)
queue = asyncio.Queue()
queue.put_nowait = mock_queue_put_nowait
manager = EventManager(queue)
manager.register_event("on_test_event", "test_type", manager.noop)
event_type = "test_type"
data = "test_data"
manager.send_event(event_type=event_type, data=data)
await queue.join()
assert queue.empty()
# Accessing a non-registered event callback via __getattr__ with the recommended fix
def test_accessing_non_registered_event_callback_with_recommended_fix(self):
queue = asyncio.Queue()
@ -70,7 +55,7 @@ class TestEventManager:
# Handling a large number of events in the queue
def test_handling_large_number_of_events(self):
async def mock_queue_put_nowait(item):
def mock_queue_put_nowait(item):
pass
queue = asyncio.Queue()
@ -97,6 +82,7 @@ class TestEventManager:
# Sending an event with complex data and verifying successful event transmission
async def test_sending_event_with_complex_data(self):
queue = asyncio.Queue()
manager = EventManager(queue)
manager.register_event("on_test_event", "test_type", manager.noop)
data = {"key": "value", "nested": [1, 2, 3]}
@ -134,7 +120,7 @@ class TestEventManager:
# Checking the performance impact of frequent event registrations
def test_performance_impact_frequent_registrations(self):
async def mock_callback(event_type: str, data: LoggableType):
def mock_callback(event_type: str, data: LoggableType):
pass
queue = asyncio.Queue()