fix: Execute event manager callbacks in asyncio thread (#5150)

Execute event manager callbacks in asyncio thread
This commit is contained in:
Christophe Bornet 2024-12-09 17:23:02 +01:00 committed by GitHub
commit b692ef7848
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 19 deletions

View file

@ -169,7 +169,7 @@ class LCAgentComponent(Component):
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id
await delete_message(id_=msg_id)
self._send_message_event(e.agent_message, category="remove_message")
await self._send_message_event(e.agent_message, category="remove_message")
raise
except Exception:
raise

View file

@ -134,11 +134,11 @@ def _build_output_async_function(
async def output_function(*args, **kwargs):
try:
if event_manager:
event_manager.on_build_start(data={"id": component._id})
await asyncio.to_thread(event_manager.on_build_start, data={"id": component._id})
component.set(*args, **kwargs)
result = await output_method()
if event_manager:
event_manager.on_build_end(data={"id": component._id})
await asyncio.to_thread(event_manager.on_build_end, data={"id": component._id})
except Exception as e:
raise ToolException(e) from e
if isinstance(result, Message):

View file

@ -1034,7 +1034,7 @@ class Component(CustomComponent):
stored_message = await self._update_stored_message(stored_message)
else:
# Only send message event for non-streaming messages
self._send_message_event(stored_message, id_=id_)
await self._send_message_event(stored_message, id_=id_)
except Exception:
# remove the message from the database
await delete_message(stored_message.id)
@ -1051,19 +1051,23 @@ class Component(CustomComponent):
return messages[0]
def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None:
async def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None:
if hasattr(self, "_event_manager") and self._event_manager:
data_dict = message.data.copy() if hasattr(message, "data") else message.model_dump()
if id_ and not data_dict.get("id"):
data_dict["id"] = id_
category = category or data_dict.get("category", None)
match category:
case "error":
self._event_manager.on_error(data=data_dict)
case "remove_message":
self._event_manager.on_remove_message(data={"id": data_dict["id"]})
case _:
self._event_manager.on_message(data=data_dict)
def _send_event():
match category:
case "error":
self._event_manager.on_error(data=data_dict)
case "remove_message":
self._event_manager.on_remove_message(data={"id": data_dict["id"]})
case _:
self._event_manager.on_message(data=data_dict)
await asyncio.to_thread(_send_event)
def _should_stream_message(self, stored_message: Message, original_message: Message) -> bool:
return bool(
@ -1092,7 +1096,7 @@ class Component(CustomComponent):
complete_message = ""
first_chunk = True
for chunk in iterator:
complete_message = self._process_chunk(
complete_message = await self._process_chunk(
chunk.content, complete_message, message.id, message, first_chunk=first_chunk
)
first_chunk = False
@ -1105,13 +1109,13 @@ class Component(CustomComponent):
complete_message = ""
first_chunk = True
async for chunk in iterator:
complete_message = self._process_chunk(
complete_message = await self._process_chunk(
chunk.content, complete_message, message_id, message, first_chunk=first_chunk
)
first_chunk = False
return complete_message
def _process_chunk(
async def _process_chunk(
self, chunk: str, complete_message: str, message_id: str, message: Message, *, first_chunk: bool = False
) -> str:
complete_message += chunk
@ -1120,12 +1124,13 @@ class Component(CustomComponent):
# Send the initial message only on the first chunk
msg_copy = message.model_copy()
msg_copy.text = complete_message
self._send_message_event(msg_copy, id_=message_id)
self._event_manager.on_token(
await self._send_message_event(msg_copy, id_=message_id)
await asyncio.to_thread(
self._event_manager.on_token,
data={
"chunk": chunk,
"id": str(message_id),
}
},
)
return complete_message

View file

@ -1,4 +1,5 @@
import asyncio
import time
from typing import Any
from unittest.mock import MagicMock
@ -17,6 +18,11 @@ async def create_event_queue():
return asyncio.Queue()
def blocking_cb(manager, event_type, data):
time.sleep(0.01)
manager.send_event(event_type=event_type, data=data)
class ComponentForTesting(Component):
"""Test component that implements basic functionality."""
@ -39,6 +45,8 @@ async def test_component_message_sending():
queue = await create_event_queue()
event_manager = EventManager(queue)
event_manager.register_event("on_message", "message", callback=blocking_cb)
# Create component
component = ComponentForTesting()
component.set_event_manager(event_manager)
@ -196,7 +204,8 @@ async def test_component_streaming_message():
"""Test component's streaming message functionality."""
queue = await create_event_queue()
event_manager = EventManager(queue)
event_manager.register_event("on_token", "token")
event_manager.register_event("on_token", "token", blocking_cb)
# Create a proper mock vertex with graph and flow_id
vertex = MagicMock()