From 79b03ba1330fcd3198e1b7ed68a8a742ac4b884f Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 6 Dec 2024 17:25:59 +0100 Subject: [PATCH] fix: Use AsyncSession in memory (#4665) --- .../base/langflow/base/agents/agent.py | 3 +- .../base/langflow/base/agents/events.py | 37 ++- src/backend/base/langflow/base/io/chat.py | 9 +- .../langflow/base/tools/component_tool.py | 2 +- .../base/langflow/components/agents/agent.py | 6 +- .../components/deactivated/store_message.py | 8 +- .../langflow/components/helpers/memory.py | 12 +- .../components/helpers/store_message.py | 14 +- .../base/langflow/components/inputs/chat.py | 7 +- .../base/langflow/components/outputs/chat.py | 4 +- .../custom/custom_component/component.py | 34 +-- .../base/langflow/graph/vertex/types.py | 4 +- src/backend/base/langflow/memory.py | 202 +++++++++++-- src/backend/base/langflow/schema/log.py | 2 +- src/backend/base/langflow/schema/message.py | 8 + src/backend/tests/conftest.py | 12 + .../components/inputs/test_chat_input.py | 6 +- .../components/outputs/test_chat_output.py | 6 +- .../components/agents/test_agent_events.py | 68 ++--- .../inputs/test_input_components.py | 24 +- .../custom_component/test_component_events.py | 9 +- .../tests/unit/graph/graph/test_base.py | 1 + .../graph/graph/test_graph_state_model.py | 4 +- src/backend/tests/unit/test_chat_endpoint.py | 12 +- src/backend/tests/unit/test_messages.py | 273 +++++++++++++++++- .../tests/unit/test_messages_endpoints.py | 16 +- 26 files changed, 610 insertions(+), 173 deletions(-) diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index 16d977572..4480eb464 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -1,4 +1,3 @@ -import asyncio import re from abc import abstractmethod from typing import TYPE_CHECKING, cast @@ -168,7 +167,7 @@ class LCAgentComponent(Component): ) except ExceptionWithMessageError as e: msg_id = e.agent_message.id - await asyncio.to_thread(delete_message, id_=msg_id) + await delete_message(id_=msg_id) self._send_message_event(e.agent_message, category="remove_message") raise except Exception: diff --git a/src/backend/base/langflow/base/agents/events.py b/src/backend/base/langflow/base/agents/events.py index 2ea4299a6..d7bfde54b 100644 --- a/src/backend/base/langflow/base/agents/events.py +++ b/src/backend/base/langflow/base/agents/events.py @@ -1,5 +1,4 @@ # Add helper functions for each event type -import asyncio from collections.abc import AsyncIterator from time import perf_counter from typing import Any, Protocol @@ -53,7 +52,7 @@ def _calculate_duration(start_time: float) -> int: return result -def handle_on_chain_start( +async def handle_on_chain_start( event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float ) -> tuple[Message, float]: # Create content blocks if they don't exist @@ -75,7 +74,7 @@ def handle_on_chain_start( header={"title": "Input", "icon": "MessageSquare"}, ) agent_message.content_blocks[0].contents.append(text_content) - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) start_time = perf_counter() return agent_message, start_time @@ -91,7 +90,7 @@ def _extract_output_text(output: str | list) -> str: return text -def handle_on_chain_end( +async def handle_on_chain_end( event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float ) -> tuple[Message, float]: data_output = event["data"].get("output") @@ -110,12 +109,12 @@ def handle_on_chain_end( header={"title": "Output", "icon": "MessageSquare"}, ) agent_message.content_blocks[0].contents.append(text_content) - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) start_time = perf_counter() return agent_message, start_time -def handle_on_tool_start( +async def handle_on_tool_start( event: dict[str, Any], agent_message: Message, tool_blocks_map: dict[str, ToolContent], @@ -149,12 +148,12 @@ def handle_on_tool_start( tool_blocks_map[tool_key] = tool_content agent_message.content_blocks[0].contents.append(tool_content) - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) tool_blocks_map[tool_key] = agent_message.content_blocks[0].contents[-1] return agent_message, new_start_time -def handle_on_tool_end( +async def handle_on_tool_end( event: dict[str, Any], agent_message: Message, tool_blocks_map: dict[str, ToolContent], @@ -172,13 +171,13 @@ def handle_on_tool_end( tool_content.duration = duration tool_content.header = {"title": f"Executed **{tool_content.name}**", "icon": "Hammer"} - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) new_start_time = perf_counter() # Get new start time for next operation return agent_message, new_start_time return agent_message, start_time -def handle_on_tool_error( +async def handle_on_tool_error( event: dict[str, Any], agent_message: Message, tool_blocks_map: dict[str, ToolContent], @@ -194,12 +193,12 @@ def handle_on_tool_error( tool_content.error = event["data"].get("error", "Unknown error") tool_content.duration = _calculate_duration(start_time) tool_content.header = {"title": f"Error using **{tool_content.name}**", "icon": "Hammer"} - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) start_time = perf_counter() return agent_message, start_time -def handle_on_chain_stream( +async def handle_on_chain_stream( event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, @@ -211,13 +210,13 @@ def handle_on_chain_stream( if output and isinstance(output, str | list): agent_message.text = _extract_output_text(output) agent_message.properties.state = "complete" - agent_message = send_message_method(message=agent_message) + agent_message = await send_message_method(message=agent_message) start_time = perf_counter() return agent_message, start_time class ToolEventHandler(Protocol): - def __call__( + async def __call__( self, event: dict[str, Any], agent_message: Message, @@ -228,7 +227,7 @@ class ToolEventHandler(Protocol): class ChainEventHandler(Protocol): - def __call__( + async def __call__( self, event: dict[str, Any], agent_message: Message, @@ -265,7 +264,7 @@ async def process_agent_events( agent_message.properties.icon = "Bot" agent_message.properties.state = "partial" # Store the initial message - agent_message = await asyncio.to_thread(send_message_method, message=agent_message) + agent_message = await send_message_method(message=agent_message) try: # Create a mapping of run_ids to tool contents tool_blocks_map: dict[str, ToolContent] = {} @@ -273,14 +272,14 @@ async def process_agent_events( async for event in agent_executor: if event["event"] in TOOL_EVENT_HANDLERS: tool_handler = TOOL_EVENT_HANDLERS[event["event"]] - agent_message, start_time = tool_handler( + agent_message, start_time = await tool_handler( event, agent_message, tool_blocks_map, send_message_method, start_time ) elif event["event"] in CHAIN_EVENT_HANDLERS: chain_handler = CHAIN_EVENT_HANDLERS[event["event"]] - agent_message, start_time = chain_handler(event, agent_message, send_message_method, start_time) + agent_message, start_time = await chain_handler(event, agent_message, send_message_method, start_time) agent_message.properties.state = "complete" except Exception as e: raise ExceptionWithMessageError(agent_message) from e - return Message(**agent_message.model_dump()) + return await Message.create(**agent_message.model_dump()) diff --git a/src/backend/base/langflow/base/io/chat.py b/src/backend/base/langflow/base/io/chat.py index 965854a76..0e88a9e5d 100644 --- a/src/backend/base/langflow/base/io/chat.py +++ b/src/backend/base/langflow/base/io/chat.py @@ -1,7 +1,8 @@ +import asyncio from typing import cast from langflow.custom import Component -from langflow.memory import store_message +from langflow.memory import astore_message from langflow.schema import Data from langflow.schema.message import Message @@ -10,7 +11,7 @@ class ChatComponent(Component): display_name = "Chat Component" description = "Use as base for chat components." - def build_with_data( + async def build_with_data( self, *, sender: str | None = "User", @@ -20,13 +21,13 @@ class ChatComponent(Component): session_id: str | None = None, return_message: bool = False, ) -> str | Message: - message = self._create_message(input_value, sender, sender_name, files, session_id) + message = await asyncio.to_thread(self._create_message, input_value, sender, sender_name, files, session_id) message_text = message.text if not return_message else message self.status = message_text if session_id and isinstance(message, Message) and isinstance(message.text, str): flow_id = self.graph.flow_id if hasattr(self, "graph") else None - messages = store_message(message, flow_id=flow_id) + messages = await astore_message(message, flow_id=flow_id) self.status = messages self._send_messages_events(messages) diff --git a/src/backend/base/langflow/base/tools/component_tool.py b/src/backend/base/langflow/base/tools/component_tool.py index 9f098792c..fd86d0b8a 100644 --- a/src/backend/base/langflow/base/tools/component_tool.py +++ b/src/backend/base/langflow/base/tools/component_tool.py @@ -56,7 +56,7 @@ def build_description(component: Component, output: Output) -> str: return f"{output.method}({args}) - {component.description}" -def send_message_noop( +async def send_message_noop( message: Message, text: str | None = None, # noqa: ARG001 background_color: str | None = None, # noqa: ARG001 diff --git a/src/backend/base/langflow/components/agents/agent.py b/src/backend/base/langflow/components/agents/agent.py index 5fd33df29..ddd6b4022 100644 --- a/src/backend/base/langflow/components/agents/agent.py +++ b/src/backend/base/langflow/components/agents/agent.py @@ -66,7 +66,7 @@ class AgentComponent(ToolCallingAgentComponent): if llm_model is None: msg = "No language model selected" raise ValueError(msg) - self.chat_history = self.get_memory_data() + self.chat_history = await self.get_memory_data() if self.add_current_date_tool: if not isinstance(self.tools, list): # type: ignore[has-type] @@ -92,12 +92,12 @@ class AgentComponent(ToolCallingAgentComponent): agent = self.create_agent_runnable() return await self.run_agent(agent) - def get_memory_data(self): + async def get_memory_data(self): memory_kwargs = { component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs } - return MemoryComponent().set(**memory_kwargs).retrieve_messages() + return await MemoryComponent().set(**memory_kwargs).retrieve_messages() def get_llm(self): if isinstance(self.agent_llm, str): diff --git a/src/backend/base/langflow/components/deactivated/store_message.py b/src/backend/base/langflow/components/deactivated/store_message.py index e983f4e06..0d68ece68 100644 --- a/src/backend/base/langflow/components/deactivated/store_message.py +++ b/src/backend/base/langflow/components/deactivated/store_message.py @@ -1,5 +1,5 @@ from langflow.custom import CustomComponent -from langflow.memory import get_messages, store_message +from langflow.memory import aget_messages, astore_message from langflow.schema.message import Message @@ -13,12 +13,12 @@ class StoreMessageComponent(CustomComponent): "message": {"display_name": "Message"}, } - def build( + async def build( self, message: Message, ) -> Message: flow_id = self.graph.flow_id if hasattr(self, "graph") else None - store_message(message, flow_id=flow_id) - self.status = get_messages() + await astore_message(message, flow_id=flow_id) + self.status = await aget_messages() return message diff --git a/src/backend/base/langflow/components/helpers/memory.py b/src/backend/base/langflow/components/helpers/memory.py index 323c2c8c4..7dfebafd2 100644 --- a/src/backend/base/langflow/components/helpers/memory.py +++ b/src/backend/base/langflow/components/helpers/memory.py @@ -5,7 +5,7 @@ from langflow.field_typing import BaseChatMemory from langflow.helpers.data import data_to_text from langflow.inputs import HandleInput from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output -from langflow.memory import LCBuiltinChatMemory, get_messages +from langflow.memory import LCBuiltinChatMemory, aget_messages from langflow.schema import Data from langflow.schema.message import Message from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER @@ -74,7 +74,7 @@ class MemoryComponent(Component): Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"), ] - def retrieve_messages(self) -> Data: + async def retrieve_messages(self) -> Data: sender = self.sender sender_name = self.sender_name session_id = self.session_id @@ -88,7 +88,7 @@ class MemoryComponent(Component): # override session_id self.memory.session_id = session_id - stored = self.memory.messages + stored = await self.memory.aget_messages() # langchain memories are supposed to return messages in ascending order if order == "DESC": stored = stored[::-1] @@ -99,7 +99,7 @@ class MemoryComponent(Component): expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER stored = [m for m in stored if m.type == expected_type] else: - stored = get_messages( + stored = await aget_messages( sender=sender, sender_name=sender_name, session_id=session_id, @@ -109,8 +109,8 @@ class MemoryComponent(Component): self.status = stored return stored - def retrieve_messages_as_text(self) -> Message: - stored_text = data_to_text(self.template, self.retrieve_messages()) + async def retrieve_messages_as_text(self) -> Message: + stored_text = data_to_text(self.template, await self.retrieve_messages()) self.status = stored_text return Message(text=stored_text) diff --git a/src/backend/base/langflow/components/helpers/store_message.py b/src/backend/base/langflow/components/helpers/store_message.py index 388d8a476..f9178e17e 100644 --- a/src/backend/base/langflow/components/helpers/store_message.py +++ b/src/backend/base/langflow/components/helpers/store_message.py @@ -1,7 +1,7 @@ from langflow.custom import Component from langflow.inputs import HandleInput, MessageInput from langflow.inputs.inputs import MessageTextInput -from langflow.memory import get_messages, store_message +from langflow.memory import aget_messages, astore_message from langflow.schema.message import Message from langflow.template import Output from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI @@ -47,7 +47,7 @@ class StoreMessageComponent(Component): Output(display_name="Stored Messages", name="stored_messages", method="store_message"), ] - def store_message(self) -> Message: + async def store_message(self) -> Message: message = self.message message.session_id = self.session_id or message.session_id @@ -58,13 +58,15 @@ class StoreMessageComponent(Component): # override session_id self.memory.session_id = message.session_id lc_message = message.to_lc_message() - self.memory.add_messages([lc_message]) - stored = self.memory.messages + await self.memory.aadd_messages([lc_message]) + stored = await self.memory.aget_messages() stored = [Message.from_lc_message(m) for m in stored] if message.sender: stored = [m for m in stored if m.sender == message.sender] else: - store_message(message, flow_id=self.graph.flow_id) - stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender) + await astore_message(message, flow_id=self.graph.flow_id) + stored = await aget_messages( + session_id=message.session_id, sender_name=message.sender_name, sender=message.sender + ) self.status = stored return stored diff --git a/src/backend/base/langflow/components/inputs/chat.py b/src/backend/base/langflow/components/inputs/chat.py index be8e656cf..ef861c320 100644 --- a/src/backend/base/langflow/components/inputs/chat.py +++ b/src/backend/base/langflow/components/inputs/chat.py @@ -78,11 +78,12 @@ class ChatInput(ChatComponent): Output(display_name="Message", name="message", method="message_response"), ] - def message_response(self) -> Message: + async def message_response(self) -> Message: _background_color = self.background_color _text_color = self.text_color _icon = self.chat_icon - message = Message( + + message = await Message.create( text=self.input_value, sender=self.sender, sender_name=self.sender_name, @@ -91,7 +92,7 @@ class ChatInput(ChatComponent): properties={"background_color": _background_color, "text_color": _text_color, "icon": _icon}, ) if self.session_id and isinstance(message, Message) and self.should_store_message: - stored_message = self.send_message( + stored_message = await self.send_message( message, ) self.message.value = stored_message diff --git a/src/backend/base/langflow/components/outputs/chat.py b/src/backend/base/langflow/components/outputs/chat.py index b6751c231..141e09367 100644 --- a/src/backend/base/langflow/components/outputs/chat.py +++ b/src/backend/base/langflow/components/outputs/chat.py @@ -90,7 +90,7 @@ class ChatOutput(ChatComponent): source_dict["source"] = source return Source(**source_dict) - def message_response(self) -> Message: + async def message_response(self) -> Message: _source, _icon, _display_name, _source_id = self.get_properties_from_source_component() _background_color = self.background_color _text_color = self.text_color @@ -106,7 +106,7 @@ class ChatOutput(ChatComponent): message.properties.background_color = _background_color message.properties.text_color = _text_color if self.session_id and isinstance(message, Message) and self.should_store_message: - stored_message = self.send_message( + stored_message = await self.send_message( message, ) self.message.value = stored_message diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 11a1c15bf..0f1da2f0e 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -24,7 +24,7 @@ from langflow.exceptions.component import StreamingError from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output from langflow.graph.state.model import create_state_model from langflow.helpers.custom import format_type -from langflow.memory import delete_message, store_message, update_messages +from langflow.memory import astore_message, aupdate_messages, delete_message from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data from langflow.schema.message import ErrorMessage, Message @@ -847,7 +847,7 @@ class Component(CustomComponent): return await self._build_with_tracing() return await self._build_without_tracing() except StreamingError as e: - self.send_error( + await self.send_error( exception=e.cause, session_id=session_id, trace_name=getattr(self, "trace_name", None), @@ -855,7 +855,7 @@ class Component(CustomComponent): ) raise e.cause # noqa: B904 except Exception as e: - self.send_error( + await self.send_error( exception=e, session_id=session_id, source=Source(id=self._id, display_name=self.display_name, source=self.display_name), @@ -1016,10 +1016,10 @@ class Component(CustomComponent): ) ) - def send_message(self, message: Message, id_: str | None = None): + async def send_message(self, message: Message, id_: str | None = None): if (hasattr(self, "graph") and self.graph.session_id) and (message is not None and not message.session_id): message.session_id = self.graph.session_id - stored_message = self._store_message(message) + stored_message = await self._store_message(message) self._stored_message_id = stored_message.id try: @@ -1029,22 +1029,22 @@ class Component(CustomComponent): and message is not None and isinstance(message.text, AsyncIterator | Iterator) ): - complete_message = self._stream_message(message.text, stored_message) + complete_message = await self._stream_message(message.text, stored_message) stored_message.text = complete_message - stored_message = self._update_stored_message(stored_message) + stored_message = await self._update_stored_message(stored_message) else: # Only send message event for non-streaming messages self._send_message_event(stored_message, id_=id_) except Exception: # remove the message from the database - delete_message(stored_message.id) + await delete_message(stored_message.id) raise self.status = stored_message return stored_message - def _store_message(self, message: Message) -> Message: + async def _store_message(self, message: Message) -> Message: flow_id = self.graph.flow_id if hasattr(self, "graph") else None - messages = store_message(message, flow_id=flow_id) + messages = await astore_message(message, flow_id=flow_id) if len(messages) != 1: msg = "Only one message can be stored at a time." raise ValueError(msg) @@ -1073,21 +1073,21 @@ class Component(CustomComponent): and not isinstance(original_message.text, str) ) - def _update_stored_message(self, stored_message: Message) -> Message: - message_tables = update_messages(stored_message) + async def _update_stored_message(self, stored_message: Message) -> Message: + message_tables = await aupdate_messages(stored_message) if len(message_tables) != 1: msg = "Only one message can be updated at a time." raise ValueError(msg) message_table = message_tables[0] - return Message(**message_table.model_dump()) + return await Message.create(**message_table.model_dump()) - def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str: + async def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str: if not isinstance(iterator, AsyncIterator | Iterator): msg = "The message must be an iterator or an async iterator." raise TypeError(msg) if isinstance(iterator, AsyncIterator): - return run_until_complete(self._handle_async_iterator(iterator, message.id, message)) + return await self._handle_async_iterator(iterator, message.id, message) try: complete_message = "" first_chunk = True @@ -1129,7 +1129,7 @@ class Component(CustomComponent): ) return complete_message - def send_error( + async def send_error( self, exception: Exception, session_id: str, @@ -1145,7 +1145,7 @@ class Component(CustomComponent): trace_name=trace_name, source=source, ) - self.send_message(error_message) + await self.send_message(error_message) return error_message def _append_tool_to_outputs_map(self): diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index d57039939..0db434f7f 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -401,7 +401,7 @@ class InterfaceVertex(ComponentVertex): type=ArtifactType.OBJECT.value, ).model_dump() - message = Message( + message = await Message.create( text=complete_message, sender=self.params.get("sender", ""), sender_name=self.params.get("sender_name", ""), @@ -434,7 +434,7 @@ class InterfaceVertex(ComponentVertex): and hasattr(self.custom_component, "should_store_message") and hasattr(self.custom_component, "store_message") ): - self.custom_component.store_message(message) + await self.custom_component.store_message(message) await log_vertex_build( flow_id=self.graph.flow_id, vertex_id=self.id, diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 7b293c2b5..43d77b439 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -7,13 +7,40 @@ from langchain_core.messages import BaseMessage from loguru import logger from sqlalchemy import delete from sqlmodel import Session, col, select +from sqlmodel.ext.asyncio.session import AsyncSession from langflow.schema.message import Message from langflow.services.database.models.message.model import MessageRead, MessageTable -from langflow.services.deps import session_scope +from langflow.services.deps import async_session_scope, session_scope from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER +def _get_variable_query( + sender: str | None = None, + sender_name: str | None = None, + session_id: str | None = None, + order_by: str | None = "timestamp", + order: str | None = "DESC", + flow_id: UUID | None = None, + limit: int | None = None, +): + stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712 + if sender: + stmt = stmt.where(MessageTable.sender == sender) + if sender_name: + stmt = stmt.where(MessageTable.sender_name == sender_name) + if session_id: + stmt = stmt.where(MessageTable.session_id == session_id) + if flow_id: + stmt = stmt.where(MessageTable.flow_id == flow_id) + if order_by: + col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc() + stmt = stmt.order_by(col) + if limit: + stmt = stmt.limit(limit) + return stmt + + def get_messages( sender: str | None = None, sender_name: str | None = None, @@ -38,24 +65,40 @@ def get_messages( List[Data]: A list of Data objects representing the retrieved messages. """ with session_scope() as session: - stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712 - if sender: - stmt = stmt.where(MessageTable.sender == sender) - if sender_name: - stmt = stmt.where(MessageTable.sender_name == sender_name) - if session_id: - stmt = stmt.where(MessageTable.session_id == session_id) - if flow_id: - stmt = stmt.where(MessageTable.flow_id == flow_id) - if order_by: - col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc() - stmt = stmt.order_by(col) - if limit: - stmt = stmt.limit(limit) + stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) messages = session.exec(stmt) return [Message(**d.model_dump()) for d in messages] +async def aget_messages( + sender: str | None = None, + sender_name: str | None = None, + session_id: str | None = None, + order_by: str | None = "timestamp", + order: str | None = "DESC", + flow_id: UUID | None = None, + limit: int | None = None, +) -> list[Message]: + """Retrieves messages from the monitor service based on the provided filters. + + Args: + sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") + sender_name (Optional[str]): The name of the sender. + session_id (Optional[str]): The session ID associated with the messages. + order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". + order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". + flow_id (Optional[UUID]): The flow ID associated with the messages. + limit (Optional[int]): The maximum number of messages to retrieve. + + Returns: + List[Data]: A list of Data objects representing the retrieved messages. + """ + async with async_session_scope() as session: + stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) + messages = await session.exec(stmt) + return [await Message.create(**d.model_dump()) for d in messages] + + def add_messages(messages: Message | list[Message], flow_id: str | None = None): """Add a message to the monitor service.""" if not isinstance(messages, list): @@ -76,6 +119,26 @@ def add_messages(messages: Message | list[Message], flow_id: str | None = None): raise +async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None): + """Add a message to the monitor service.""" + if not isinstance(messages, list): + messages = [messages] + + if not all(isinstance(message, Message) for message in messages): + types = ", ".join([str(type(message)) for message in messages]) + msg = f"The messages must be instances of Message. Found: {types}" + raise ValueError(msg) + + try: + messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] + async with async_session_scope() as session: + messages_models = await aadd_messagetables(messages_models, session) + return [await Message.create(**message.model_dump()) for message in messages_models] + except Exception as e: + logger.exception(e) + raise + + def update_messages(messages: Message | list[Message]) -> list[Message]: if not isinstance(messages, list): messages = [messages] @@ -95,6 +158,25 @@ def update_messages(messages: Message | list[Message]) -> list[Message]: return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] +async def aupdate_messages(messages: Message | list[Message]) -> list[Message]: + if not isinstance(messages, list): + messages = [messages] + + async with async_session_scope() as session: + updated_messages: list[MessageTable] = [] + for message in messages: + msg = await session.get(MessageTable, message.id) + if msg: + msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) + session.add(msg) + await session.commit() + await session.refresh(msg) + updated_messages.append(msg) + else: + logger.warning(f"Message with id {message.id} not found") + return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] + + def add_messagetables(messages: list[MessageTable], session: Session): for message in messages: try: @@ -115,6 +197,27 @@ def add_messagetables(messages: list[MessageTable], session: Session): return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] +async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession): + try: + for message in messages: + session.add(message) + await session.commit() + for message in messages: + await session.refresh(message) + except Exception as e: + logger.exception(e) + raise + + new_messages = [] + for msg in messages: + msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type] + msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type] + msg.category = msg.category or "" + new_messages.append(msg) + + return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] + + def delete_messages(session_id: str) -> None: """Delete messages from the monitor service based on the provided session ID. @@ -129,17 +232,32 @@ def delete_messages(session_id: str) -> None: ) -def delete_message(id_: str) -> None: +async def adelete_messages(session_id: str) -> None: + """Delete messages from the monitor service based on the provided session ID. + + Args: + session_id (str): The session ID associated with the messages to delete. + """ + async with async_session_scope() as session: + stmt = ( + delete(MessageTable) + .where(col(MessageTable.session_id) == session_id) + .execution_options(synchronize_session="fetch") + ) + await session.exec(stmt) + + +async def delete_message(id_: str) -> None: """Delete a message from the monitor service based on the provided ID. Args: id_ (str): The ID of the message to delete. """ - with session_scope() as session: - message = session.get(MessageTable, id_) + async with async_session_scope() as session: + message = await session.get(MessageTable, id_) if message: - session.delete(message) - session.commit() + await session.delete(message) + await session.commit() def store_message( @@ -182,6 +300,35 @@ def store_message( return add_messages([message], flow_id=flow_id) +async def astore_message( + message: Message, + flow_id: str | None = None, +) -> list[Message]: + """Stores a message in the memory. + + Args: + message (Message): The message to store. + flow_id (Optional[str]): The flow ID associated with the message. + When running from the CustomComponent you can access this using `self.graph.flow_id`. + + Returns: + List[Message]: A list of data containing the stored message. + + Raises: + ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. + """ + if not message: + logger.warning("No message provided.") + return [] + + if not message.session_id or not message.sender or not message.sender_name: + msg = "All of session_id, sender, and sender_name must be provided." + raise ValueError(msg) + if hasattr(message, "id") and message.id: + return await aupdate_messages([message]) + return await aadd_messages([message], flow_id=flow_id) + + class LCBuiltinChatMemory(BaseChatMessageHistory): def __init__( self, @@ -198,11 +345,26 @@ class LCBuiltinChatMemory(BaseChatMessageHistory): ) return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages + async def aget_messages(self) -> list[BaseMessage]: + messages = await aget_messages( + session_id=self.session_id, + ) + return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages + def add_messages(self, messages: Sequence[BaseMessage]) -> None: for lc_message in messages: message = Message.from_lc_message(lc_message) message.session_id = self.session_id store_message(message, flow_id=self.flow_id) + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + for lc_message in messages: + message = Message.from_lc_message(lc_message) + message.session_id = self.session_id + await astore_message(message, flow_id=self.flow_id) + def clear(self) -> None: delete_messages(self.session_id) + + async def aclear(self) -> None: + await adelete_messages(self.session_id) diff --git a/src/backend/base/langflow/schema/log.py b/src/backend/base/langflow/schema/log.py index e4a272d8c..402106a50 100644 --- a/src/backend/base/langflow/schema/log.py +++ b/src/backend/base/langflow/schema/log.py @@ -14,7 +14,7 @@ class LogFunctionType(Protocol): class SendMessageFunctionType(Protocol): - def __call__( + async def __call__( self, message: Message | None = None, text: str | None = None, diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index e3ef7e1e7..7b5e0f520 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import re import traceback @@ -267,6 +268,13 @@ class Message(Data): instance.messages = instance.prompt.get("kwargs", {}).get("messages", []) return instance + @classmethod + async def create(cls, **kwargs): + """If files are present, create the message in a separate thread as is_image_file is blocking.""" + if "files" in kwargs: + return await asyncio.to_thread(cls, **kwargs) + return cls(**kwargs) + class DefaultModel(BaseModel): class Config: diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index a855fb988..fa149895a 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -29,6 +29,7 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service from loguru import logger +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import selectinload from sqlmodel import Session, SQLModel, create_engine, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -151,6 +152,17 @@ def session_fixture(): SQLModel.metadata.drop_all(engine) # Add this line to clean up tables +@pytest.fixture +async def async_session(): + engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with AsyncSession(engine) as session: + yield session + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.drop_all) + + class Config: broker_url = "redis://localhost:6379/0" result_backend = "redis://localhost:6379/0" diff --git a/src/backend/tests/integration/components/inputs/test_chat_input.py b/src/backend/tests/integration/components/inputs/test_chat_input.py index 698e8d266..dbad196fa 100644 --- a/src/backend/tests/integration/components/inputs/test_chat_input.py +++ b/src/backend/tests/integration/components/inputs/test_chat_input.py @@ -1,5 +1,5 @@ from langflow.components.inputs import ChatInput -from langflow.memory import get_messages +from langflow.memory import aget_messages from langflow.schema.message import Message from tests.integration.utils import run_single_component @@ -38,7 +38,7 @@ async def test_do_not_store_messages(): assert outputs["message"].text == "hello" assert outputs["message"].session_id == session_id - assert len(get_messages(session_id=session_id)) == 1 + assert len(await aget_messages(session_id=session_id)) == 1 session_id = "test-session-id-another" outputs = await run_single_component( @@ -48,4 +48,4 @@ async def test_do_not_store_messages(): assert outputs["message"].text == "hello" assert outputs["message"].session_id == session_id - assert len(get_messages(session_id=session_id)) == 0 + assert len(await aget_messages(session_id=session_id)) == 0 diff --git a/src/backend/tests/integration/components/outputs/test_chat_output.py b/src/backend/tests/integration/components/outputs/test_chat_output.py index dfe113cc9..df3f00046 100644 --- a/src/backend/tests/integration/components/outputs/test_chat_output.py +++ b/src/backend/tests/integration/components/outputs/test_chat_output.py @@ -1,5 +1,5 @@ from langflow.components.outputs import ChatOutput -from langflow.memory import get_messages +from langflow.memory import aget_messages from langflow.schema.message import Message from tests.integration.utils import run_single_component @@ -29,7 +29,7 @@ async def test_do_not_store_message(): assert isinstance(outputs["message"], Message) assert outputs["message"].text == "hello" - assert len(get_messages(session_id=session_id)) == 1 + assert len(await aget_messages(session_id=session_id)) == 1 session_id = "test-session-id-another" outputs = await run_single_component( @@ -38,4 +38,4 @@ async def test_do_not_store_message(): assert isinstance(outputs["message"], Message) assert outputs["message"].text == "hello" - assert len(get_messages(session_id=session_id)) == 0 + assert len(await aget_messages(session_id=session_id)) == 0 diff --git a/src/backend/tests/unit/components/agents/test_agent_events.py b/src/backend/tests/unit/components/agents/test_agent_events.py index e1f76bafc..c1342135a 100644 --- a/src/backend/tests/unit/components/agents/test_agent_events.py +++ b/src/backend/tests/unit/components/agents/test_agent_events.py @@ -1,6 +1,6 @@ from collections.abc import AsyncIterator from typing import Any -from unittest.mock import MagicMock +from unittest.mock import AsyncMock from langchain_core.agents import AgentFinish from langflow.base.agents.agent import process_agent_events @@ -26,7 +26,7 @@ async def create_event_iterator(events: list[dict[str, Any]]) -> AsyncIterator[d async def test_chain_start_event(): """Test handling of on_chain_start event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) events = [ {"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0} @@ -51,7 +51,7 @@ async def test_chain_start_event(): async def test_chain_end_event(): """Test handling of on_chain_end event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) # Create a mock AgentFinish output output = AgentFinish(return_values={"output": "final output"}, log="test log") @@ -77,7 +77,7 @@ async def test_chain_end_event(): async def test_tool_start_event(): """Test handling of on_tool_start event.""" - send_message = MagicMock() + send_message = AsyncMock() # Set up the send_message mock to return the modified message def update_message(message): @@ -116,7 +116,7 @@ async def test_tool_start_event(): async def test_tool_end_event(): """Test handling of on_tool_end event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) events = [ { @@ -151,7 +151,7 @@ async def test_tool_end_event(): async def test_tool_error_event(): """Test handling of on_tool_error event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) events = [ { @@ -187,7 +187,7 @@ async def test_tool_error_event(): async def test_chain_stream_event(): """Test handling of on_chain_stream event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) events = [{"event": "on_chain_stream", "data": {"chunk": {"output": "streamed output"}}, "start_time": 0}] agent_message = Message( @@ -205,7 +205,7 @@ async def test_chain_stream_event(): async def test_multiple_events(): """Test handling of multiple events in sequence.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) # Create a mock AgentFinish output instead of MockOutput output = AgentFinish(return_values={"output": "final output"}, log="test log") @@ -248,7 +248,7 @@ async def test_multiple_events(): async def test_unknown_event(): """Test handling of unknown event type.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -273,7 +273,7 @@ async def test_unknown_event(): async def test_handle_on_chain_start_with_input(): """Test handle_on_chain_start with input.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -282,7 +282,7 @@ async def test_handle_on_chain_start_with_input(): ) event = {"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0} - updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert len(updated_message.content_blocks) == 1 @@ -292,7 +292,7 @@ async def test_handle_on_chain_start_with_input(): async def test_handle_on_chain_start_no_input(): """Test handle_on_chain_start without input.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -301,7 +301,7 @@ async def test_handle_on_chain_start_no_input(): ) event = {"event": "on_chain_start", "data": {}, "start_time": 0} - updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert len(updated_message.content_blocks) == 1 @@ -311,7 +311,7 @@ async def test_handle_on_chain_start_no_input(): async def test_handle_on_chain_end_with_output(): """Test handle_on_chain_end with output.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -322,7 +322,7 @@ async def test_handle_on_chain_end_with_output(): output = AgentFinish(return_values={"output": "final output"}, log="test log") event = {"event": "on_chain_end", "data": {"output": output}, "start_time": 0} - updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert updated_message.properties.state == "complete" @@ -332,7 +332,7 @@ async def test_handle_on_chain_end_with_output(): async def test_handle_on_chain_end_no_output(): """Test handle_on_chain_end without output key in data.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -341,7 +341,7 @@ async def test_handle_on_chain_end_no_output(): ) event = {"event": "on_chain_end", "data": {}, "start_time": 0} - updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert updated_message.properties.state == "partial" @@ -351,7 +351,7 @@ async def test_handle_on_chain_end_no_output(): async def test_handle_on_chain_end_empty_data(): """Test handle_on_chain_end with empty data.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -360,7 +360,7 @@ async def test_handle_on_chain_end_empty_data(): ) event = {"event": "on_chain_end", "data": {"output": None}, "start_time": 0} - updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert updated_message.properties.state == "partial" @@ -370,7 +370,7 @@ async def test_handle_on_chain_end_empty_data(): async def test_handle_on_chain_end_with_empty_return_values(): """Test handle_on_chain_end with empty return_values.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -384,7 +384,7 @@ async def test_handle_on_chain_end_with_empty_return_values(): event = {"event": "on_chain_end", "data": {"output": MockOutputEmptyReturnValues()}, "start_time": 0} - updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0) assert updated_message.properties.icon == "Bot" assert updated_message.properties.state == "partial" @@ -394,7 +394,7 @@ async def test_handle_on_chain_end_with_empty_return_values(): async def test_handle_on_tool_start(): """Test handle_on_tool_start event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) tool_blocks_map = {} agent_message = Message( sender=MESSAGE_SENDER_AI, @@ -410,7 +410,7 @@ async def test_handle_on_tool_start(): "start_time": 0, } - updated_message, start_time = handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0) + updated_message, start_time = await handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0) assert len(updated_message.content_blocks) == 1 assert len(updated_message.content_blocks[0].contents) > 0 @@ -426,7 +426,7 @@ async def test_handle_on_tool_start(): async def test_handle_on_tool_end(): """Test handle_on_tool_end event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) tool_blocks_map = {} agent_message = Message( sender=MESSAGE_SENDER_AI, @@ -441,7 +441,7 @@ async def test_handle_on_tool_end(): "run_id": "test_run", "data": {"input": {"query": "tool input"}}, } - agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0) + agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0) end_event = { "event": "on_tool_end", @@ -451,7 +451,7 @@ async def test_handle_on_tool_end(): "start_time": 0, } - updated_message, start_time = handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0) + updated_message, start_time = await handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0) f"{end_event['name']}_{end_event['run_id']}" tool_content = updated_message.content_blocks[0].contents[-1] @@ -463,7 +463,7 @@ async def test_handle_on_tool_end(): async def test_handle_on_tool_error(): """Test handle_on_tool_error event.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) tool_blocks_map = {} agent_message = Message( sender=MESSAGE_SENDER_AI, @@ -478,7 +478,7 @@ async def test_handle_on_tool_error(): "run_id": "test_run", "data": {"input": {"query": "tool input"}}, } - agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0) + agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0) error_event = { "event": "on_tool_error", @@ -488,7 +488,9 @@ async def test_handle_on_tool_error(): "start_time": 0, } - updated_message, start_time = handle_on_tool_error(error_event, agent_message, tool_blocks_map, send_message, 0.0) + updated_message, start_time = await handle_on_tool_error( + error_event, agent_message, tool_blocks_map, send_message, 0.0 + ) tool_content = updated_message.content_blocks[0].contents[-1] assert tool_content.name == "test_tool" @@ -500,7 +502,7 @@ async def test_handle_on_tool_error(): async def test_handle_on_chain_stream_with_output(): """Test handle_on_chain_stream with output.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -512,7 +514,7 @@ async def test_handle_on_chain_stream_with_output(): "data": {"chunk": {"output": "streamed output"}}, } - updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0) assert updated_message.text == "streamed output" assert updated_message.properties.state == "complete" @@ -521,7 +523,7 @@ async def test_handle_on_chain_stream_with_output(): async def test_handle_on_chain_stream_no_output(): """Test handle_on_chain_stream without output.""" - send_message = MagicMock(side_effect=lambda message: message) + send_message = AsyncMock(side_effect=lambda message: message) agent_message = Message( sender=MESSAGE_SENDER_AI, sender_name="Agent", @@ -534,7 +536,7 @@ async def test_handle_on_chain_stream_no_output(): "data": {"chunk": {}}, } - updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0) + updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0) assert updated_message.text == "" assert updated_message.properties.state == "partial" diff --git a/src/backend/tests/unit/components/inputs/test_input_components.py b/src/backend/tests/unit/components/inputs/test_input_components.py index d8dcc7fd7..552b9e78b 100644 --- a/src/backend/tests/unit/components/inputs/test_input_components.py +++ b/src/backend/tests/unit/components/inputs/test_input_components.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from langflow.components.inputs import ChatInput, TextInputComponent from langflow.schema.message import Message @@ -36,10 +38,10 @@ class TestChatInput(ComponentTestBaseWithClient): {"version": "1.0.19", "module": "inputs", "file_name": "ChatInput"}, ] - def test_message_response(self, component_class, default_kwargs): + async def test_message_response(self, component_class, default_kwargs): """Test that the message_response method returns a valid Message object.""" component = component_class(**default_kwargs) - message = component.message_response() + message = await component.message_response() assert isinstance(message, Message) assert message.text == default_kwargs["input_value"] @@ -58,7 +60,7 @@ class TestChatInput(ComponentTestBaseWithClient): "targets": [], } - def test_message_response_ai_sender(self, component_class): + async def test_message_response_ai_sender(self, component_class): """Test message response with AI sender type.""" kwargs = { "input_value": "I am an AI assistant", @@ -67,13 +69,13 @@ class TestChatInput(ComponentTestBaseWithClient): "session_id": "test_session_123", } component = component_class(**kwargs) - message = component.message_response() + message = await component.message_response() assert isinstance(message, Message) assert message.sender == MESSAGE_SENDER_AI assert message.sender_name == "AI Assistant" - def test_message_response_without_session(self, component_class): + async def test_message_response_without_session(self, component_class): """Test message response without session ID.""" kwargs = { "input_value": "Test message", @@ -82,16 +84,16 @@ class TestChatInput(ComponentTestBaseWithClient): "session_id": "", # Empty session ID } component = component_class(**kwargs) - message = component.message_response() + message = await component.message_response() assert isinstance(message, Message) assert message.session_id == "" - def test_message_response_with_files(self, component_class, tmp_path): + async def test_message_response_with_files(self, component_class, tmp_path): """Test message response with file attachments.""" # Create a temporary test file test_file = tmp_path / "test.txt" - test_file.write_text("Test content") + await asyncio.to_thread(test_file.write_text, "Test content") kwargs = { "input_value": "Message with file", @@ -101,13 +103,13 @@ class TestChatInput(ComponentTestBaseWithClient): "files": [str(test_file)], } component = component_class(**kwargs) - message = component.message_response() + message = await component.message_response() assert isinstance(message, Message) assert len(message.files) == 1 assert message.files[0] == str(test_file) - def test_message_storage_disabled(self, component_class): + async def test_message_storage_disabled(self, component_class): """Test message response when storage is disabled.""" kwargs = { "input_value": "Test message", @@ -117,7 +119,7 @@ class TestChatInput(ComponentTestBaseWithClient): "session_id": "test_session_123", } component = component_class(**kwargs) - message = component.message_response() + message = await component.message_response() assert isinstance(message, Message) # The message should still be created but not stored diff --git a/src/backend/tests/unit/custom/custom_component/test_component_events.py b/src/backend/tests/unit/custom/custom_component/test_component_events.py index a987fd1af..52490526f 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component_events.py +++ b/src/backend/tests/unit/custom/custom_component/test_component_events.py @@ -52,7 +52,7 @@ async def test_component_message_sending(): ) # Send the message - sent_message = await asyncio.to_thread(component.send_message, message) + sent_message = await component.send_message(message) # Verify the message was sent assert sent_message.id is not None @@ -85,7 +85,7 @@ async def test_component_tool_output(): ) # Send the message - sent_message = await asyncio.to_thread(component.send_message, message) + sent_message = await component.send_message(message) # Verify the message was stored and processed assert sent_message.id is not None @@ -112,8 +112,7 @@ async def test_component_error_handling(): msg = "Test error" raise CustomError(msg) except CustomError as e: - sent_message = await asyncio.to_thread( - component.send_error, + sent_message = await component.send_error( exception=e, session_id="test_session", trace_name="test_trace", @@ -227,7 +226,7 @@ async def test_component_streaming_message(): ) # Send the streaming message - sent_message = await asyncio.to_thread(component.send_message, message) + sent_message = await component.send_message(message) # Verify the message assert sent_message.id is not None diff --git a/src/backend/tests/unit/graph/graph/test_base.py b/src/backend/tests/unit/graph/graph/test_base.py index 75b9aa9fe..c32ef7f47 100644 --- a/src/backend/tests/unit/graph/graph/test_base.py +++ b/src/backend/tests/unit/graph/graph/test_base.py @@ -55,6 +55,7 @@ async def test_graph_with_edge(): async def test_graph_functional(): chat_input = ChatInput(_id="chat_input") + chat_input.set(should_store_message=False) chat_output = ChatOutput(input_value="test", _id="chat_output") chat_output.set(sender_name=chat_input.message_response) graph = await asyncio.to_thread(Graph, chat_input, chat_output) diff --git a/src/backend/tests/unit/graph/graph/test_graph_state_model.py b/src/backend/tests/unit/graph/graph/test_graph_state_model.py index 41a4e145d..f9fea89a8 100644 --- a/src/backend/tests/unit/graph/graph/test_graph_state_model.py +++ b/src/backend/tests/unit/graph/graph/test_graph_state_model.py @@ -75,9 +75,9 @@ def test_graph_functional_start_graph_state_update(): def test_graph_state_model_serialization(): chat_input = ChatInput(_id="chat_input") - chat_input.set(input_value="Test Sender Name") + chat_input.set(input_value="Test Sender Name", should_store_message=False) chat_output = ChatOutput(input_value="test", _id="chat_output") - chat_output.set(sender_name=chat_input.message_response) + chat_output.set(sender_name=chat_input.message_response, should_store_message=False) graph = Graph(chat_input, chat_output) graph.prepare() diff --git a/src/backend/tests/unit/test_chat_endpoint.py b/src/backend/tests/unit/test_chat_endpoint.py index c892da463..7a745b239 100644 --- a/src/backend/tests/unit/test_chat_endpoint.py +++ b/src/backend/tests/unit/test_chat_endpoint.py @@ -2,7 +2,7 @@ import json from uuid import UUID import pytest -from langflow.memory import get_messages +from langflow.memory import aget_messages from langflow.services.database.models.flow import FlowCreate, FlowUpdate from orjson import orjson @@ -14,7 +14,7 @@ async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers) async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r: await consume_and_assert_stream(r) - check_messages(flow_id) + await check_messages(flow_id) @pytest.mark.benchmark @@ -28,7 +28,7 @@ async def test_build_flow_from_request_data(client, json_memory_chatbot_no_llm, ) as r: await consume_and_assert_stream(r) - check_messages(flow_id) + await check_messages(flow_id) async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, logged_in_headers): @@ -47,11 +47,11 @@ async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, l async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r: await consume_and_assert_stream(r) - check_messages(flow_id) + await check_messages(flow_id) -def check_messages(flow_id): - messages = get_messages(flow_id=UUID(flow_id), order="ASC") +async def check_messages(flow_id): + messages = await aget_messages(flow_id=UUID(flow_id), order="ASC") assert len(messages) == 2 assert messages[0].session_id == flow_id assert messages[0].sender == "User" diff --git a/src/backend/tests/unit/test_messages.py b/src/backend/tests/unit/test_messages.py index 579016391..808caa141 100644 --- a/src/backend/tests/unit/test_messages.py +++ b/src/backend/tests/unit/test_messages.py @@ -3,8 +3,14 @@ from uuid import UUID, uuid4 import pytest from langflow.memory import ( + aadd_messages, + aadd_messagetables, add_messages, add_messagetables, + adelete_messages, + aget_messages, + astore_message, + aupdate_messages, delete_messages, get_messages, store_message, @@ -18,29 +24,29 @@ from langflow.schema.properties import Properties, Source # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import session_scope +from langflow.services.deps import async_session_scope from langflow.services.tracing.utils import convert_to_langchain_type @pytest.fixture -def created_message(): - with session_scope() as session: +async def created_message(): + async with async_session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) - messagetables = add_messagetables([messagetable], session) + messagetables = await aadd_messagetables([messagetable], session) return MessageRead.model_validate(messagetables[0], from_attributes=True) @pytest.fixture -def created_messages(session): # noqa: ARG001 - with session_scope() as _session: +async def created_messages(session): # noqa: ARG001 + async with async_session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - messagetables = add_messagetables(messagetables, _session) + messagetables = await aadd_messagetables(messagetables, _session) return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables] @@ -58,6 +64,20 @@ def test_get_messages(): assert messages[1].text == "Test message 2" +@pytest.mark.usefixtures("client") +async def test_aget_messages(): + await aadd_messages( + [ + Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), + Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), + ] + ) + messages = await aget_messages(sender="User", session_id="session_id2", limit=2) + assert len(messages) == 2 + assert messages[0].text == "Test message 1" + assert messages[1].text == "Test message 2" + + @pytest.mark.usefixtures("client") def test_add_messages(): message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") @@ -66,6 +86,14 @@ def test_add_messages(): assert messages[0].text == "New Test message" +@pytest.mark.usefixtures("client") +async def test_aadd_messages(): + message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") + messages = await aadd_messages(message) + assert len(messages) == 1 + assert messages[0].text == "New Test message" + + @pytest.mark.usefixtures("client") def test_add_messagetables(session): messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] @@ -75,17 +103,53 @@ def test_add_messagetables(session): @pytest.mark.usefixtures("client") -def test_delete_messages(session): - session_id = "session_id2" +async def test_aadd_messagetables(async_session): + messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] + added_messages = await aadd_messagetables(messages, async_session) + assert len(added_messages) == 1 + assert added_messages[0].text == "New Test message" + + +@pytest.mark.usefixtures("client") +def test_delete_messages(): + session_id = "new_session_id" + message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) + add_messages([message]) + messages = get_messages(sender="User", session_id=session_id) + assert len(messages) == 1 delete_messages(session_id) - messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all() + messages = get_messages(sender="User", session_id=session_id) + assert len(messages) == 0 + + +@pytest.mark.usefixtures("client") +async def test_adelete_messages(): + session_id = "new_session_id" + message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) + await aadd_messages([message]) + messages = await aget_messages(sender="User", session_id=session_id) + assert len(messages) == 1 + await adelete_messages(session_id) + messages = await aget_messages(sender="User", session_id=session_id) assert len(messages) == 0 @pytest.mark.usefixtures("client") def test_store_message(): - message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id") - stored_messages = store_message(message) + session_id = "stored_session_id" + message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) + store_message(message) + stored_messages = get_messages(sender="User", session_id=session_id) + assert len(stored_messages) == 1 + assert stored_messages[0].text == "Stored message" + + +@pytest.mark.usefixtures("client") +async def test_astore_message(): + session_id = "stored_session_id" + message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) + await astore_message(message) + stored_messages = await aget_messages(sender="User", session_id=session_id) assert len(stored_messages) == 1 assert stored_messages[0].text == "Stored message" @@ -298,3 +362,188 @@ def test_update_message_with_nested_properties(created_message): assert updated[0].properties.allow_markdown is True assert updated[0].properties.state == "complete" assert updated[0].properties.targets == [] + + +@pytest.mark.usefixtures("client") +async def test_aupdate_single_message(created_message): + # Modify the message + created_message.text = "Updated message" + updated = await aupdate_messages(created_message) + + assert len(updated) == 1 + assert updated[0].text == "Updated message" + assert updated[0].id == created_message.id + + +@pytest.mark.usefixtures("client") +async def test_aupdate_multiple_messages(created_messages): + # Modify the messages + for i, message in enumerate(created_messages): + message.text = f"Updated message {i}" + + updated = await aupdate_messages(created_messages) + + assert len(updated) == len(created_messages) + for i, message in enumerate(updated): + assert message.text == f"Updated message {i}" + assert message.id == created_messages[i].id + + +@pytest.mark.usefixtures("client") +async def test_aupdate_nonexistent_message(): + # Create a message with a non-existent UUID + message = MessageRead( + id=uuid4(), # Generate a random UUID that won't exist in the database + text="Test message", + sender="User", + sender_name="User", + session_id="session_id", + flow_id=uuid4(), + ) + + updated = await aupdate_messages(message) + assert len(updated) == 0 + + +@pytest.mark.usefixtures("client") +async def test_aupdate_mixed_messages(created_messages): + # Create a mix of existing and non-existing messages + nonexistent_message = MessageRead( + id=uuid4(), # Generate a random UUID that won't exist in the database + text="Test message", + sender="User", + sender_name="User", + session_id="session_id", + flow_id=uuid4(), + ) + + messages_to_update = created_messages[:1] + [nonexistent_message] + created_messages[0].text = "Updated existing message" + + updated = await aupdate_messages(messages_to_update) + + assert len(updated) == 1 + assert updated[0].text == "Updated existing message" + assert updated[0].id == created_messages[0].id + assert isinstance(updated[0].id, UUID) # Verify ID is UUID type + + +@pytest.mark.usefixtures("client") +async def test_aupdate_message_with_timestamp(created_message): + # Set a specific timestamp + new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + created_message.timestamp = new_timestamp + created_message.text = "Updated message with timestamp" + + updated = await aupdate_messages(created_message) + + assert len(updated) == 1 + assert updated[0].text == "Updated message with timestamp" + + # Compare timestamps without timezone info since DB doesn't preserve it + assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None) + assert updated[0].id == created_message.id + + +@pytest.mark.usefixtures("client") +async def test_aupdate_multiple_messages_with_timestamps(created_messages): + # Modify messages with different timestamps + for i, message in enumerate(created_messages): + message.text = f"Updated message {i}" + message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) + + updated = await aupdate_messages(created_messages) + + assert len(updated) == len(created_messages) + for i, message in enumerate(updated): + assert message.text == f"Updated message {i}" + # Compare timestamps without timezone info + expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) + assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None) + assert message.id == created_messages[i].id + + +@pytest.mark.usefixtures("client") +async def test_aupdate_message_with_content_blocks(created_message): + # Create a content block using proper models + text_content = TextContent( + type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"} + ) + + tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10) + + content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True) + + created_message.content_blocks = [content_block] + created_message.text = "Message with content blocks" + + updated = await aupdate_messages(created_message) + + assert len(updated) == 1 + assert updated[0].text == "Message with content blocks" + assert len(updated[0].content_blocks) == 1 + + # Verify the content block structure + updated_block = updated[0].content_blocks[0] + assert updated_block.title == "Test Block" + assert len(updated_block.contents) == 2 + + # Verify text content + text_content = updated_block.contents[0] + assert text_content.type == "text" + assert text_content.text == "Test content" + assert text_content.duration == 5 + assert text_content.header["title"] == "Test Header" + + # Verify tool content + tool_content = updated_block.contents[1] + assert tool_content.type == "tool_use" + assert tool_content.name == "test_tool" + assert tool_content.tool_input == {"param": "value"} + assert tool_content.duration == 10 + + +@pytest.mark.usefixtures("client") +async def test_aupdate_message_with_nested_properties(created_message): + # Create a text content with nested properties + text_content = TextContent( + type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15 + ) + + content_block = ContentBlock( + title="Test Properties", + contents=[text_content], + allow_markdown=True, + media_url=["http://example.com/image.jpg"], + ) + + # Set properties according to the Properties model structure + created_message.properties = Properties( + text_color="blue", + background_color="white", + edited=False, + source=Source(id="test_id", display_name="Test Source", source="test"), + icon="TestIcon", + allow_markdown=True, + state="complete", + targets=[], + ) + created_message.text = "Message with nested properties" + created_message.content_blocks = [content_block] + + updated = await aupdate_messages(created_message) + + assert len(updated) == 1 + assert updated[0].text == "Message with nested properties" + + # Verify the properties were properly serialized and stored + assert updated[0].properties.text_color == "blue" + assert updated[0].properties.background_color == "white" + assert updated[0].properties.edited is False + assert updated[0].properties.source.id == "test_id" + assert updated[0].properties.source.display_name == "Test Source" + assert updated[0].properties.source.source == "test" + assert updated[0].properties.icon == "TestIcon" + assert updated[0].properties.allow_markdown is True + assert updated[0].properties.state == "complete" + assert updated[0].properties.targets == [] diff --git a/src/backend/tests/unit/test_messages_endpoints.py b/src/backend/tests/unit/test_messages_endpoints.py index a5bddee9c..0c155de71 100644 --- a/src/backend/tests/unit/test_messages_endpoints.py +++ b/src/backend/tests/unit/test_messages_endpoints.py @@ -2,33 +2,33 @@ from uuid import UUID import pytest from httpx import AsyncClient -from langflow.memory import add_messagetables +from langflow.memory import aadd_messagetables # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate from langflow.services.database.models.message.model import MessageTable -from langflow.services.deps import session_scope +from langflow.services.deps import async_session_scope @pytest.fixture -def created_message(): - with session_scope() as session: +async def created_message(): + async with async_session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) - messagetables = add_messagetables([messagetable], session) + messagetables = await aadd_messagetables([messagetable], session) return MessageRead.model_validate(messagetables[0], from_attributes=True) @pytest.fixture -def created_messages(session): # noqa: ARG001 - with session_scope() as _session: +async def created_messages(session): # noqa: ARG001 + async with async_session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - return add_messagetables(messagetables, _session) + return await aadd_messagetables(messagetables, _session) @pytest.mark.api_key_required