fix: Use AsyncSession in memory (#4665)

This commit is contained in:
Christophe Bornet 2024-12-06 17:25:59 +01:00 committed by GitHub
commit 79b03ba133
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 610 additions and 173 deletions

View file

@ -1,4 +1,3 @@
import asyncio
import re
from abc import abstractmethod
from typing import TYPE_CHECKING, cast
@ -168,7 +167,7 @@ class LCAgentComponent(Component):
)
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id
await asyncio.to_thread(delete_message, id_=msg_id)
await delete_message(id_=msg_id)
self._send_message_event(e.agent_message, category="remove_message")
raise
except Exception:

View file

@ -1,5 +1,4 @@
# Add helper functions for each event type
import asyncio
from collections.abc import AsyncIterator
from time import perf_counter
from typing import Any, Protocol
@ -53,7 +52,7 @@ def _calculate_duration(start_time: float) -> int:
return result
def handle_on_chain_start(
async def handle_on_chain_start(
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
) -> tuple[Message, float]:
# Create content blocks if they don't exist
@ -75,7 +74,7 @@ def handle_on_chain_start(
header={"title": "Input", "icon": "MessageSquare"},
)
agent_message.content_blocks[0].contents.append(text_content)
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time
@ -91,7 +90,7 @@ def _extract_output_text(output: str | list) -> str:
return text
def handle_on_chain_end(
async def handle_on_chain_end(
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
) -> tuple[Message, float]:
data_output = event["data"].get("output")
@ -110,12 +109,12 @@ def handle_on_chain_end(
header={"title": "Output", "icon": "MessageSquare"},
)
agent_message.content_blocks[0].contents.append(text_content)
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time
def handle_on_tool_start(
async def handle_on_tool_start(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
@ -149,12 +148,12 @@ def handle_on_tool_start(
tool_blocks_map[tool_key] = tool_content
agent_message.content_blocks[0].contents.append(tool_content)
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
tool_blocks_map[tool_key] = agent_message.content_blocks[0].contents[-1]
return agent_message, new_start_time
def handle_on_tool_end(
async def handle_on_tool_end(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
@ -172,13 +171,13 @@ def handle_on_tool_end(
tool_content.duration = duration
tool_content.header = {"title": f"Executed **{tool_content.name}**", "icon": "Hammer"}
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
new_start_time = perf_counter() # Get new start time for next operation
return agent_message, new_start_time
return agent_message, start_time
def handle_on_tool_error(
async def handle_on_tool_error(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
@ -194,12 +193,12 @@ def handle_on_tool_error(
tool_content.error = event["data"].get("error", "Unknown error")
tool_content.duration = _calculate_duration(start_time)
tool_content.header = {"title": f"Error using **{tool_content.name}**", "icon": "Hammer"}
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time
def handle_on_chain_stream(
async def handle_on_chain_stream(
event: dict[str, Any],
agent_message: Message,
send_message_method: SendMessageFunctionType,
@ -211,13 +210,13 @@ def handle_on_chain_stream(
if output and isinstance(output, str | list):
agent_message.text = _extract_output_text(output)
agent_message.properties.state = "complete"
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time
class ToolEventHandler(Protocol):
def __call__(
async def __call__(
self,
event: dict[str, Any],
agent_message: Message,
@ -228,7 +227,7 @@ class ToolEventHandler(Protocol):
class ChainEventHandler(Protocol):
def __call__(
async def __call__(
self,
event: dict[str, Any],
agent_message: Message,
@ -265,7 +264,7 @@ async def process_agent_events(
agent_message.properties.icon = "Bot"
agent_message.properties.state = "partial"
# Store the initial message
agent_message = await asyncio.to_thread(send_message_method, message=agent_message)
agent_message = await send_message_method(message=agent_message)
try:
# Create a mapping of run_ids to tool contents
tool_blocks_map: dict[str, ToolContent] = {}
@ -273,14 +272,14 @@ async def process_agent_events(
async for event in agent_executor:
if event["event"] in TOOL_EVENT_HANDLERS:
tool_handler = TOOL_EVENT_HANDLERS[event["event"]]
agent_message, start_time = tool_handler(
agent_message, start_time = await tool_handler(
event, agent_message, tool_blocks_map, send_message_method, start_time
)
elif event["event"] in CHAIN_EVENT_HANDLERS:
chain_handler = CHAIN_EVENT_HANDLERS[event["event"]]
agent_message, start_time = chain_handler(event, agent_message, send_message_method, start_time)
agent_message, start_time = await chain_handler(event, agent_message, send_message_method, start_time)
agent_message.properties.state = "complete"
except Exception as e:
raise ExceptionWithMessageError(agent_message) from e
return Message(**agent_message.model_dump())
return await Message.create(**agent_message.model_dump())

View file

@ -1,7 +1,8 @@
import asyncio
from typing import cast
from langflow.custom import Component
from langflow.memory import store_message
from langflow.memory import astore_message
from langflow.schema import Data
from langflow.schema.message import Message
@ -10,7 +11,7 @@ class ChatComponent(Component):
display_name = "Chat Component"
description = "Use as base for chat components."
def build_with_data(
async def build_with_data(
self,
*,
sender: str | None = "User",
@ -20,13 +21,13 @@ class ChatComponent(Component):
session_id: str | None = None,
return_message: bool = False,
) -> str | Message:
message = self._create_message(input_value, sender, sender_name, files, session_id)
message = await asyncio.to_thread(self._create_message, input_value, sender, sender_name, files, session_id)
message_text = message.text if not return_message else message
self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
messages = store_message(message, flow_id=flow_id)
messages = await astore_message(message, flow_id=flow_id)
self.status = messages
self._send_messages_events(messages)

View file

@ -56,7 +56,7 @@ def build_description(component: Component, output: Output) -> str:
return f"{output.method}({args}) - {component.description}"
def send_message_noop(
async def send_message_noop(
message: Message,
text: str | None = None, # noqa: ARG001
background_color: str | None = None, # noqa: ARG001

View file

@ -66,7 +66,7 @@ class AgentComponent(ToolCallingAgentComponent):
if llm_model is None:
msg = "No language model selected"
raise ValueError(msg)
self.chat_history = self.get_memory_data()
self.chat_history = await self.get_memory_data()
if self.add_current_date_tool:
if not isinstance(self.tools, list): # type: ignore[has-type]
@ -92,12 +92,12 @@ class AgentComponent(ToolCallingAgentComponent):
agent = self.create_agent_runnable()
return await self.run_agent(agent)
def get_memory_data(self):
async def get_memory_data(self):
memory_kwargs = {
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
}
return MemoryComponent().set(**memory_kwargs).retrieve_messages()
return await MemoryComponent().set(**memory_kwargs).retrieve_messages()
def get_llm(self):
if isinstance(self.agent_llm, str):

View file

@ -1,5 +1,5 @@
from langflow.custom import CustomComponent
from langflow.memory import get_messages, store_message
from langflow.memory import aget_messages, astore_message
from langflow.schema.message import Message
@ -13,12 +13,12 @@ class StoreMessageComponent(CustomComponent):
"message": {"display_name": "Message"},
}
def build(
async def build(
self,
message: Message,
) -> Message:
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
store_message(message, flow_id=flow_id)
self.status = get_messages()
await astore_message(message, flow_id=flow_id)
self.status = await aget_messages()
return message

View file

@ -5,7 +5,7 @@ from langflow.field_typing import BaseChatMemory
from langflow.helpers.data import data_to_text
from langflow.inputs import HandleInput
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
from langflow.memory import LCBuiltinChatMemory, get_messages
from langflow.memory import LCBuiltinChatMemory, aget_messages
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
@ -74,7 +74,7 @@ class MemoryComponent(Component):
Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"),
]
def retrieve_messages(self) -> Data:
async def retrieve_messages(self) -> Data:
sender = self.sender
sender_name = self.sender_name
session_id = self.session_id
@ -88,7 +88,7 @@ class MemoryComponent(Component):
# override session_id
self.memory.session_id = session_id
stored = self.memory.messages
stored = await self.memory.aget_messages()
# langchain memories are supposed to return messages in ascending order
if order == "DESC":
stored = stored[::-1]
@ -99,7 +99,7 @@ class MemoryComponent(Component):
expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER
stored = [m for m in stored if m.type == expected_type]
else:
stored = get_messages(
stored = await aget_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
@ -109,8 +109,8 @@ class MemoryComponent(Component):
self.status = stored
return stored
def retrieve_messages_as_text(self) -> Message:
stored_text = data_to_text(self.template, self.retrieve_messages())
async def retrieve_messages_as_text(self) -> Message:
stored_text = data_to_text(self.template, await self.retrieve_messages())
self.status = stored_text
return Message(text=stored_text)

View file

@ -1,7 +1,7 @@
from langflow.custom import Component
from langflow.inputs import HandleInput, MessageInput
from langflow.inputs.inputs import MessageTextInput
from langflow.memory import get_messages, store_message
from langflow.memory import aget_messages, astore_message
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
@ -47,7 +47,7 @@ class StoreMessageComponent(Component):
Output(display_name="Stored Messages", name="stored_messages", method="store_message"),
]
def store_message(self) -> Message:
async def store_message(self) -> Message:
message = self.message
message.session_id = self.session_id or message.session_id
@ -58,13 +58,15 @@ class StoreMessageComponent(Component):
# override session_id
self.memory.session_id = message.session_id
lc_message = message.to_lc_message()
self.memory.add_messages([lc_message])
stored = self.memory.messages
await self.memory.aadd_messages([lc_message])
stored = await self.memory.aget_messages()
stored = [Message.from_lc_message(m) for m in stored]
if message.sender:
stored = [m for m in stored if m.sender == message.sender]
else:
store_message(message, flow_id=self.graph.flow_id)
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
await astore_message(message, flow_id=self.graph.flow_id)
stored = await aget_messages(
session_id=message.session_id, sender_name=message.sender_name, sender=message.sender
)
self.status = stored
return stored

View file

@ -78,11 +78,12 @@ class ChatInput(ChatComponent):
Output(display_name="Message", name="message", method="message_response"),
]
def message_response(self) -> Message:
async def message_response(self) -> Message:
_background_color = self.background_color
_text_color = self.text_color
_icon = self.chat_icon
message = Message(
message = await Message.create(
text=self.input_value,
sender=self.sender,
sender_name=self.sender_name,
@ -91,7 +92,7 @@ class ChatInput(ChatComponent):
properties={"background_color": _background_color, "text_color": _text_color, "icon": _icon},
)
if self.session_id and isinstance(message, Message) and self.should_store_message:
stored_message = self.send_message(
stored_message = await self.send_message(
message,
)
self.message.value = stored_message

View file

@ -90,7 +90,7 @@ class ChatOutput(ChatComponent):
source_dict["source"] = source
return Source(**source_dict)
def message_response(self) -> Message:
async def message_response(self) -> Message:
_source, _icon, _display_name, _source_id = self.get_properties_from_source_component()
_background_color = self.background_color
_text_color = self.text_color
@ -106,7 +106,7 @@ class ChatOutput(ChatComponent):
message.properties.background_color = _background_color
message.properties.text_color = _text_color
if self.session_id and isinstance(message, Message) and self.should_store_message:
stored_message = self.send_message(
stored_message = await self.send_message(
message,
)
self.message.value = stored_message

View file

@ -24,7 +24,7 @@ from langflow.exceptions.component import StreamingError
from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output
from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.memory import delete_message, store_message, update_messages
from langflow.memory import astore_message, aupdate_messages, delete_message
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
from langflow.schema.message import ErrorMessage, Message
@ -847,7 +847,7 @@ class Component(CustomComponent):
return await self._build_with_tracing()
return await self._build_without_tracing()
except StreamingError as e:
self.send_error(
await self.send_error(
exception=e.cause,
session_id=session_id,
trace_name=getattr(self, "trace_name", None),
@ -855,7 +855,7 @@ class Component(CustomComponent):
)
raise e.cause # noqa: B904
except Exception as e:
self.send_error(
await self.send_error(
exception=e,
session_id=session_id,
source=Source(id=self._id, display_name=self.display_name, source=self.display_name),
@ -1016,10 +1016,10 @@ class Component(CustomComponent):
)
)
def send_message(self, message: Message, id_: str | None = None):
async def send_message(self, message: Message, id_: str | None = None):
if (hasattr(self, "graph") and self.graph.session_id) and (message is not None and not message.session_id):
message.session_id = self.graph.session_id
stored_message = self._store_message(message)
stored_message = await self._store_message(message)
self._stored_message_id = stored_message.id
try:
@ -1029,22 +1029,22 @@ class Component(CustomComponent):
and message is not None
and isinstance(message.text, AsyncIterator | Iterator)
):
complete_message = self._stream_message(message.text, stored_message)
complete_message = await self._stream_message(message.text, stored_message)
stored_message.text = complete_message
stored_message = self._update_stored_message(stored_message)
stored_message = await self._update_stored_message(stored_message)
else:
# Only send message event for non-streaming messages
self._send_message_event(stored_message, id_=id_)
except Exception:
# remove the message from the database
delete_message(stored_message.id)
await delete_message(stored_message.id)
raise
self.status = stored_message
return stored_message
def _store_message(self, message: Message) -> Message:
async def _store_message(self, message: Message) -> Message:
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
messages = store_message(message, flow_id=flow_id)
messages = await astore_message(message, flow_id=flow_id)
if len(messages) != 1:
msg = "Only one message can be stored at a time."
raise ValueError(msg)
@ -1073,21 +1073,21 @@ class Component(CustomComponent):
and not isinstance(original_message.text, str)
)
def _update_stored_message(self, stored_message: Message) -> Message:
message_tables = update_messages(stored_message)
async def _update_stored_message(self, stored_message: Message) -> Message:
message_tables = await aupdate_messages(stored_message)
if len(message_tables) != 1:
msg = "Only one message can be updated at a time."
raise ValueError(msg)
message_table = message_tables[0]
return Message(**message_table.model_dump())
return await Message.create(**message_table.model_dump())
def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str:
async def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str:
if not isinstance(iterator, AsyncIterator | Iterator):
msg = "The message must be an iterator or an async iterator."
raise TypeError(msg)
if isinstance(iterator, AsyncIterator):
return run_until_complete(self._handle_async_iterator(iterator, message.id, message))
return await self._handle_async_iterator(iterator, message.id, message)
try:
complete_message = ""
first_chunk = True
@ -1129,7 +1129,7 @@ class Component(CustomComponent):
)
return complete_message
def send_error(
async def send_error(
self,
exception: Exception,
session_id: str,
@ -1145,7 +1145,7 @@ class Component(CustomComponent):
trace_name=trace_name,
source=source,
)
self.send_message(error_message)
await self.send_message(error_message)
return error_message
def _append_tool_to_outputs_map(self):

View file

@ -401,7 +401,7 @@ class InterfaceVertex(ComponentVertex):
type=ArtifactType.OBJECT.value,
).model_dump()
message = Message(
message = await Message.create(
text=complete_message,
sender=self.params.get("sender", ""),
sender_name=self.params.get("sender_name", ""),
@ -434,7 +434,7 @@ class InterfaceVertex(ComponentVertex):
and hasattr(self.custom_component, "should_store_message")
and hasattr(self.custom_component, "store_message")
):
self.custom_component.store_message(message)
await self.custom_component.store_message(message)
await log_vertex_build(
flow_id=self.graph.flow_id,
vertex_id=self.id,

View file

@ -7,13 +7,40 @@ from langchain_core.messages import BaseMessage
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.deps import session_scope
from langflow.services.deps import async_session_scope, session_scope
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
def _get_variable_query(
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
):
stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712
if sender:
stmt = stmt.where(MessageTable.sender == sender)
if sender_name:
stmt = stmt.where(MessageTable.sender_name == sender_name)
if session_id:
stmt = stmt.where(MessageTable.session_id == session_id)
if flow_id:
stmt = stmt.where(MessageTable.flow_id == flow_id)
if order_by:
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
if limit:
stmt = stmt.limit(limit)
return stmt
def get_messages(
sender: str | None = None,
sender_name: str | None = None,
@ -38,24 +65,40 @@ def get_messages(
List[Data]: A list of Data objects representing the retrieved messages.
"""
with session_scope() as session:
stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712
if sender:
stmt = stmt.where(MessageTable.sender == sender)
if sender_name:
stmt = stmt.where(MessageTable.sender_name == sender_name)
if session_id:
stmt = stmt.where(MessageTable.session_id == session_id)
if flow_id:
stmt = stmt.where(MessageTable.flow_id == flow_id)
if order_by:
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
if limit:
stmt = stmt.limit(limit)
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
messages = session.exec(stmt)
return [Message(**d.model_dump()) for d in messages]
async def aget_messages(
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
) -> list[Message]:
"""Retrieves messages from the monitor service based on the provided filters.
Args:
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
sender_name (Optional[str]): The name of the sender.
session_id (Optional[str]): The session ID associated with the messages.
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp".
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC".
flow_id (Optional[UUID]): The flow ID associated with the messages.
limit (Optional[int]): The maximum number of messages to retrieve.
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
async with async_session_scope() as session:
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
messages = await session.exec(stmt)
return [await Message.create(**d.model_dump()) for d in messages]
def add_messages(messages: Message | list[Message], flow_id: str | None = None):
"""Add a message to the monitor service."""
if not isinstance(messages, list):
@ -76,6 +119,26 @@ def add_messages(messages: Message | list[Message], flow_id: str | None = None):
raise
async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None):
"""Add a message to the monitor service."""
if not isinstance(messages, list):
messages = [messages]
if not all(isinstance(message, Message) for message in messages):
types = ", ".join([str(type(message)) for message in messages])
msg = f"The messages must be instances of Message. Found: {types}"
raise ValueError(msg)
try:
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
async with async_session_scope() as session:
messages_models = await aadd_messagetables(messages_models, session)
return [await Message.create(**message.model_dump()) for message in messages_models]
except Exception as e:
logger.exception(e)
raise
def update_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
@ -95,6 +158,25 @@ def update_messages(messages: Message | list[Message]) -> list[Message]:
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
async with async_session_scope() as session:
updated_messages: list[MessageTable] = []
for message in messages:
msg = await session.get(MessageTable, message.id)
if msg:
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
session.add(msg)
await session.commit()
await session.refresh(msg)
updated_messages.append(msg)
else:
logger.warning(f"Message with id {message.id} not found")
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
def add_messagetables(messages: list[MessageTable], session: Session):
for message in messages:
try:
@ -115,6 +197,27 @@ def add_messagetables(messages: list[MessageTable], session: Session):
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
try:
for message in messages:
session.add(message)
await session.commit()
for message in messages:
await session.refresh(message)
except Exception as e:
logger.exception(e)
raise
new_messages = []
for msg in messages:
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
msg.category = msg.category or ""
new_messages.append(msg)
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
def delete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
@ -129,17 +232,32 @@ def delete_messages(session_id: str) -> None:
)
def delete_message(id_: str) -> None:
async def adelete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
Args:
session_id (str): The session ID associated with the messages to delete.
"""
async with async_session_scope() as session:
stmt = (
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
await session.exec(stmt)
async def delete_message(id_: str) -> None:
"""Delete a message from the monitor service based on the provided ID.
Args:
id_ (str): The ID of the message to delete.
"""
with session_scope() as session:
message = session.get(MessageTable, id_)
async with async_session_scope() as session:
message = await session.get(MessageTable, id_)
if message:
session.delete(message)
session.commit()
await session.delete(message)
await session.commit()
def store_message(
@ -182,6 +300,35 @@ def store_message(
return add_messages([message], flow_id=flow_id)
async def astore_message(
message: Message,
flow_id: str | None = None,
) -> list[Message]:
"""Stores a message in the memory.
Args:
message (Message): The message to store.
flow_id (Optional[str]): The flow ID associated with the message.
When running from the CustomComponent you can access this using `self.graph.flow_id`.
Returns:
List[Message]: A list of data containing the stored message.
Raises:
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
"""
if not message:
logger.warning("No message provided.")
return []
if not message.session_id or not message.sender or not message.sender_name:
msg = "All of session_id, sender, and sender_name must be provided."
raise ValueError(msg)
if hasattr(message, "id") and message.id:
return await aupdate_messages([message])
return await aadd_messages([message], flow_id=flow_id)
class LCBuiltinChatMemory(BaseChatMessageHistory):
def __init__(
self,
@ -198,11 +345,26 @@ class LCBuiltinChatMemory(BaseChatMessageHistory):
)
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
async def aget_messages(self) -> list[BaseMessage]:
messages = await aget_messages(
session_id=self.session_id,
)
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
for lc_message in messages:
message = Message.from_lc_message(lc_message)
message.session_id = self.session_id
store_message(message, flow_id=self.flow_id)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
for lc_message in messages:
message = Message.from_lc_message(lc_message)
message.session_id = self.session_id
await astore_message(message, flow_id=self.flow_id)
def clear(self) -> None:
delete_messages(self.session_id)
async def aclear(self) -> None:
await adelete_messages(self.session_id)

View file

@ -14,7 +14,7 @@ class LogFunctionType(Protocol):
class SendMessageFunctionType(Protocol):
def __call__(
async def __call__(
self,
message: Message | None = None,
text: str | None = None,

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import re
import traceback
@ -267,6 +268,13 @@ class Message(Data):
instance.messages = instance.prompt.get("kwargs", {}).get("messages", [])
return instance
@classmethod
async def create(cls, **kwargs):
"""If files are present, create the message in a separate thread as is_image_file is blocking."""
if "files" in kwargs:
return await asyncio.to_thread(cls, **kwargs)
return cls(**kwargs)
class DefaultModel(BaseModel):
class Config:

View file

@ -29,6 +29,7 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import selectinload
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession
@ -151,6 +152,17 @@ def session_fixture():
SQLModel.metadata.drop_all(engine) # Add this line to clean up tables
@pytest.fixture
async def async_session():
engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with AsyncSession(engine) as session:
yield session
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
class Config:
broker_url = "redis://localhost:6379/0"
result_backend = "redis://localhost:6379/0"

View file

@ -1,5 +1,5 @@
from langflow.components.inputs import ChatInput
from langflow.memory import get_messages
from langflow.memory import aget_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
@ -38,7 +38,7 @@ async def test_do_not_store_messages():
assert outputs["message"].text == "hello"
assert outputs["message"].session_id == session_id
assert len(get_messages(session_id=session_id)) == 1
assert len(await aget_messages(session_id=session_id)) == 1
session_id = "test-session-id-another"
outputs = await run_single_component(
@ -48,4 +48,4 @@ async def test_do_not_store_messages():
assert outputs["message"].text == "hello"
assert outputs["message"].session_id == session_id
assert len(get_messages(session_id=session_id)) == 0
assert len(await aget_messages(session_id=session_id)) == 0

View file

@ -1,5 +1,5 @@
from langflow.components.outputs import ChatOutput
from langflow.memory import get_messages
from langflow.memory import aget_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
@ -29,7 +29,7 @@ async def test_do_not_store_message():
assert isinstance(outputs["message"], Message)
assert outputs["message"].text == "hello"
assert len(get_messages(session_id=session_id)) == 1
assert len(await aget_messages(session_id=session_id)) == 1
session_id = "test-session-id-another"
outputs = await run_single_component(
@ -38,4 +38,4 @@ async def test_do_not_store_message():
assert isinstance(outputs["message"], Message)
assert outputs["message"].text == "hello"
assert len(get_messages(session_id=session_id)) == 0
assert len(await aget_messages(session_id=session_id)) == 0

View file

@ -1,6 +1,6 @@
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import AsyncMock
from langchain_core.agents import AgentFinish
from langflow.base.agents.agent import process_agent_events
@ -26,7 +26,7 @@ async def create_event_iterator(events: list[dict[str, Any]]) -> AsyncIterator[d
async def test_chain_start_event():
"""Test handling of on_chain_start event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
events = [
{"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0}
@ -51,7 +51,7 @@ async def test_chain_start_event():
async def test_chain_end_event():
"""Test handling of on_chain_end event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
# Create a mock AgentFinish output
output = AgentFinish(return_values={"output": "final output"}, log="test log")
@ -77,7 +77,7 @@ async def test_chain_end_event():
async def test_tool_start_event():
"""Test handling of on_tool_start event."""
send_message = MagicMock()
send_message = AsyncMock()
# Set up the send_message mock to return the modified message
def update_message(message):
@ -116,7 +116,7 @@ async def test_tool_start_event():
async def test_tool_end_event():
"""Test handling of on_tool_end event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
events = [
{
@ -151,7 +151,7 @@ async def test_tool_end_event():
async def test_tool_error_event():
"""Test handling of on_tool_error event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
events = [
{
@ -187,7 +187,7 @@ async def test_tool_error_event():
async def test_chain_stream_event():
"""Test handling of on_chain_stream event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
events = [{"event": "on_chain_stream", "data": {"chunk": {"output": "streamed output"}}, "start_time": 0}]
agent_message = Message(
@ -205,7 +205,7 @@ async def test_chain_stream_event():
async def test_multiple_events():
"""Test handling of multiple events in sequence."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
# Create a mock AgentFinish output instead of MockOutput
output = AgentFinish(return_values={"output": "final output"}, log="test log")
@ -248,7 +248,7 @@ async def test_multiple_events():
async def test_unknown_event():
"""Test handling of unknown event type."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -273,7 +273,7 @@ async def test_unknown_event():
async def test_handle_on_chain_start_with_input():
"""Test handle_on_chain_start with input."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -282,7 +282,7 @@ async def test_handle_on_chain_start_with_input():
)
event = {"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0}
updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert len(updated_message.content_blocks) == 1
@ -292,7 +292,7 @@ async def test_handle_on_chain_start_with_input():
async def test_handle_on_chain_start_no_input():
"""Test handle_on_chain_start without input."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -301,7 +301,7 @@ async def test_handle_on_chain_start_no_input():
)
event = {"event": "on_chain_start", "data": {}, "start_time": 0}
updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert len(updated_message.content_blocks) == 1
@ -311,7 +311,7 @@ async def test_handle_on_chain_start_no_input():
async def test_handle_on_chain_end_with_output():
"""Test handle_on_chain_end with output."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -322,7 +322,7 @@ async def test_handle_on_chain_end_with_output():
output = AgentFinish(return_values={"output": "final output"}, log="test log")
event = {"event": "on_chain_end", "data": {"output": output}, "start_time": 0}
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert updated_message.properties.state == "complete"
@ -332,7 +332,7 @@ async def test_handle_on_chain_end_with_output():
async def test_handle_on_chain_end_no_output():
"""Test handle_on_chain_end without output key in data."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -341,7 +341,7 @@ async def test_handle_on_chain_end_no_output():
)
event = {"event": "on_chain_end", "data": {}, "start_time": 0}
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert updated_message.properties.state == "partial"
@ -351,7 +351,7 @@ async def test_handle_on_chain_end_no_output():
async def test_handle_on_chain_end_empty_data():
"""Test handle_on_chain_end with empty data."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -360,7 +360,7 @@ async def test_handle_on_chain_end_empty_data():
)
event = {"event": "on_chain_end", "data": {"output": None}, "start_time": 0}
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert updated_message.properties.state == "partial"
@ -370,7 +370,7 @@ async def test_handle_on_chain_end_empty_data():
async def test_handle_on_chain_end_with_empty_return_values():
"""Test handle_on_chain_end with empty return_values."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -384,7 +384,7 @@ async def test_handle_on_chain_end_with_empty_return_values():
event = {"event": "on_chain_end", "data": {"output": MockOutputEmptyReturnValues()}, "start_time": 0}
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
assert updated_message.properties.icon == "Bot"
assert updated_message.properties.state == "partial"
@ -394,7 +394,7 @@ async def test_handle_on_chain_end_with_empty_return_values():
async def test_handle_on_tool_start():
"""Test handle_on_tool_start event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
tool_blocks_map = {}
agent_message = Message(
sender=MESSAGE_SENDER_AI,
@ -410,7 +410,7 @@ async def test_handle_on_tool_start():
"start_time": 0,
}
updated_message, start_time = handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0)
updated_message, start_time = await handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0)
assert len(updated_message.content_blocks) == 1
assert len(updated_message.content_blocks[0].contents) > 0
@ -426,7 +426,7 @@ async def test_handle_on_tool_start():
async def test_handle_on_tool_end():
"""Test handle_on_tool_end event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
tool_blocks_map = {}
agent_message = Message(
sender=MESSAGE_SENDER_AI,
@ -441,7 +441,7 @@ async def test_handle_on_tool_end():
"run_id": "test_run",
"data": {"input": {"query": "tool input"}},
}
agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
end_event = {
"event": "on_tool_end",
@ -451,7 +451,7 @@ async def test_handle_on_tool_end():
"start_time": 0,
}
updated_message, start_time = handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0)
updated_message, start_time = await handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0)
f"{end_event['name']}_{end_event['run_id']}"
tool_content = updated_message.content_blocks[0].contents[-1]
@ -463,7 +463,7 @@ async def test_handle_on_tool_end():
async def test_handle_on_tool_error():
"""Test handle_on_tool_error event."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
tool_blocks_map = {}
agent_message = Message(
sender=MESSAGE_SENDER_AI,
@ -478,7 +478,7 @@ async def test_handle_on_tool_error():
"run_id": "test_run",
"data": {"input": {"query": "tool input"}},
}
agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
error_event = {
"event": "on_tool_error",
@ -488,7 +488,9 @@ async def test_handle_on_tool_error():
"start_time": 0,
}
updated_message, start_time = handle_on_tool_error(error_event, agent_message, tool_blocks_map, send_message, 0.0)
updated_message, start_time = await handle_on_tool_error(
error_event, agent_message, tool_blocks_map, send_message, 0.0
)
tool_content = updated_message.content_blocks[0].contents[-1]
assert tool_content.name == "test_tool"
@ -500,7 +502,7 @@ async def test_handle_on_tool_error():
async def test_handle_on_chain_stream_with_output():
"""Test handle_on_chain_stream with output."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -512,7 +514,7 @@ async def test_handle_on_chain_stream_with_output():
"data": {"chunk": {"output": "streamed output"}},
}
updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0)
assert updated_message.text == "streamed output"
assert updated_message.properties.state == "complete"
@ -521,7 +523,7 @@ async def test_handle_on_chain_stream_with_output():
async def test_handle_on_chain_stream_no_output():
"""Test handle_on_chain_stream without output."""
send_message = MagicMock(side_effect=lambda message: message)
send_message = AsyncMock(side_effect=lambda message: message)
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name="Agent",
@ -534,7 +536,7 @@ async def test_handle_on_chain_stream_no_output():
"data": {"chunk": {}},
}
updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0)
updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0)
assert updated_message.text == ""
assert updated_message.properties.state == "partial"

View file

@ -1,3 +1,5 @@
import asyncio
import pytest
from langflow.components.inputs import ChatInput, TextInputComponent
from langflow.schema.message import Message
@ -36,10 +38,10 @@ class TestChatInput(ComponentTestBaseWithClient):
{"version": "1.0.19", "module": "inputs", "file_name": "ChatInput"},
]
def test_message_response(self, component_class, default_kwargs):
async def test_message_response(self, component_class, default_kwargs):
"""Test that the message_response method returns a valid Message object."""
component = component_class(**default_kwargs)
message = component.message_response()
message = await component.message_response()
assert isinstance(message, Message)
assert message.text == default_kwargs["input_value"]
@ -58,7 +60,7 @@ class TestChatInput(ComponentTestBaseWithClient):
"targets": [],
}
def test_message_response_ai_sender(self, component_class):
async def test_message_response_ai_sender(self, component_class):
"""Test message response with AI sender type."""
kwargs = {
"input_value": "I am an AI assistant",
@ -67,13 +69,13 @@ class TestChatInput(ComponentTestBaseWithClient):
"session_id": "test_session_123",
}
component = component_class(**kwargs)
message = component.message_response()
message = await component.message_response()
assert isinstance(message, Message)
assert message.sender == MESSAGE_SENDER_AI
assert message.sender_name == "AI Assistant"
def test_message_response_without_session(self, component_class):
async def test_message_response_without_session(self, component_class):
"""Test message response without session ID."""
kwargs = {
"input_value": "Test message",
@ -82,16 +84,16 @@ class TestChatInput(ComponentTestBaseWithClient):
"session_id": "", # Empty session ID
}
component = component_class(**kwargs)
message = component.message_response()
message = await component.message_response()
assert isinstance(message, Message)
assert message.session_id == ""
def test_message_response_with_files(self, component_class, tmp_path):
async def test_message_response_with_files(self, component_class, tmp_path):
"""Test message response with file attachments."""
# Create a temporary test file
test_file = tmp_path / "test.txt"
test_file.write_text("Test content")
await asyncio.to_thread(test_file.write_text, "Test content")
kwargs = {
"input_value": "Message with file",
@ -101,13 +103,13 @@ class TestChatInput(ComponentTestBaseWithClient):
"files": [str(test_file)],
}
component = component_class(**kwargs)
message = component.message_response()
message = await component.message_response()
assert isinstance(message, Message)
assert len(message.files) == 1
assert message.files[0] == str(test_file)
def test_message_storage_disabled(self, component_class):
async def test_message_storage_disabled(self, component_class):
"""Test message response when storage is disabled."""
kwargs = {
"input_value": "Test message",
@ -117,7 +119,7 @@ class TestChatInput(ComponentTestBaseWithClient):
"session_id": "test_session_123",
}
component = component_class(**kwargs)
message = component.message_response()
message = await component.message_response()
assert isinstance(message, Message)
# The message should still be created but not stored

View file

@ -52,7 +52,7 @@ async def test_component_message_sending():
)
# Send the message
sent_message = await asyncio.to_thread(component.send_message, message)
sent_message = await component.send_message(message)
# Verify the message was sent
assert sent_message.id is not None
@ -85,7 +85,7 @@ async def test_component_tool_output():
)
# Send the message
sent_message = await asyncio.to_thread(component.send_message, message)
sent_message = await component.send_message(message)
# Verify the message was stored and processed
assert sent_message.id is not None
@ -112,8 +112,7 @@ async def test_component_error_handling():
msg = "Test error"
raise CustomError(msg)
except CustomError as e:
sent_message = await asyncio.to_thread(
component.send_error,
sent_message = await component.send_error(
exception=e,
session_id="test_session",
trace_name="test_trace",
@ -227,7 +226,7 @@ async def test_component_streaming_message():
)
# Send the streaming message
sent_message = await asyncio.to_thread(component.send_message, message)
sent_message = await component.send_message(message)
# Verify the message
assert sent_message.id is not None

View file

@ -55,6 +55,7 @@ async def test_graph_with_edge():
async def test_graph_functional():
chat_input = ChatInput(_id="chat_input")
chat_input.set(should_store_message=False)
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = await asyncio.to_thread(Graph, chat_input, chat_output)

View file

@ -75,9 +75,9 @@ def test_graph_functional_start_graph_state_update():
def test_graph_state_model_serialization():
chat_input = ChatInput(_id="chat_input")
chat_input.set(input_value="Test Sender Name")
chat_input.set(input_value="Test Sender Name", should_store_message=False)
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
chat_output.set(sender_name=chat_input.message_response, should_store_message=False)
graph = Graph(chat_input, chat_output)
graph.prepare()

View file

@ -2,7 +2,7 @@ import json
from uuid import UUID
import pytest
from langflow.memory import get_messages
from langflow.memory import aget_messages
from langflow.services.database.models.flow import FlowCreate, FlowUpdate
from orjson import orjson
@ -14,7 +14,7 @@ async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers)
async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r:
await consume_and_assert_stream(r)
check_messages(flow_id)
await check_messages(flow_id)
@pytest.mark.benchmark
@ -28,7 +28,7 @@ async def test_build_flow_from_request_data(client, json_memory_chatbot_no_llm,
) as r:
await consume_and_assert_stream(r)
check_messages(flow_id)
await check_messages(flow_id)
async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, logged_in_headers):
@ -47,11 +47,11 @@ async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, l
async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r:
await consume_and_assert_stream(r)
check_messages(flow_id)
await check_messages(flow_id)
def check_messages(flow_id):
messages = get_messages(flow_id=UUID(flow_id), order="ASC")
async def check_messages(flow_id):
messages = await aget_messages(flow_id=UUID(flow_id), order="ASC")
assert len(messages) == 2
assert messages[0].session_id == flow_id
assert messages[0].sender == "User"

View file

@ -3,8 +3,14 @@ from uuid import UUID, uuid4
import pytest
from langflow.memory import (
aadd_messages,
aadd_messagetables,
add_messages,
add_messagetables,
adelete_messages,
aget_messages,
astore_message,
aupdate_messages,
delete_messages,
get_messages,
store_message,
@ -18,29 +24,29 @@ from langflow.schema.properties import Properties, Source
# Assuming you have these imports available
from langflow.services.database.models.message import MessageCreate, MessageRead
from langflow.services.database.models.message.model import MessageTable
from langflow.services.deps import session_scope
from langflow.services.deps import async_session_scope
from langflow.services.tracing.utils import convert_to_langchain_type
@pytest.fixture
def created_message():
with session_scope() as session:
async def created_message():
async with async_session_scope() as session:
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
messagetable = MessageTable.model_validate(message, from_attributes=True)
messagetables = add_messagetables([messagetable], session)
messagetables = await aadd_messagetables([messagetable], session)
return MessageRead.model_validate(messagetables[0], from_attributes=True)
@pytest.fixture
def created_messages(session): # noqa: ARG001
with session_scope() as _session:
async def created_messages(session): # noqa: ARG001
async with async_session_scope() as _session:
messages = [
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
messagetables = add_messagetables(messagetables, _session)
messagetables = await aadd_messagetables(messagetables, _session)
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]
@ -58,6 +64,20 @@ def test_get_messages():
assert messages[1].text == "Test message 2"
@pytest.mark.usefixtures("client")
async def test_aget_messages():
await aadd_messages(
[
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
]
)
messages = await aget_messages(sender="User", session_id="session_id2", limit=2)
assert len(messages) == 2
assert messages[0].text == "Test message 1"
assert messages[1].text == "Test message 2"
@pytest.mark.usefixtures("client")
def test_add_messages():
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
@ -66,6 +86,14 @@ def test_add_messages():
assert messages[0].text == "New Test message"
@pytest.mark.usefixtures("client")
async def test_aadd_messages():
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
messages = await aadd_messages(message)
assert len(messages) == 1
assert messages[0].text == "New Test message"
@pytest.mark.usefixtures("client")
def test_add_messagetables(session):
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
@ -75,17 +103,53 @@ def test_add_messagetables(session):
@pytest.mark.usefixtures("client")
def test_delete_messages(session):
session_id = "session_id2"
async def test_aadd_messagetables(async_session):
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
added_messages = await aadd_messagetables(messages, async_session)
assert len(added_messages) == 1
assert added_messages[0].text == "New Test message"
@pytest.mark.usefixtures("client")
def test_delete_messages():
session_id = "new_session_id"
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id)
add_messages([message])
messages = get_messages(sender="User", session_id=session_id)
assert len(messages) == 1
delete_messages(session_id)
messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all()
messages = get_messages(sender="User", session_id=session_id)
assert len(messages) == 0
@pytest.mark.usefixtures("client")
async def test_adelete_messages():
session_id = "new_session_id"
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id)
await aadd_messages([message])
messages = await aget_messages(sender="User", session_id=session_id)
assert len(messages) == 1
await adelete_messages(session_id)
messages = await aget_messages(sender="User", session_id=session_id)
assert len(messages) == 0
@pytest.mark.usefixtures("client")
def test_store_message():
message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id")
stored_messages = store_message(message)
session_id = "stored_session_id"
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id)
store_message(message)
stored_messages = get_messages(sender="User", session_id=session_id)
assert len(stored_messages) == 1
assert stored_messages[0].text == "Stored message"
@pytest.mark.usefixtures("client")
async def test_astore_message():
session_id = "stored_session_id"
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id)
await astore_message(message)
stored_messages = await aget_messages(sender="User", session_id=session_id)
assert len(stored_messages) == 1
assert stored_messages[0].text == "Stored message"
@ -298,3 +362,188 @@ def test_update_message_with_nested_properties(created_message):
assert updated[0].properties.allow_markdown is True
assert updated[0].properties.state == "complete"
assert updated[0].properties.targets == []
@pytest.mark.usefixtures("client")
async def test_aupdate_single_message(created_message):
# Modify the message
created_message.text = "Updated message"
updated = await aupdate_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Updated message"
assert updated[0].id == created_message.id
@pytest.mark.usefixtures("client")
async def test_aupdate_multiple_messages(created_messages):
# Modify the messages
for i, message in enumerate(created_messages):
message.text = f"Updated message {i}"
updated = await aupdate_messages(created_messages)
assert len(updated) == len(created_messages)
for i, message in enumerate(updated):
assert message.text == f"Updated message {i}"
assert message.id == created_messages[i].id
@pytest.mark.usefixtures("client")
async def test_aupdate_nonexistent_message():
# Create a message with a non-existent UUID
message = MessageRead(
id=uuid4(), # Generate a random UUID that won't exist in the database
text="Test message",
sender="User",
sender_name="User",
session_id="session_id",
flow_id=uuid4(),
)
updated = await aupdate_messages(message)
assert len(updated) == 0
@pytest.mark.usefixtures("client")
async def test_aupdate_mixed_messages(created_messages):
# Create a mix of existing and non-existing messages
nonexistent_message = MessageRead(
id=uuid4(), # Generate a random UUID that won't exist in the database
text="Test message",
sender="User",
sender_name="User",
session_id="session_id",
flow_id=uuid4(),
)
messages_to_update = created_messages[:1] + [nonexistent_message]
created_messages[0].text = "Updated existing message"
updated = await aupdate_messages(messages_to_update)
assert len(updated) == 1
assert updated[0].text == "Updated existing message"
assert updated[0].id == created_messages[0].id
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type
@pytest.mark.usefixtures("client")
async def test_aupdate_message_with_timestamp(created_message):
# Set a specific timestamp
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
created_message.timestamp = new_timestamp
created_message.text = "Updated message with timestamp"
updated = await aupdate_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Updated message with timestamp"
# Compare timestamps without timezone info since DB doesn't preserve it
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None)
assert updated[0].id == created_message.id
@pytest.mark.usefixtures("client")
async def test_aupdate_multiple_messages_with_timestamps(created_messages):
# Modify messages with different timestamps
for i, message in enumerate(created_messages):
message.text = f"Updated message {i}"
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
updated = await aupdate_messages(created_messages)
assert len(updated) == len(created_messages)
for i, message in enumerate(updated):
assert message.text == f"Updated message {i}"
# Compare timestamps without timezone info
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None)
assert message.id == created_messages[i].id
@pytest.mark.usefixtures("client")
async def test_aupdate_message_with_content_blocks(created_message):
# Create a content block using proper models
text_content = TextContent(
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"}
)
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10)
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True)
created_message.content_blocks = [content_block]
created_message.text = "Message with content blocks"
updated = await aupdate_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Message with content blocks"
assert len(updated[0].content_blocks) == 1
# Verify the content block structure
updated_block = updated[0].content_blocks[0]
assert updated_block.title == "Test Block"
assert len(updated_block.contents) == 2
# Verify text content
text_content = updated_block.contents[0]
assert text_content.type == "text"
assert text_content.text == "Test content"
assert text_content.duration == 5
assert text_content.header["title"] == "Test Header"
# Verify tool content
tool_content = updated_block.contents[1]
assert tool_content.type == "tool_use"
assert tool_content.name == "test_tool"
assert tool_content.tool_input == {"param": "value"}
assert tool_content.duration == 10
@pytest.mark.usefixtures("client")
async def test_aupdate_message_with_nested_properties(created_message):
# Create a text content with nested properties
text_content = TextContent(
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15
)
content_block = ContentBlock(
title="Test Properties",
contents=[text_content],
allow_markdown=True,
media_url=["http://example.com/image.jpg"],
)
# Set properties according to the Properties model structure
created_message.properties = Properties(
text_color="blue",
background_color="white",
edited=False,
source=Source(id="test_id", display_name="Test Source", source="test"),
icon="TestIcon",
allow_markdown=True,
state="complete",
targets=[],
)
created_message.text = "Message with nested properties"
created_message.content_blocks = [content_block]
updated = await aupdate_messages(created_message)
assert len(updated) == 1
assert updated[0].text == "Message with nested properties"
# Verify the properties were properly serialized and stored
assert updated[0].properties.text_color == "blue"
assert updated[0].properties.background_color == "white"
assert updated[0].properties.edited is False
assert updated[0].properties.source.id == "test_id"
assert updated[0].properties.source.display_name == "Test Source"
assert updated[0].properties.source.source == "test"
assert updated[0].properties.icon == "TestIcon"
assert updated[0].properties.allow_markdown is True
assert updated[0].properties.state == "complete"
assert updated[0].properties.targets == []

View file

@ -2,33 +2,33 @@ from uuid import UUID
import pytest
from httpx import AsyncClient
from langflow.memory import add_messagetables
from langflow.memory import aadd_messagetables
# Assuming you have these imports available
from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate
from langflow.services.database.models.message.model import MessageTable
from langflow.services.deps import session_scope
from langflow.services.deps import async_session_scope
@pytest.fixture
def created_message():
with session_scope() as session:
async def created_message():
async with async_session_scope() as session:
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
messagetable = MessageTable.model_validate(message, from_attributes=True)
messagetables = add_messagetables([messagetable], session)
messagetables = await aadd_messagetables([messagetable], session)
return MessageRead.model_validate(messagetables[0], from_attributes=True)
@pytest.fixture
def created_messages(session): # noqa: ARG001
with session_scope() as _session:
async def created_messages(session): # noqa: ARG001
async with async_session_scope() as _session:
messages = [
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
return add_messagetables(messagetables, _session)
return await aadd_messagetables(messagetables, _session)
@pytest.mark.api_key_required