fix: Use AsyncSession in memory (#4665)
This commit is contained in:
parent
156597d3d1
commit
79b03ba133
26 changed files with 610 additions and 173 deletions
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
|
@ -168,7 +167,7 @@ class LCAgentComponent(Component):
|
|||
)
|
||||
except ExceptionWithMessageError as e:
|
||||
msg_id = e.agent_message.id
|
||||
await asyncio.to_thread(delete_message, id_=msg_id)
|
||||
await delete_message(id_=msg_id)
|
||||
self._send_message_event(e.agent_message, category="remove_message")
|
||||
raise
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# Add helper functions for each event type
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
|
|
@ -53,7 +52,7 @@ def _calculate_duration(start_time: float) -> int:
|
|||
return result
|
||||
|
||||
|
||||
def handle_on_chain_start(
|
||||
async def handle_on_chain_start(
|
||||
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
|
||||
) -> tuple[Message, float]:
|
||||
# Create content blocks if they don't exist
|
||||
|
|
@ -75,7 +74,7 @@ def handle_on_chain_start(
|
|||
header={"title": "Input", "icon": "MessageSquare"},
|
||||
)
|
||||
agent_message.content_blocks[0].contents.append(text_content)
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
return agent_message, start_time
|
||||
|
||||
|
|
@ -91,7 +90,7 @@ def _extract_output_text(output: str | list) -> str:
|
|||
return text
|
||||
|
||||
|
||||
def handle_on_chain_end(
|
||||
async def handle_on_chain_end(
|
||||
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
|
||||
) -> tuple[Message, float]:
|
||||
data_output = event["data"].get("output")
|
||||
|
|
@ -110,12 +109,12 @@ def handle_on_chain_end(
|
|||
header={"title": "Output", "icon": "MessageSquare"},
|
||||
)
|
||||
agent_message.content_blocks[0].contents.append(text_content)
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
return agent_message, start_time
|
||||
|
||||
|
||||
def handle_on_tool_start(
|
||||
async def handle_on_tool_start(
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
tool_blocks_map: dict[str, ToolContent],
|
||||
|
|
@ -149,12 +148,12 @@ def handle_on_tool_start(
|
|||
tool_blocks_map[tool_key] = tool_content
|
||||
agent_message.content_blocks[0].contents.append(tool_content)
|
||||
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
tool_blocks_map[tool_key] = agent_message.content_blocks[0].contents[-1]
|
||||
return agent_message, new_start_time
|
||||
|
||||
|
||||
def handle_on_tool_end(
|
||||
async def handle_on_tool_end(
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
tool_blocks_map: dict[str, ToolContent],
|
||||
|
|
@ -172,13 +171,13 @@ def handle_on_tool_end(
|
|||
tool_content.duration = duration
|
||||
tool_content.header = {"title": f"Executed **{tool_content.name}**", "icon": "Hammer"}
|
||||
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
new_start_time = perf_counter() # Get new start time for next operation
|
||||
return agent_message, new_start_time
|
||||
return agent_message, start_time
|
||||
|
||||
|
||||
def handle_on_tool_error(
|
||||
async def handle_on_tool_error(
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
tool_blocks_map: dict[str, ToolContent],
|
||||
|
|
@ -194,12 +193,12 @@ def handle_on_tool_error(
|
|||
tool_content.error = event["data"].get("error", "Unknown error")
|
||||
tool_content.duration = _calculate_duration(start_time)
|
||||
tool_content.header = {"title": f"Error using **{tool_content.name}**", "icon": "Hammer"}
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
return agent_message, start_time
|
||||
|
||||
|
||||
def handle_on_chain_stream(
|
||||
async def handle_on_chain_stream(
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
send_message_method: SendMessageFunctionType,
|
||||
|
|
@ -211,13 +210,13 @@ def handle_on_chain_stream(
|
|||
if output and isinstance(output, str | list):
|
||||
agent_message.text = _extract_output_text(output)
|
||||
agent_message.properties.state = "complete"
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
start_time = perf_counter()
|
||||
return agent_message, start_time
|
||||
|
||||
|
||||
class ToolEventHandler(Protocol):
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self,
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
|
|
@ -228,7 +227,7 @@ class ToolEventHandler(Protocol):
|
|||
|
||||
|
||||
class ChainEventHandler(Protocol):
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self,
|
||||
event: dict[str, Any],
|
||||
agent_message: Message,
|
||||
|
|
@ -265,7 +264,7 @@ async def process_agent_events(
|
|||
agent_message.properties.icon = "Bot"
|
||||
agent_message.properties.state = "partial"
|
||||
# Store the initial message
|
||||
agent_message = await asyncio.to_thread(send_message_method, message=agent_message)
|
||||
agent_message = await send_message_method(message=agent_message)
|
||||
try:
|
||||
# Create a mapping of run_ids to tool contents
|
||||
tool_blocks_map: dict[str, ToolContent] = {}
|
||||
|
|
@ -273,14 +272,14 @@ async def process_agent_events(
|
|||
async for event in agent_executor:
|
||||
if event["event"] in TOOL_EVENT_HANDLERS:
|
||||
tool_handler = TOOL_EVENT_HANDLERS[event["event"]]
|
||||
agent_message, start_time = tool_handler(
|
||||
agent_message, start_time = await tool_handler(
|
||||
event, agent_message, tool_blocks_map, send_message_method, start_time
|
||||
)
|
||||
elif event["event"] in CHAIN_EVENT_HANDLERS:
|
||||
chain_handler = CHAIN_EVENT_HANDLERS[event["event"]]
|
||||
agent_message, start_time = chain_handler(event, agent_message, send_message_method, start_time)
|
||||
agent_message, start_time = await chain_handler(event, agent_message, send_message_method, start_time)
|
||||
agent_message.properties.state = "complete"
|
||||
except Exception as e:
|
||||
raise ExceptionWithMessageError(agent_message) from e
|
||||
|
||||
return Message(**agent_message.model_dump())
|
||||
return await Message.create(**agent_message.model_dump())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.memory import store_message
|
||||
from langflow.memory import astore_message
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ class ChatComponent(Component):
|
|||
display_name = "Chat Component"
|
||||
description = "Use as base for chat components."
|
||||
|
||||
def build_with_data(
|
||||
async def build_with_data(
|
||||
self,
|
||||
*,
|
||||
sender: str | None = "User",
|
||||
|
|
@ -20,13 +21,13 @@ class ChatComponent(Component):
|
|||
session_id: str | None = None,
|
||||
return_message: bool = False,
|
||||
) -> str | Message:
|
||||
message = self._create_message(input_value, sender, sender_name, files, session_id)
|
||||
message = await asyncio.to_thread(self._create_message, input_value, sender, sender_name, files, session_id)
|
||||
message_text = message.text if not return_message else message
|
||||
|
||||
self.status = message_text
|
||||
if session_id and isinstance(message, Message) and isinstance(message.text, str):
|
||||
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
|
||||
messages = store_message(message, flow_id=flow_id)
|
||||
messages = await astore_message(message, flow_id=flow_id)
|
||||
self.status = messages
|
||||
self._send_messages_events(messages)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def build_description(component: Component, output: Output) -> str:
|
|||
return f"{output.method}({args}) - {component.description}"
|
||||
|
||||
|
||||
def send_message_noop(
|
||||
async def send_message_noop(
|
||||
message: Message,
|
||||
text: str | None = None, # noqa: ARG001
|
||||
background_color: str | None = None, # noqa: ARG001
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
if llm_model is None:
|
||||
msg = "No language model selected"
|
||||
raise ValueError(msg)
|
||||
self.chat_history = self.get_memory_data()
|
||||
self.chat_history = await self.get_memory_data()
|
||||
|
||||
if self.add_current_date_tool:
|
||||
if not isinstance(self.tools, list): # type: ignore[has-type]
|
||||
|
|
@ -92,12 +92,12 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
agent = self.create_agent_runnable()
|
||||
return await self.run_agent(agent)
|
||||
|
||||
def get_memory_data(self):
|
||||
async def get_memory_data(self):
|
||||
memory_kwargs = {
|
||||
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
|
||||
}
|
||||
|
||||
return MemoryComponent().set(**memory_kwargs).retrieve_messages()
|
||||
return await MemoryComponent().set(**memory_kwargs).retrieve_messages()
|
||||
|
||||
def get_llm(self):
|
||||
if isinstance(self.agent_llm, str):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.custom import CustomComponent
|
||||
from langflow.memory import get_messages, store_message
|
||||
from langflow.memory import aget_messages, astore_message
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
@ -13,12 +13,12 @@ class StoreMessageComponent(CustomComponent):
|
|||
"message": {"display_name": "Message"},
|
||||
}
|
||||
|
||||
def build(
|
||||
async def build(
|
||||
self,
|
||||
message: Message,
|
||||
) -> Message:
|
||||
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
|
||||
store_message(message, flow_id=flow_id)
|
||||
self.status = get_messages()
|
||||
await astore_message(message, flow_id=flow_id)
|
||||
self.status = await aget_messages()
|
||||
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.field_typing import BaseChatMemory
|
|||
from langflow.helpers.data import data_to_text
|
||||
from langflow.inputs import HandleInput
|
||||
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
|
||||
from langflow.memory import LCBuiltinChatMemory, get_messages
|
||||
from langflow.memory import LCBuiltinChatMemory, aget_messages
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
|
||||
|
|
@ -74,7 +74,7 @@ class MemoryComponent(Component):
|
|||
Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"),
|
||||
]
|
||||
|
||||
def retrieve_messages(self) -> Data:
|
||||
async def retrieve_messages(self) -> Data:
|
||||
sender = self.sender
|
||||
sender_name = self.sender_name
|
||||
session_id = self.session_id
|
||||
|
|
@ -88,7 +88,7 @@ class MemoryComponent(Component):
|
|||
# override session_id
|
||||
self.memory.session_id = session_id
|
||||
|
||||
stored = self.memory.messages
|
||||
stored = await self.memory.aget_messages()
|
||||
# langchain memories are supposed to return messages in ascending order
|
||||
if order == "DESC":
|
||||
stored = stored[::-1]
|
||||
|
|
@ -99,7 +99,7 @@ class MemoryComponent(Component):
|
|||
expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER
|
||||
stored = [m for m in stored if m.type == expected_type]
|
||||
else:
|
||||
stored = get_messages(
|
||||
stored = await aget_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
|
|
@ -109,8 +109,8 @@ class MemoryComponent(Component):
|
|||
self.status = stored
|
||||
return stored
|
||||
|
||||
def retrieve_messages_as_text(self) -> Message:
|
||||
stored_text = data_to_text(self.template, self.retrieve_messages())
|
||||
async def retrieve_messages_as_text(self) -> Message:
|
||||
stored_text = data_to_text(self.template, await self.retrieve_messages())
|
||||
self.status = stored_text
|
||||
return Message(text=stored_text)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.inputs import HandleInput, MessageInput
|
||||
from langflow.inputs.inputs import MessageTextInput
|
||||
from langflow.memory import get_messages, store_message
|
||||
from langflow.memory import aget_messages, astore_message
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
|
||||
|
|
@ -47,7 +47,7 @@ class StoreMessageComponent(Component):
|
|||
Output(display_name="Stored Messages", name="stored_messages", method="store_message"),
|
||||
]
|
||||
|
||||
def store_message(self) -> Message:
|
||||
async def store_message(self) -> Message:
|
||||
message = self.message
|
||||
|
||||
message.session_id = self.session_id or message.session_id
|
||||
|
|
@ -58,13 +58,15 @@ class StoreMessageComponent(Component):
|
|||
# override session_id
|
||||
self.memory.session_id = message.session_id
|
||||
lc_message = message.to_lc_message()
|
||||
self.memory.add_messages([lc_message])
|
||||
stored = self.memory.messages
|
||||
await self.memory.aadd_messages([lc_message])
|
||||
stored = await self.memory.aget_messages()
|
||||
stored = [Message.from_lc_message(m) for m in stored]
|
||||
if message.sender:
|
||||
stored = [m for m in stored if m.sender == message.sender]
|
||||
else:
|
||||
store_message(message, flow_id=self.graph.flow_id)
|
||||
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
|
||||
await astore_message(message, flow_id=self.graph.flow_id)
|
||||
stored = await aget_messages(
|
||||
session_id=message.session_id, sender_name=message.sender_name, sender=message.sender
|
||||
)
|
||||
self.status = stored
|
||||
return stored
|
||||
|
|
|
|||
|
|
@ -78,11 +78,12 @@ class ChatInput(ChatComponent):
|
|||
Output(display_name="Message", name="message", method="message_response"),
|
||||
]
|
||||
|
||||
def message_response(self) -> Message:
|
||||
async def message_response(self) -> Message:
|
||||
_background_color = self.background_color
|
||||
_text_color = self.text_color
|
||||
_icon = self.chat_icon
|
||||
message = Message(
|
||||
|
||||
message = await Message.create(
|
||||
text=self.input_value,
|
||||
sender=self.sender,
|
||||
sender_name=self.sender_name,
|
||||
|
|
@ -91,7 +92,7 @@ class ChatInput(ChatComponent):
|
|||
properties={"background_color": _background_color, "text_color": _text_color, "icon": _icon},
|
||||
)
|
||||
if self.session_id and isinstance(message, Message) and self.should_store_message:
|
||||
stored_message = self.send_message(
|
||||
stored_message = await self.send_message(
|
||||
message,
|
||||
)
|
||||
self.message.value = stored_message
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class ChatOutput(ChatComponent):
|
|||
source_dict["source"] = source
|
||||
return Source(**source_dict)
|
||||
|
||||
def message_response(self) -> Message:
|
||||
async def message_response(self) -> Message:
|
||||
_source, _icon, _display_name, _source_id = self.get_properties_from_source_component()
|
||||
_background_color = self.background_color
|
||||
_text_color = self.text_color
|
||||
|
|
@ -106,7 +106,7 @@ class ChatOutput(ChatComponent):
|
|||
message.properties.background_color = _background_color
|
||||
message.properties.text_color = _text_color
|
||||
if self.session_id and isinstance(message, Message) and self.should_store_message:
|
||||
stored_message = self.send_message(
|
||||
stored_message = await self.send_message(
|
||||
message,
|
||||
)
|
||||
self.message.value = stored_message
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from langflow.exceptions.component import StreamingError
|
|||
from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.memory import delete_message, store_message, update_messages
|
||||
from langflow.memory import astore_message, aupdate_messages, delete_message
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.message import ErrorMessage, Message
|
||||
|
|
@ -847,7 +847,7 @@ class Component(CustomComponent):
|
|||
return await self._build_with_tracing()
|
||||
return await self._build_without_tracing()
|
||||
except StreamingError as e:
|
||||
self.send_error(
|
||||
await self.send_error(
|
||||
exception=e.cause,
|
||||
session_id=session_id,
|
||||
trace_name=getattr(self, "trace_name", None),
|
||||
|
|
@ -855,7 +855,7 @@ class Component(CustomComponent):
|
|||
)
|
||||
raise e.cause # noqa: B904
|
||||
except Exception as e:
|
||||
self.send_error(
|
||||
await self.send_error(
|
||||
exception=e,
|
||||
session_id=session_id,
|
||||
source=Source(id=self._id, display_name=self.display_name, source=self.display_name),
|
||||
|
|
@ -1016,10 +1016,10 @@ class Component(CustomComponent):
|
|||
)
|
||||
)
|
||||
|
||||
def send_message(self, message: Message, id_: str | None = None):
|
||||
async def send_message(self, message: Message, id_: str | None = None):
|
||||
if (hasattr(self, "graph") and self.graph.session_id) and (message is not None and not message.session_id):
|
||||
message.session_id = self.graph.session_id
|
||||
stored_message = self._store_message(message)
|
||||
stored_message = await self._store_message(message)
|
||||
|
||||
self._stored_message_id = stored_message.id
|
||||
try:
|
||||
|
|
@ -1029,22 +1029,22 @@ class Component(CustomComponent):
|
|||
and message is not None
|
||||
and isinstance(message.text, AsyncIterator | Iterator)
|
||||
):
|
||||
complete_message = self._stream_message(message.text, stored_message)
|
||||
complete_message = await self._stream_message(message.text, stored_message)
|
||||
stored_message.text = complete_message
|
||||
stored_message = self._update_stored_message(stored_message)
|
||||
stored_message = await self._update_stored_message(stored_message)
|
||||
else:
|
||||
# Only send message event for non-streaming messages
|
||||
self._send_message_event(stored_message, id_=id_)
|
||||
except Exception:
|
||||
# remove the message from the database
|
||||
delete_message(stored_message.id)
|
||||
await delete_message(stored_message.id)
|
||||
raise
|
||||
self.status = stored_message
|
||||
return stored_message
|
||||
|
||||
def _store_message(self, message: Message) -> Message:
|
||||
async def _store_message(self, message: Message) -> Message:
|
||||
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
|
||||
messages = store_message(message, flow_id=flow_id)
|
||||
messages = await astore_message(message, flow_id=flow_id)
|
||||
if len(messages) != 1:
|
||||
msg = "Only one message can be stored at a time."
|
||||
raise ValueError(msg)
|
||||
|
|
@ -1073,21 +1073,21 @@ class Component(CustomComponent):
|
|||
and not isinstance(original_message.text, str)
|
||||
)
|
||||
|
||||
def _update_stored_message(self, stored_message: Message) -> Message:
|
||||
message_tables = update_messages(stored_message)
|
||||
async def _update_stored_message(self, stored_message: Message) -> Message:
|
||||
message_tables = await aupdate_messages(stored_message)
|
||||
if len(message_tables) != 1:
|
||||
msg = "Only one message can be updated at a time."
|
||||
raise ValueError(msg)
|
||||
message_table = message_tables[0]
|
||||
return Message(**message_table.model_dump())
|
||||
return await Message.create(**message_table.model_dump())
|
||||
|
||||
def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str:
|
||||
async def _stream_message(self, iterator: AsyncIterator | Iterator, message: Message) -> str:
|
||||
if not isinstance(iterator, AsyncIterator | Iterator):
|
||||
msg = "The message must be an iterator or an async iterator."
|
||||
raise TypeError(msg)
|
||||
|
||||
if isinstance(iterator, AsyncIterator):
|
||||
return run_until_complete(self._handle_async_iterator(iterator, message.id, message))
|
||||
return await self._handle_async_iterator(iterator, message.id, message)
|
||||
try:
|
||||
complete_message = ""
|
||||
first_chunk = True
|
||||
|
|
@ -1129,7 +1129,7 @@ class Component(CustomComponent):
|
|||
)
|
||||
return complete_message
|
||||
|
||||
def send_error(
|
||||
async def send_error(
|
||||
self,
|
||||
exception: Exception,
|
||||
session_id: str,
|
||||
|
|
@ -1145,7 +1145,7 @@ class Component(CustomComponent):
|
|||
trace_name=trace_name,
|
||||
source=source,
|
||||
)
|
||||
self.send_message(error_message)
|
||||
await self.send_message(error_message)
|
||||
return error_message
|
||||
|
||||
def _append_tool_to_outputs_map(self):
|
||||
|
|
|
|||
|
|
@ -401,7 +401,7 @@ class InterfaceVertex(ComponentVertex):
|
|||
type=ArtifactType.OBJECT.value,
|
||||
).model_dump()
|
||||
|
||||
message = Message(
|
||||
message = await Message.create(
|
||||
text=complete_message,
|
||||
sender=self.params.get("sender", ""),
|
||||
sender_name=self.params.get("sender_name", ""),
|
||||
|
|
@ -434,7 +434,7 @@ class InterfaceVertex(ComponentVertex):
|
|||
and hasattr(self.custom_component, "should_store_message")
|
||||
and hasattr(self.custom_component, "store_message")
|
||||
):
|
||||
self.custom_component.store_message(message)
|
||||
await self.custom_component.store_message(message)
|
||||
await log_vertex_build(
|
||||
flow_id=self.graph.flow_id,
|
||||
vertex_id=self.id,
|
||||
|
|
|
|||
|
|
@ -7,13 +7,40 @@ from langchain_core.messages import BaseMessage
|
|||
from loguru import logger
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.deps import async_session_scope, session_scope
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
|
||||
|
||||
|
||||
def _get_variable_query(
|
||||
sender: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
order: str | None = "DESC",
|
||||
flow_id: UUID | None = None,
|
||||
limit: int | None = None,
|
||||
):
|
||||
stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712
|
||||
if sender:
|
||||
stmt = stmt.where(MessageTable.sender == sender)
|
||||
if sender_name:
|
||||
stmt = stmt.where(MessageTable.sender_name == sender_name)
|
||||
if session_id:
|
||||
stmt = stmt.where(MessageTable.session_id == session_id)
|
||||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if order_by:
|
||||
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
return stmt
|
||||
|
||||
|
||||
def get_messages(
|
||||
sender: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
|
|
@ -38,24 +65,40 @@ def get_messages(
|
|||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
with session_scope() as session:
|
||||
stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712
|
||||
if sender:
|
||||
stmt = stmt.where(MessageTable.sender == sender)
|
||||
if sender_name:
|
||||
stmt = stmt.where(MessageTable.sender_name == sender_name)
|
||||
if session_id:
|
||||
stmt = stmt.where(MessageTable.session_id == session_id)
|
||||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if order_by:
|
||||
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
|
||||
messages = session.exec(stmt)
|
||||
return [Message(**d.model_dump()) for d in messages]
|
||||
|
||||
|
||||
async def aget_messages(
|
||||
sender: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
order_by: str | None = "timestamp",
|
||||
order: str | None = "DESC",
|
||||
flow_id: UUID | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Message]:
|
||||
"""Retrieves messages from the monitor service based on the provided filters.
|
||||
|
||||
Args:
|
||||
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
|
||||
sender_name (Optional[str]): The name of the sender.
|
||||
session_id (Optional[str]): The session ID associated with the messages.
|
||||
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp".
|
||||
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC".
|
||||
flow_id (Optional[UUID]): The flow ID associated with the messages.
|
||||
limit (Optional[int]): The maximum number of messages to retrieve.
|
||||
|
||||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
|
||||
messages = await session.exec(stmt)
|
||||
return [await Message.create(**d.model_dump()) for d in messages]
|
||||
|
||||
|
||||
def add_messages(messages: Message | list[Message], flow_id: str | None = None):
|
||||
"""Add a message to the monitor service."""
|
||||
if not isinstance(messages, list):
|
||||
|
|
@ -76,6 +119,26 @@ def add_messages(messages: Message | list[Message], flow_id: str | None = None):
|
|||
raise
|
||||
|
||||
|
||||
async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None):
|
||||
"""Add a message to the monitor service."""
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
if not all(isinstance(message, Message) for message in messages):
|
||||
types = ", ".join([str(type(message)) for message in messages])
|
||||
msg = f"The messages must be instances of Message. Found: {types}"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
|
||||
async with async_session_scope() as session:
|
||||
messages_models = await aadd_messagetables(messages_models, session)
|
||||
return [await Message.create(**message.model_dump()) for message in messages_models]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
|
||||
def update_messages(messages: Message | list[Message]) -> list[Message]:
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
|
@ -95,6 +158,25 @@ def update_messages(messages: Message | list[Message]) -> list[Message]:
|
|||
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
|
||||
|
||||
|
||||
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
async with async_session_scope() as session:
|
||||
updated_messages: list[MessageTable] = []
|
||||
for message in messages:
|
||||
msg = await session.get(MessageTable, message.id)
|
||||
if msg:
|
||||
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
updated_messages.append(msg)
|
||||
else:
|
||||
logger.warning(f"Message with id {message.id} not found")
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
|
||||
|
||||
|
||||
def add_messagetables(messages: list[MessageTable], session: Session):
|
||||
for message in messages:
|
||||
try:
|
||||
|
|
@ -115,6 +197,27 @@ def add_messagetables(messages: list[MessageTable], session: Session):
|
|||
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
|
||||
|
||||
|
||||
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
|
||||
try:
|
||||
for message in messages:
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
for message in messages:
|
||||
await session.refresh(message)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
|
||||
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
|
||||
msg.category = msg.category or ""
|
||||
new_messages.append(msg)
|
||||
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
|
||||
|
||||
|
||||
def delete_messages(session_id: str) -> None:
|
||||
"""Delete messages from the monitor service based on the provided session ID.
|
||||
|
||||
|
|
@ -129,17 +232,32 @@ def delete_messages(session_id: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
def delete_message(id_: str) -> None:
|
||||
async def adelete_messages(session_id: str) -> None:
|
||||
"""Delete messages from the monitor service based on the provided session ID.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
async with async_session_scope() as session:
|
||||
stmt = (
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
await session.exec(stmt)
|
||||
|
||||
|
||||
async def delete_message(id_: str) -> None:
|
||||
"""Delete a message from the monitor service based on the provided ID.
|
||||
|
||||
Args:
|
||||
id_ (str): The ID of the message to delete.
|
||||
"""
|
||||
with session_scope() as session:
|
||||
message = session.get(MessageTable, id_)
|
||||
async with async_session_scope() as session:
|
||||
message = await session.get(MessageTable, id_)
|
||||
if message:
|
||||
session.delete(message)
|
||||
session.commit()
|
||||
await session.delete(message)
|
||||
await session.commit()
|
||||
|
||||
|
||||
def store_message(
|
||||
|
|
@ -182,6 +300,35 @@ def store_message(
|
|||
return add_messages([message], flow_id=flow_id)
|
||||
|
||||
|
||||
async def astore_message(
|
||||
message: Message,
|
||||
flow_id: str | None = None,
|
||||
) -> list[Message]:
|
||||
"""Stores a message in the memory.
|
||||
|
||||
Args:
|
||||
message (Message): The message to store.
|
||||
flow_id (Optional[str]): The flow ID associated with the message.
|
||||
When running from the CustomComponent you can access this using `self.graph.flow_id`.
|
||||
|
||||
Returns:
|
||||
List[Message]: A list of data containing the stored message.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
|
||||
"""
|
||||
if not message:
|
||||
logger.warning("No message provided.")
|
||||
return []
|
||||
|
||||
if not message.session_id or not message.sender or not message.sender_name:
|
||||
msg = "All of session_id, sender, and sender_name must be provided."
|
||||
raise ValueError(msg)
|
||||
if hasattr(message, "id") and message.id:
|
||||
return await aupdate_messages([message])
|
||||
return await aadd_messages([message], flow_id=flow_id)
|
||||
|
||||
|
||||
class LCBuiltinChatMemory(BaseChatMessageHistory):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -198,11 +345,26 @@ class LCBuiltinChatMemory(BaseChatMessageHistory):
|
|||
)
|
||||
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
|
||||
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
messages = await aget_messages(
|
||||
session_id=self.session_id,
|
||||
)
|
||||
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
for lc_message in messages:
|
||||
message = Message.from_lc_message(lc_message)
|
||||
message.session_id = self.session_id
|
||||
store_message(message, flow_id=self.flow_id)
|
||||
|
||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
for lc_message in messages:
|
||||
message = Message.from_lc_message(lc_message)
|
||||
message.session_id = self.session_id
|
||||
await astore_message(message, flow_id=self.flow_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
delete_messages(self.session_id)
|
||||
|
||||
async def aclear(self) -> None:
|
||||
await adelete_messages(self.session_id)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class LogFunctionType(Protocol):
|
|||
|
||||
|
||||
class SendMessageFunctionType(Protocol):
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self,
|
||||
message: Message | None = None,
|
||||
text: str | None = None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
|
|
@ -267,6 +268,13 @@ class Message(Data):
|
|||
instance.messages = instance.prompt.get("kwargs", {}).get("messages", [])
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
async def create(cls, **kwargs):
|
||||
"""If files are present, create the message in a separate thread as is_image_file is blocking."""
|
||||
if "files" in kwargs:
|
||||
return await asyncio.to_thread(cls, **kwargs)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
class Config:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
|
@ -151,6 +152,17 @@ def session_fixture():
|
|||
SQLModel.metadata.drop_all(engine) # Add this line to clean up tables
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session():
|
||||
engine = create_async_engine("sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
|
||||
class Config:
|
||||
broker_url = "redis://localhost:6379/0"
|
||||
result_backend = "redis://localhost:6379/0"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.components.inputs import ChatInput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.memory import aget_messages
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
|
@ -38,7 +38,7 @@ async def test_do_not_store_messages():
|
|||
assert outputs["message"].text == "hello"
|
||||
assert outputs["message"].session_id == session_id
|
||||
|
||||
assert len(get_messages(session_id=session_id)) == 1
|
||||
assert len(await aget_messages(session_id=session_id)) == 1
|
||||
|
||||
session_id = "test-session-id-another"
|
||||
outputs = await run_single_component(
|
||||
|
|
@ -48,4 +48,4 @@ async def test_do_not_store_messages():
|
|||
assert outputs["message"].text == "hello"
|
||||
assert outputs["message"].session_id == session_id
|
||||
|
||||
assert len(get_messages(session_id=session_id)) == 0
|
||||
assert len(await aget_messages(session_id=session_id)) == 0
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.memory import aget_messages
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
|
@ -29,7 +29,7 @@ async def test_do_not_store_message():
|
|||
assert isinstance(outputs["message"], Message)
|
||||
assert outputs["message"].text == "hello"
|
||||
|
||||
assert len(get_messages(session_id=session_id)) == 1
|
||||
assert len(await aget_messages(session_id=session_id)) == 1
|
||||
session_id = "test-session-id-another"
|
||||
|
||||
outputs = await run_single_component(
|
||||
|
|
@ -38,4 +38,4 @@ async def test_do_not_store_message():
|
|||
assert isinstance(outputs["message"], Message)
|
||||
assert outputs["message"].text == "hello"
|
||||
|
||||
assert len(get_messages(session_id=session_id)) == 0
|
||||
assert len(await aget_messages(session_id=session_id)) == 0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from langchain_core.agents import AgentFinish
|
||||
from langflow.base.agents.agent import process_agent_events
|
||||
|
|
@ -26,7 +26,7 @@ async def create_event_iterator(events: list[dict[str, Any]]) -> AsyncIterator[d
|
|||
|
||||
async def test_chain_start_event():
|
||||
"""Test handling of on_chain_start event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
events = [
|
||||
{"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0}
|
||||
|
|
@ -51,7 +51,7 @@ async def test_chain_start_event():
|
|||
|
||||
async def test_chain_end_event():
|
||||
"""Test handling of on_chain_end event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
# Create a mock AgentFinish output
|
||||
output = AgentFinish(return_values={"output": "final output"}, log="test log")
|
||||
|
|
@ -77,7 +77,7 @@ async def test_chain_end_event():
|
|||
|
||||
async def test_tool_start_event():
|
||||
"""Test handling of on_tool_start event."""
|
||||
send_message = MagicMock()
|
||||
send_message = AsyncMock()
|
||||
|
||||
# Set up the send_message mock to return the modified message
|
||||
def update_message(message):
|
||||
|
|
@ -116,7 +116,7 @@ async def test_tool_start_event():
|
|||
|
||||
async def test_tool_end_event():
|
||||
"""Test handling of on_tool_end event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
events = [
|
||||
{
|
||||
|
|
@ -151,7 +151,7 @@ async def test_tool_end_event():
|
|||
|
||||
async def test_tool_error_event():
|
||||
"""Test handling of on_tool_error event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
events = [
|
||||
{
|
||||
|
|
@ -187,7 +187,7 @@ async def test_tool_error_event():
|
|||
|
||||
async def test_chain_stream_event():
|
||||
"""Test handling of on_chain_stream event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
events = [{"event": "on_chain_stream", "data": {"chunk": {"output": "streamed output"}}, "start_time": 0}]
|
||||
agent_message = Message(
|
||||
|
|
@ -205,7 +205,7 @@ async def test_chain_stream_event():
|
|||
|
||||
async def test_multiple_events():
|
||||
"""Test handling of multiple events in sequence."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
|
||||
# Create a mock AgentFinish output instead of MockOutput
|
||||
output = AgentFinish(return_values={"output": "final output"}, log="test log")
|
||||
|
|
@ -248,7 +248,7 @@ async def test_multiple_events():
|
|||
|
||||
async def test_unknown_event():
|
||||
"""Test handling of unknown event type."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -273,7 +273,7 @@ async def test_unknown_event():
|
|||
|
||||
async def test_handle_on_chain_start_with_input():
|
||||
"""Test handle_on_chain_start with input."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -282,7 +282,7 @@ async def test_handle_on_chain_start_with_input():
|
|||
)
|
||||
event = {"event": "on_chain_start", "data": {"input": {"input": "test input", "chat_history": []}}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert len(updated_message.content_blocks) == 1
|
||||
|
|
@ -292,7 +292,7 @@ async def test_handle_on_chain_start_with_input():
|
|||
|
||||
async def test_handle_on_chain_start_no_input():
|
||||
"""Test handle_on_chain_start without input."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -301,7 +301,7 @@ async def test_handle_on_chain_start_no_input():
|
|||
)
|
||||
event = {"event": "on_chain_start", "data": {}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_start(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_start(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert len(updated_message.content_blocks) == 1
|
||||
|
|
@ -311,7 +311,7 @@ async def test_handle_on_chain_start_no_input():
|
|||
|
||||
async def test_handle_on_chain_end_with_output():
|
||||
"""Test handle_on_chain_end with output."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -322,7 +322,7 @@ async def test_handle_on_chain_end_with_output():
|
|||
output = AgentFinish(return_values={"output": "final output"}, log="test log")
|
||||
event = {"event": "on_chain_end", "data": {"output": output}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert updated_message.properties.state == "complete"
|
||||
|
|
@ -332,7 +332,7 @@ async def test_handle_on_chain_end_with_output():
|
|||
|
||||
async def test_handle_on_chain_end_no_output():
|
||||
"""Test handle_on_chain_end without output key in data."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -341,7 +341,7 @@ async def test_handle_on_chain_end_no_output():
|
|||
)
|
||||
event = {"event": "on_chain_end", "data": {}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert updated_message.properties.state == "partial"
|
||||
|
|
@ -351,7 +351,7 @@ async def test_handle_on_chain_end_no_output():
|
|||
|
||||
async def test_handle_on_chain_end_empty_data():
|
||||
"""Test handle_on_chain_end with empty data."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -360,7 +360,7 @@ async def test_handle_on_chain_end_empty_data():
|
|||
)
|
||||
event = {"event": "on_chain_end", "data": {"output": None}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert updated_message.properties.state == "partial"
|
||||
|
|
@ -370,7 +370,7 @@ async def test_handle_on_chain_end_empty_data():
|
|||
|
||||
async def test_handle_on_chain_end_with_empty_return_values():
|
||||
"""Test handle_on_chain_end with empty return_values."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -384,7 +384,7 @@ async def test_handle_on_chain_end_with_empty_return_values():
|
|||
|
||||
event = {"event": "on_chain_end", "data": {"output": MockOutputEmptyReturnValues()}, "start_time": 0}
|
||||
|
||||
updated_message, start_time = handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_end(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.properties.icon == "Bot"
|
||||
assert updated_message.properties.state == "partial"
|
||||
|
|
@ -394,7 +394,7 @@ async def test_handle_on_chain_end_with_empty_return_values():
|
|||
|
||||
async def test_handle_on_tool_start():
|
||||
"""Test handle_on_tool_start event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
tool_blocks_map = {}
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
|
|
@ -410,7 +410,7 @@ async def test_handle_on_tool_start():
|
|||
"start_time": 0,
|
||||
}
|
||||
|
||||
updated_message, start_time = handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_tool_start(event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
|
||||
assert len(updated_message.content_blocks) == 1
|
||||
assert len(updated_message.content_blocks[0].contents) > 0
|
||||
|
|
@ -426,7 +426,7 @@ async def test_handle_on_tool_start():
|
|||
|
||||
async def test_handle_on_tool_end():
|
||||
"""Test handle_on_tool_end event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
tool_blocks_map = {}
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
|
|
@ -441,7 +441,7 @@ async def test_handle_on_tool_end():
|
|||
"run_id": "test_run",
|
||||
"data": {"input": {"query": "tool input"}},
|
||||
}
|
||||
agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
|
||||
end_event = {
|
||||
"event": "on_tool_end",
|
||||
|
|
@ -451,7 +451,7 @@ async def test_handle_on_tool_end():
|
|||
"start_time": 0,
|
||||
}
|
||||
|
||||
updated_message, start_time = handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_tool_end(end_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
|
||||
f"{end_event['name']}_{end_event['run_id']}"
|
||||
tool_content = updated_message.content_blocks[0].contents[-1]
|
||||
|
|
@ -463,7 +463,7 @@ async def test_handle_on_tool_end():
|
|||
|
||||
async def test_handle_on_tool_error():
|
||||
"""Test handle_on_tool_error event."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
tool_blocks_map = {}
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
|
|
@ -478,7 +478,7 @@ async def test_handle_on_tool_error():
|
|||
"run_id": "test_run",
|
||||
"data": {"input": {"query": "tool input"}},
|
||||
}
|
||||
agent_message, _ = handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
agent_message, _ = await handle_on_tool_start(start_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
|
||||
error_event = {
|
||||
"event": "on_tool_error",
|
||||
|
|
@ -488,7 +488,9 @@ async def test_handle_on_tool_error():
|
|||
"start_time": 0,
|
||||
}
|
||||
|
||||
updated_message, start_time = handle_on_tool_error(error_event, agent_message, tool_blocks_map, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_tool_error(
|
||||
error_event, agent_message, tool_blocks_map, send_message, 0.0
|
||||
)
|
||||
|
||||
tool_content = updated_message.content_blocks[0].contents[-1]
|
||||
assert tool_content.name == "test_tool"
|
||||
|
|
@ -500,7 +502,7 @@ async def test_handle_on_tool_error():
|
|||
|
||||
async def test_handle_on_chain_stream_with_output():
|
||||
"""Test handle_on_chain_stream with output."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -512,7 +514,7 @@ async def test_handle_on_chain_stream_with_output():
|
|||
"data": {"chunk": {"output": "streamed output"}},
|
||||
}
|
||||
|
||||
updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.text == "streamed output"
|
||||
assert updated_message.properties.state == "complete"
|
||||
|
|
@ -521,7 +523,7 @@ async def test_handle_on_chain_stream_with_output():
|
|||
|
||||
async def test_handle_on_chain_stream_no_output():
|
||||
"""Test handle_on_chain_stream without output."""
|
||||
send_message = MagicMock(side_effect=lambda message: message)
|
||||
send_message = AsyncMock(side_effect=lambda message: message)
|
||||
agent_message = Message(
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
sender_name="Agent",
|
||||
|
|
@ -534,7 +536,7 @@ async def test_handle_on_chain_stream_no_output():
|
|||
"data": {"chunk": {}},
|
||||
}
|
||||
|
||||
updated_message, start_time = handle_on_chain_stream(event, agent_message, send_message, 0.0)
|
||||
updated_message, start_time = await handle_on_chain_stream(event, agent_message, send_message, 0.0)
|
||||
|
||||
assert updated_message.text == ""
|
||||
assert updated_message.properties.state == "partial"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langflow.components.inputs import ChatInput, TextInputComponent
|
||||
from langflow.schema.message import Message
|
||||
|
|
@ -36,10 +38,10 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
{"version": "1.0.19", "module": "inputs", "file_name": "ChatInput"},
|
||||
]
|
||||
|
||||
def test_message_response(self, component_class, default_kwargs):
|
||||
async def test_message_response(self, component_class, default_kwargs):
|
||||
"""Test that the message_response method returns a valid Message object."""
|
||||
component = component_class(**default_kwargs)
|
||||
message = component.message_response()
|
||||
message = await component.message_response()
|
||||
|
||||
assert isinstance(message, Message)
|
||||
assert message.text == default_kwargs["input_value"]
|
||||
|
|
@ -58,7 +60,7 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
"targets": [],
|
||||
}
|
||||
|
||||
def test_message_response_ai_sender(self, component_class):
|
||||
async def test_message_response_ai_sender(self, component_class):
|
||||
"""Test message response with AI sender type."""
|
||||
kwargs = {
|
||||
"input_value": "I am an AI assistant",
|
||||
|
|
@ -67,13 +69,13 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
"session_id": "test_session_123",
|
||||
}
|
||||
component = component_class(**kwargs)
|
||||
message = component.message_response()
|
||||
message = await component.message_response()
|
||||
|
||||
assert isinstance(message, Message)
|
||||
assert message.sender == MESSAGE_SENDER_AI
|
||||
assert message.sender_name == "AI Assistant"
|
||||
|
||||
def test_message_response_without_session(self, component_class):
|
||||
async def test_message_response_without_session(self, component_class):
|
||||
"""Test message response without session ID."""
|
||||
kwargs = {
|
||||
"input_value": "Test message",
|
||||
|
|
@ -82,16 +84,16 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
"session_id": "", # Empty session ID
|
||||
}
|
||||
component = component_class(**kwargs)
|
||||
message = component.message_response()
|
||||
message = await component.message_response()
|
||||
|
||||
assert isinstance(message, Message)
|
||||
assert message.session_id == ""
|
||||
|
||||
def test_message_response_with_files(self, component_class, tmp_path):
|
||||
async def test_message_response_with_files(self, component_class, tmp_path):
|
||||
"""Test message response with file attachments."""
|
||||
# Create a temporary test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("Test content")
|
||||
await asyncio.to_thread(test_file.write_text, "Test content")
|
||||
|
||||
kwargs = {
|
||||
"input_value": "Message with file",
|
||||
|
|
@ -101,13 +103,13 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
"files": [str(test_file)],
|
||||
}
|
||||
component = component_class(**kwargs)
|
||||
message = component.message_response()
|
||||
message = await component.message_response()
|
||||
|
||||
assert isinstance(message, Message)
|
||||
assert len(message.files) == 1
|
||||
assert message.files[0] == str(test_file)
|
||||
|
||||
def test_message_storage_disabled(self, component_class):
|
||||
async def test_message_storage_disabled(self, component_class):
|
||||
"""Test message response when storage is disabled."""
|
||||
kwargs = {
|
||||
"input_value": "Test message",
|
||||
|
|
@ -117,7 +119,7 @@ class TestChatInput(ComponentTestBaseWithClient):
|
|||
"session_id": "test_session_123",
|
||||
}
|
||||
component = component_class(**kwargs)
|
||||
message = component.message_response()
|
||||
message = await component.message_response()
|
||||
|
||||
assert isinstance(message, Message)
|
||||
# The message should still be created but not stored
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ async def test_component_message_sending():
|
|||
)
|
||||
|
||||
# Send the message
|
||||
sent_message = await asyncio.to_thread(component.send_message, message)
|
||||
sent_message = await component.send_message(message)
|
||||
|
||||
# Verify the message was sent
|
||||
assert sent_message.id is not None
|
||||
|
|
@ -85,7 +85,7 @@ async def test_component_tool_output():
|
|||
)
|
||||
|
||||
# Send the message
|
||||
sent_message = await asyncio.to_thread(component.send_message, message)
|
||||
sent_message = await component.send_message(message)
|
||||
|
||||
# Verify the message was stored and processed
|
||||
assert sent_message.id is not None
|
||||
|
|
@ -112,8 +112,7 @@ async def test_component_error_handling():
|
|||
msg = "Test error"
|
||||
raise CustomError(msg)
|
||||
except CustomError as e:
|
||||
sent_message = await asyncio.to_thread(
|
||||
component.send_error,
|
||||
sent_message = await component.send_error(
|
||||
exception=e,
|
||||
session_id="test_session",
|
||||
trace_name="test_trace",
|
||||
|
|
@ -227,7 +226,7 @@ async def test_component_streaming_message():
|
|||
)
|
||||
|
||||
# Send the streaming message
|
||||
sent_message = await asyncio.to_thread(component.send_message, message)
|
||||
sent_message = await component.send_message(message)
|
||||
|
||||
# Verify the message
|
||||
assert sent_message.id is not None
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ async def test_graph_with_edge():
|
|||
|
||||
async def test_graph_functional():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_input.set(should_store_message=False)
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
graph = await asyncio.to_thread(Graph, chat_input, chat_output)
|
||||
|
|
|
|||
|
|
@ -75,9 +75,9 @@ def test_graph_functional_start_graph_state_update():
|
|||
|
||||
def test_graph_state_model_serialization():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_input.set(input_value="Test Sender Name")
|
||||
chat_input.set(input_value="Test Sender Name", should_store_message=False)
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
chat_output.set(sender_name=chat_input.message_response, should_store_message=False)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from langflow.memory import get_messages
|
||||
from langflow.memory import aget_messages
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowUpdate
|
||||
from orjson import orjson
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers)
|
|||
async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r:
|
||||
await consume_and_assert_stream(r)
|
||||
|
||||
check_messages(flow_id)
|
||||
await check_messages(flow_id)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
|
|
@ -28,7 +28,7 @@ async def test_build_flow_from_request_data(client, json_memory_chatbot_no_llm,
|
|||
) as r:
|
||||
await consume_and_assert_stream(r)
|
||||
|
||||
check_messages(flow_id)
|
||||
await check_messages(flow_id)
|
||||
|
||||
|
||||
async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, logged_in_headers):
|
||||
|
|
@ -47,11 +47,11 @@ async def test_build_flow_with_frozen_path(client, json_memory_chatbot_no_llm, l
|
|||
async with client.stream("POST", f"api/v1/build/{flow_id}/flow", json={}, headers=logged_in_headers) as r:
|
||||
await consume_and_assert_stream(r)
|
||||
|
||||
check_messages(flow_id)
|
||||
await check_messages(flow_id)
|
||||
|
||||
|
||||
def check_messages(flow_id):
|
||||
messages = get_messages(flow_id=UUID(flow_id), order="ASC")
|
||||
async def check_messages(flow_id):
|
||||
messages = await aget_messages(flow_id=UUID(flow_id), order="ASC")
|
||||
assert len(messages) == 2
|
||||
assert messages[0].session_id == flow_id
|
||||
assert messages[0].sender == "User"
|
||||
|
|
|
|||
|
|
@ -3,8 +3,14 @@ from uuid import UUID, uuid4
|
|||
|
||||
import pytest
|
||||
from langflow.memory import (
|
||||
aadd_messages,
|
||||
aadd_messagetables,
|
||||
add_messages,
|
||||
add_messagetables,
|
||||
adelete_messages,
|
||||
aget_messages,
|
||||
astore_message,
|
||||
aupdate_messages,
|
||||
delete_messages,
|
||||
get_messages,
|
||||
store_message,
|
||||
|
|
@ -18,29 +24,29 @@ from langflow.schema.properties import Properties, Source
|
|||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.deps import async_session_scope
|
||||
from langflow.services.tracing.utils import convert_to_langchain_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_message():
|
||||
with session_scope() as session:
|
||||
async def created_message():
|
||||
async with async_session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
messagetables = await aadd_messagetables([messagetable], session)
|
||||
return MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_messages(session): # noqa: ARG001
|
||||
with session_scope() as _session:
|
||||
async def created_messages(session): # noqa: ARG001
|
||||
async with async_session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
messagetables = add_messagetables(messagetables, _session)
|
||||
messagetables = await aadd_messagetables(messagetables, _session)
|
||||
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]
|
||||
|
||||
|
||||
|
|
@ -58,6 +64,20 @@ def test_get_messages():
|
|||
assert messages[1].text == "Test message 2"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aget_messages():
|
||||
await aadd_messages(
|
||||
[
|
||||
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
)
|
||||
messages = await aget_messages(sender="User", session_id="session_id2", limit=2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].text == "Test message 1"
|
||||
assert messages[1].text == "Test message 2"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_add_messages():
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
|
||||
|
|
@ -66,6 +86,14 @@ def test_add_messages():
|
|||
assert messages[0].text == "New Test message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aadd_messages():
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
|
||||
messages = await aadd_messages(message)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].text == "New Test message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_add_messagetables(session):
|
||||
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
|
||||
|
|
@ -75,17 +103,53 @@ def test_add_messagetables(session):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_delete_messages(session):
|
||||
session_id = "session_id2"
|
||||
async def test_aadd_messagetables(async_session):
|
||||
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
|
||||
added_messages = await aadd_messagetables(messages, async_session)
|
||||
assert len(added_messages) == 1
|
||||
assert added_messages[0].text == "New Test message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_delete_messages():
|
||||
session_id = "new_session_id"
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id)
|
||||
add_messages([message])
|
||||
messages = get_messages(sender="User", session_id=session_id)
|
||||
assert len(messages) == 1
|
||||
delete_messages(session_id)
|
||||
messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all()
|
||||
messages = get_messages(sender="User", session_id=session_id)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_adelete_messages():
|
||||
session_id = "new_session_id"
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id)
|
||||
await aadd_messages([message])
|
||||
messages = await aget_messages(sender="User", session_id=session_id)
|
||||
assert len(messages) == 1
|
||||
await adelete_messages(session_id)
|
||||
messages = await aget_messages(sender="User", session_id=session_id)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_store_message():
|
||||
message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id")
|
||||
stored_messages = store_message(message)
|
||||
session_id = "stored_session_id"
|
||||
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id)
|
||||
store_message(message)
|
||||
stored_messages = get_messages(sender="User", session_id=session_id)
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].text == "Stored message"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_astore_message():
|
||||
session_id = "stored_session_id"
|
||||
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id)
|
||||
await astore_message(message)
|
||||
stored_messages = await aget_messages(sender="User", session_id=session_id)
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].text == "Stored message"
|
||||
|
||||
|
|
@ -298,3 +362,188 @@ def test_update_message_with_nested_properties(created_message):
|
|||
assert updated[0].properties.allow_markdown is True
|
||||
assert updated[0].properties.state == "complete"
|
||||
assert updated[0].properties.targets == []
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_single_message(created_message):
|
||||
# Modify the message
|
||||
created_message.text = "Updated message"
|
||||
updated = await aupdate_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated message"
|
||||
assert updated[0].id == created_message.id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_multiple_messages(created_messages):
|
||||
# Modify the messages
|
||||
for i, message in enumerate(created_messages):
|
||||
message.text = f"Updated message {i}"
|
||||
|
||||
updated = await aupdate_messages(created_messages)
|
||||
|
||||
assert len(updated) == len(created_messages)
|
||||
for i, message in enumerate(updated):
|
||||
assert message.text == f"Updated message {i}"
|
||||
assert message.id == created_messages[i].id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_nonexistent_message():
|
||||
# Create a message with a non-existent UUID
|
||||
message = MessageRead(
|
||||
id=uuid4(), # Generate a random UUID that won't exist in the database
|
||||
text="Test message",
|
||||
sender="User",
|
||||
sender_name="User",
|
||||
session_id="session_id",
|
||||
flow_id=uuid4(),
|
||||
)
|
||||
|
||||
updated = await aupdate_messages(message)
|
||||
assert len(updated) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_mixed_messages(created_messages):
|
||||
# Create a mix of existing and non-existing messages
|
||||
nonexistent_message = MessageRead(
|
||||
id=uuid4(), # Generate a random UUID that won't exist in the database
|
||||
text="Test message",
|
||||
sender="User",
|
||||
sender_name="User",
|
||||
session_id="session_id",
|
||||
flow_id=uuid4(),
|
||||
)
|
||||
|
||||
messages_to_update = created_messages[:1] + [nonexistent_message]
|
||||
created_messages[0].text = "Updated existing message"
|
||||
|
||||
updated = await aupdate_messages(messages_to_update)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated existing message"
|
||||
assert updated[0].id == created_messages[0].id
|
||||
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_message_with_timestamp(created_message):
|
||||
# Set a specific timestamp
|
||||
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
created_message.timestamp = new_timestamp
|
||||
created_message.text = "Updated message with timestamp"
|
||||
|
||||
updated = await aupdate_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Updated message with timestamp"
|
||||
|
||||
# Compare timestamps without timezone info since DB doesn't preserve it
|
||||
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None)
|
||||
assert updated[0].id == created_message.id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_multiple_messages_with_timestamps(created_messages):
|
||||
# Modify messages with different timestamps
|
||||
for i, message in enumerate(created_messages):
|
||||
message.text = f"Updated message {i}"
|
||||
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
updated = await aupdate_messages(created_messages)
|
||||
|
||||
assert len(updated) == len(created_messages)
|
||||
for i, message in enumerate(updated):
|
||||
assert message.text == f"Updated message {i}"
|
||||
# Compare timestamps without timezone info
|
||||
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc)
|
||||
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None)
|
||||
assert message.id == created_messages[i].id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_message_with_content_blocks(created_message):
|
||||
# Create a content block using proper models
|
||||
text_content = TextContent(
|
||||
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"}
|
||||
)
|
||||
|
||||
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10)
|
||||
|
||||
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True)
|
||||
|
||||
created_message.content_blocks = [content_block]
|
||||
created_message.text = "Message with content blocks"
|
||||
|
||||
updated = await aupdate_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Message with content blocks"
|
||||
assert len(updated[0].content_blocks) == 1
|
||||
|
||||
# Verify the content block structure
|
||||
updated_block = updated[0].content_blocks[0]
|
||||
assert updated_block.title == "Test Block"
|
||||
assert len(updated_block.contents) == 2
|
||||
|
||||
# Verify text content
|
||||
text_content = updated_block.contents[0]
|
||||
assert text_content.type == "text"
|
||||
assert text_content.text == "Test content"
|
||||
assert text_content.duration == 5
|
||||
assert text_content.header["title"] == "Test Header"
|
||||
|
||||
# Verify tool content
|
||||
tool_content = updated_block.contents[1]
|
||||
assert tool_content.type == "tool_use"
|
||||
assert tool_content.name == "test_tool"
|
||||
assert tool_content.tool_input == {"param": "value"}
|
||||
assert tool_content.duration == 10
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_aupdate_message_with_nested_properties(created_message):
|
||||
# Create a text content with nested properties
|
||||
text_content = TextContent(
|
||||
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15
|
||||
)
|
||||
|
||||
content_block = ContentBlock(
|
||||
title="Test Properties",
|
||||
contents=[text_content],
|
||||
allow_markdown=True,
|
||||
media_url=["http://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
# Set properties according to the Properties model structure
|
||||
created_message.properties = Properties(
|
||||
text_color="blue",
|
||||
background_color="white",
|
||||
edited=False,
|
||||
source=Source(id="test_id", display_name="Test Source", source="test"),
|
||||
icon="TestIcon",
|
||||
allow_markdown=True,
|
||||
state="complete",
|
||||
targets=[],
|
||||
)
|
||||
created_message.text = "Message with nested properties"
|
||||
created_message.content_blocks = [content_block]
|
||||
|
||||
updated = await aupdate_messages(created_message)
|
||||
|
||||
assert len(updated) == 1
|
||||
assert updated[0].text == "Message with nested properties"
|
||||
|
||||
# Verify the properties were properly serialized and stored
|
||||
assert updated[0].properties.text_color == "blue"
|
||||
assert updated[0].properties.background_color == "white"
|
||||
assert updated[0].properties.edited is False
|
||||
assert updated[0].properties.source.id == "test_id"
|
||||
assert updated[0].properties.source.display_name == "Test Source"
|
||||
assert updated[0].properties.source.source == "test"
|
||||
assert updated[0].properties.icon == "TestIcon"
|
||||
assert updated[0].properties.allow_markdown is True
|
||||
assert updated[0].properties.state == "complete"
|
||||
assert updated[0].properties.targets == []
|
||||
|
|
|
|||
|
|
@ -2,33 +2,33 @@ from uuid import UUID
|
|||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from langflow.memory import add_messagetables
|
||||
from langflow.memory import aadd_messagetables
|
||||
|
||||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
from langflow.services.deps import async_session_scope
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_message():
|
||||
with session_scope() as session:
|
||||
async def created_message():
|
||||
async with async_session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
messagetables = await aadd_messagetables([messagetable], session)
|
||||
return MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def created_messages(session): # noqa: ARG001
|
||||
with session_scope() as _session:
|
||||
async def created_messages(session): # noqa: ARG001
|
||||
async with async_session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
return add_messagetables(messagetables, _session)
|
||||
return await aadd_messagetables(messagetables, _session)
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue