From b692ef7848a8786d4e9f8a358d4a995bf540ff10 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 9 Dec 2024 17:23:02 +0100 Subject: [PATCH] fix: Execute event manager callbacks in asyncio thread (#5150) Execute event manager callbacks in asyncio thread --- .../base/langflow/base/agents/agent.py | 2 +- .../langflow/base/tools/component_tool.py | 4 +-- .../custom/custom_component/component.py | 35 +++++++++++-------- .../custom_component/test_component_events.py | 11 +++++- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index e0a030b0d..e154b1968 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -169,7 +169,7 @@ class LCAgentComponent(Component): except ExceptionWithMessageError as e: msg_id = e.agent_message.id await delete_message(id_=msg_id) - self._send_message_event(e.agent_message, category="remove_message") + await self._send_message_event(e.agent_message, category="remove_message") raise except Exception: raise diff --git a/src/backend/base/langflow/base/tools/component_tool.py b/src/backend/base/langflow/base/tools/component_tool.py index dc832cb53..e3f82d8d6 100644 --- a/src/backend/base/langflow/base/tools/component_tool.py +++ b/src/backend/base/langflow/base/tools/component_tool.py @@ -134,11 +134,11 @@ def _build_output_async_function( async def output_function(*args, **kwargs): try: if event_manager: - event_manager.on_build_start(data={"id": component._id}) + await asyncio.to_thread(event_manager.on_build_start, data={"id": component._id}) component.set(*args, **kwargs) result = await output_method() if event_manager: - event_manager.on_build_end(data={"id": component._id}) + await asyncio.to_thread(event_manager.on_build_end, data={"id": component._id}) except Exception as e: raise ToolException(e) from e if isinstance(result, Message): diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 00b210cec..3185fce6c 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -1034,7 +1034,7 @@ class Component(CustomComponent): stored_message = await self._update_stored_message(stored_message) else: # Only send message event for non-streaming messages - self._send_message_event(stored_message, id_=id_) + await self._send_message_event(stored_message, id_=id_) except Exception: # remove the message from the database await delete_message(stored_message.id) @@ -1051,19 +1051,23 @@ class Component(CustomComponent): return messages[0] - def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None: + async def _send_message_event(self, message: Message, id_: str | None = None, category: str | None = None) -> None: if hasattr(self, "_event_manager") and self._event_manager: data_dict = message.data.copy() if hasattr(message, "data") else message.model_dump() if id_ and not data_dict.get("id"): data_dict["id"] = id_ category = category or data_dict.get("category", None) - match category: - case "error": - self._event_manager.on_error(data=data_dict) - case "remove_message": - self._event_manager.on_remove_message(data={"id": data_dict["id"]}) - case _: - self._event_manager.on_message(data=data_dict) + + def _send_event(): + match category: + case "error": + self._event_manager.on_error(data=data_dict) + case "remove_message": + self._event_manager.on_remove_message(data={"id": data_dict["id"]}) + case _: + self._event_manager.on_message(data=data_dict) + + await asyncio.to_thread(_send_event) def _should_stream_message(self, stored_message: Message, original_message: Message) -> bool: return bool( @@ -1092,7 +1096,7 @@ class Component(CustomComponent): complete_message = "" first_chunk = True for chunk in iterator: - complete_message = self._process_chunk( + complete_message = await self._process_chunk( chunk.content, complete_message, message.id, message, first_chunk=first_chunk ) first_chunk = False @@ -1105,13 +1109,13 @@ class Component(CustomComponent): complete_message = "" first_chunk = True async for chunk in iterator: - complete_message = self._process_chunk( + complete_message = await self._process_chunk( chunk.content, complete_message, message_id, message, first_chunk=first_chunk ) first_chunk = False return complete_message - def _process_chunk( + async def _process_chunk( self, chunk: str, complete_message: str, message_id: str, message: Message, *, first_chunk: bool = False ) -> str: complete_message += chunk @@ -1120,12 +1124,13 @@ class Component(CustomComponent): # Send the initial message only on the first chunk msg_copy = message.model_copy() msg_copy.text = complete_message - self._send_message_event(msg_copy, id_=message_id) - self._event_manager.on_token( + await self._send_message_event(msg_copy, id_=message_id) + await asyncio.to_thread( + self._event_manager.on_token, data={ "chunk": chunk, "id": str(message_id), - } + }, ) return complete_message diff --git a/src/backend/tests/unit/custom/custom_component/test_component_events.py b/src/backend/tests/unit/custom/custom_component/test_component_events.py index 52490526f..8b433191a 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component_events.py +++ b/src/backend/tests/unit/custom/custom_component/test_component_events.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Any from unittest.mock import MagicMock @@ -17,6 +18,11 @@ async def create_event_queue(): return asyncio.Queue() +def blocking_cb(manager, event_type, data): + time.sleep(0.01) + manager.send_event(event_type=event_type, data=data) + + class ComponentForTesting(Component): """Test component that implements basic functionality.""" @@ -39,6 +45,8 @@ async def test_component_message_sending(): queue = await create_event_queue() event_manager = EventManager(queue) + event_manager.register_event("on_message", "message", callback=blocking_cb) + # Create component component = ComponentForTesting() component.set_event_manager(event_manager) @@ -196,7 +204,8 @@ async def test_component_streaming_message(): """Test component's streaming message functionality.""" queue = await create_event_queue() event_manager = EventManager(queue) - event_manager.register_event("on_token", "token") + + event_manager.register_event("on_token", "token", blocking_cb) # Create a proper mock vertex with graph and flow_id vertex = MagicMock()