fix: streaming now works on Agent and ChatOutput (#7833)
* fix: add event handling for AI message chunks in stream processing * fix: integrate serialization in LangFuse and LangSmith tracers for improved data handling
This commit is contained in:
parent
5f96d18fb1
commit
bc1ee21ee9
4 changed files with 16 additions and 8 deletions
|
|
@ -4,7 +4,7 @@ from time import perf_counter
|
|||
from typing import Any, Protocol
|
||||
|
||||
from langchain_core.agents import AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langflow.schema.content_block import ContentBlock
|
||||
|
|
@ -213,6 +213,11 @@ async def handle_on_chain_stream(
|
|||
agent_message.properties.state = "complete"
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
elif isinstance(data_chunk, AIMessageChunk):
|
||||
agent_message.text += data_chunk.content
|
||||
agent_message.properties.state = "complete"
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
return agent_message, start_time
|
||||
|
||||
|
||||
|
|
@ -244,6 +249,7 @@ CHAIN_EVENT_HANDLERS: dict[str, ChainEventHandler] = {
|
|||
"on_chain_start": handle_on_chain_start,
|
||||
"on_chain_end": handle_on_chain_end,
|
||||
"on_chain_stream": handle_on_chain_stream,
|
||||
"on_chat_model_stream": handle_on_chain_stream,
|
||||
}
|
||||
|
||||
TOOL_EVENT_HANDLERS: dict[str, ToolEventHandler] = {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
|||
from loguru import logger
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.serialization.serialization import serialize
|
||||
from langflow.services.tracing.base import BaseTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -108,7 +109,7 @@ class LangFuseTracer(BaseTracer):
|
|||
# last_span = next(reversed(self.spans))
|
||||
# span = self.spans[last_span].span(**content_span)
|
||||
# else:
|
||||
span = self.trace.span(**content_span)
|
||||
span = self.trace.span(**serialize(content_span))
|
||||
|
||||
self.spans[trace_id] = span
|
||||
|
||||
|
|
@ -131,7 +132,7 @@ class LangFuseTracer(BaseTracer):
|
|||
output |= outputs or {}
|
||||
output |= {"error": str(error)} if error else {}
|
||||
output |= {"logs": list(logs)} if logs else {}
|
||||
content = {"output": output, "end_time": end_time}
|
||||
content = serialize({"output": output, "end_time": end_time})
|
||||
span.update(**content)
|
||||
|
||||
@override
|
||||
|
|
@ -149,7 +150,7 @@ class LangFuseTracer(BaseTracer):
|
|||
"output": outputs,
|
||||
"metadata": metadata,
|
||||
}
|
||||
self.trace.update(**content_update)
|
||||
self.trace.update(**serialize(content_update))
|
||||
self._client.flush()
|
||||
|
||||
def get_langchain_callback(self) -> BaseCallbackHandler | None:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from loguru import logger
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.schema.data import Data
|
||||
from langflow.serialization.serialization import serialize
|
||||
from langflow.services.tracing.base import BaseTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -158,10 +159,10 @@ class LangSmithTracer(BaseTracer):
|
|||
) -> None:
|
||||
if not self._ready or not self._run_tree:
|
||||
return
|
||||
self._run_tree.add_metadata({"inputs": inputs})
|
||||
self._run_tree.add_metadata({"inputs": serialize(inputs)})
|
||||
if metadata:
|
||||
self._run_tree.add_metadata(metadata)
|
||||
self._run_tree.end(outputs=outputs, error=self._error_to_string(error))
|
||||
self._run_tree.add_metadata(serialize(metadata))
|
||||
self._run_tree.end(outputs=serialize(outputs), error=self._error_to_string(error))
|
||||
self._run_tree.post()
|
||||
self._run_link = self._run_tree.get_url()
|
||||
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ class TracingService(Service):
|
|||
metadata=outputs,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error ending all traces")
|
||||
logger.error("Error ending all traces")
|
||||
|
||||
async def end_tracers(self, outputs: dict, error: Exception | None = None) -> None:
|
||||
"""End the trace for a graph run.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue