fix: Execute event manager callbacks in asyncio thread (#5150)
Execute event manager callbacks in asyncio thread
This commit is contained in:
parent
bb703f6e0d
commit
b692ef7848
4 changed files with 33 additions and 19 deletions
|
|
@ -169,7 +169,7 @@ class LCAgentComponent(Component):
|
|||
except ExceptionWithMessageError as e:
|
||||
msg_id = e.agent_message.id
|
||||
await delete_message(id_=msg_id)
|
||||
self._send_message_event(e.agent_message, category="remove_message")
|
||||
await self._send_message_event(e.agent_message, category="remove_message")
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -134,11 +134,11 @@ def _build_output_async_function(
|
|||
async def output_function(*args, **kwargs):
|
||||
try:
|
||||
if event_manager:
|
||||
event_manager.on_build_start(data={"id": component._id})
|
||||
await asyncio.to_thread(event_manager.on_build_start, data={"id": component._id})
|
||||
component.set(*args, **kwargs)
|
||||
result = await output_method()
|
||||
if event_manager:
|
||||
event_manager.on_build_end(data={"id": component._id})
|
||||
await asyncio.to_thread(event_manager.on_build_end, data={"id": component._id})
|
||||
except Exception as e:
|
||||
raise ToolException(e) from e
|
||||
if isinstance(result, Message):
|
||||
|
|
|
|||
|
|
@ -1034,7 +1034,7 @@ class Component(CustomComponent):
|
|||
stored_message = await self._update_stored_message(stored_message)
|
||||
else:
|
||||
# Only send message event for non-streaming messages
|
||||
self._send_message_event(stored_message, id_=id_)
|
||||
await self._send_message_event(stored_message, id_=id_)
|
||||
except Exception:
|
||||
# remove the message from the database
|
||||
await delete_message(stored_message.id)
|
||||
|
|
@ -1051,19 +1051,23 @@ class Component(CustomComponent):
|
|||
|
||||
return messages[0]
|
||||
|
||||
def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None:
|
||||
async def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None:
|
||||
if hasattr(self, "_event_manager") and self._event_manager:
|
||||
data_dict = message.data.copy() if hasattr(message, "data") else message.model_dump()
|
||||
if id_ and not data_dict.get("id"):
|
||||
data_dict["id"] = id_
|
||||
category = category or data_dict.get("category", None)
|
||||
match category:
|
||||
case "error":
|
||||
self._event_manager.on_error(data=data_dict)
|
||||
case "remove_message":
|
||||
self._event_manager.on_remove_message(data={"id": data_dict["id"]})
|
||||
case _:
|
||||
self._event_manager.on_message(data=data_dict)
|
||||
|
||||
def _send_event():
|
||||
match category:
|
||||
case "error":
|
||||
self._event_manager.on_error(data=data_dict)
|
||||
case "remove_message":
|
||||
self._event_manager.on_remove_message(data={"id": data_dict["id"]})
|
||||
case _:
|
||||
self._event_manager.on_message(data=data_dict)
|
||||
|
||||
await asyncio.to_thread(_send_event)
|
||||
|
||||
def _should_stream_message(self, stored_message: Message, original_message: Message) -> bool:
|
||||
return bool(
|
||||
|
|
@ -1092,7 +1096,7 @@ class Component(CustomComponent):
|
|||
complete_message = ""
|
||||
first_chunk = True
|
||||
for chunk in iterator:
|
||||
complete_message = self._process_chunk(
|
||||
complete_message = await self._process_chunk(
|
||||
chunk.content, complete_message, message.id, message, first_chunk=first_chunk
|
||||
)
|
||||
first_chunk = False
|
||||
|
|
@ -1105,13 +1109,13 @@ class Component(CustomComponent):
|
|||
complete_message = ""
|
||||
first_chunk = True
|
||||
async for chunk in iterator:
|
||||
complete_message = self._process_chunk(
|
||||
complete_message = await self._process_chunk(
|
||||
chunk.content, complete_message, message_id, message, first_chunk=first_chunk
|
||||
)
|
||||
first_chunk = False
|
||||
return complete_message
|
||||
|
||||
def _process_chunk(
|
||||
async def _process_chunk(
|
||||
self, chunk: str, complete_message: str, message_id: str, message: Message, *, first_chunk: bool = False
|
||||
) -> str:
|
||||
complete_message += chunk
|
||||
|
|
@ -1120,12 +1124,13 @@ class Component(CustomComponent):
|
|||
# Send the initial message only on the first chunk
|
||||
msg_copy = message.model_copy()
|
||||
msg_copy.text = complete_message
|
||||
self._send_message_event(msg_copy, id_=message_id)
|
||||
self._event_manager.on_token(
|
||||
await self._send_message_event(msg_copy, id_=message_id)
|
||||
await asyncio.to_thread(
|
||||
self._event_manager.on_token,
|
||||
data={
|
||||
"chunk": chunk,
|
||||
"id": str(message_id),
|
||||
}
|
||||
},
|
||||
)
|
||||
return complete_message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -17,6 +18,11 @@ async def create_event_queue():
|
|||
return asyncio.Queue()
|
||||
|
||||
|
||||
def blocking_cb(manager, event_type, data):
|
||||
time.sleep(0.01)
|
||||
manager.send_event(event_type=event_type, data=data)
|
||||
|
||||
|
||||
class ComponentForTesting(Component):
|
||||
"""Test component that implements basic functionality."""
|
||||
|
||||
|
|
@ -39,6 +45,8 @@ async def test_component_message_sending():
|
|||
queue = await create_event_queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
event_manager.register_event("on_message", "message", callback=blocking_cb)
|
||||
|
||||
# Create component
|
||||
component = ComponentForTesting()
|
||||
component.set_event_manager(event_manager)
|
||||
|
|
@ -196,7 +204,8 @@ async def test_component_streaming_message():
|
|||
"""Test component's streaming message functionality."""
|
||||
queue = await create_event_queue()
|
||||
event_manager = EventManager(queue)
|
||||
event_manager.register_event("on_token", "token")
|
||||
|
||||
event_manager.register_event("on_token", "token", blocking_cb)
|
||||
|
||||
# Create a proper mock vertex with graph and flow_id
|
||||
vertex = MagicMock()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue