From d1d5eb6e39b9e4d674e39c3f65336e88ddd6f7a9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 25 Jun 2024 20:20:12 -0300 Subject: [PATCH] refactor: Update add_messages and add_messagetables functions to return Message objects --- src/backend/base/langflow/memory.py | 6 +++--- src/backend/base/langflow/schema/message.py | 10 +++++++++- tests/test_messages_endpoints.py | 8 +++----- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index b0115d9b2..d1793740d 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -7,7 +7,7 @@ from sqlalchemy import delete from sqlmodel import Session, col, select from langflow.schema.message import Message -from langflow.services.database.models.message.model import MessageTable +from langflow.services.database.models.message.model import MessageRead, MessageTable from langflow.services.deps import session_scope @@ -75,7 +75,7 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non messages_models.append(MessageTable.from_message(msg, flow_id=flow_id)) with session_scope() as session: messages_models = add_messagetables(messages_models, session) - return messages_models + return [Message(**message.model_dump()) for message in messages_models] except Exception as e: logger.exception(e) raise e @@ -90,7 +90,7 @@ def add_messagetables(messages: list[MessageTable], session: Session): except Exception as e: logger.exception(e) raise e - return [Message(**message.model_dump()) for message in messages] + return [MessageRead.model_validate(message, from_attributes=True) for message in messages] def delete_messages(session_id: str): diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index c50dab880..c03c0cb71 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -1,5 +1,6 @@ from datetime import datetime, timezone from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional +from uuid import UUID from fastapi.encoders import jsonable_encoder from langchain_core.load import load @@ -31,7 +32,14 @@ class Message(Data): timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field( default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") ) - flow_id: Optional[str] = None + flow_id: Optional[str | UUID] = None + + @field_validator("flow_id", mode="before") + @classmethod + def validate_flow_id(cls, value): + if isinstance(value, UUID): + value = str(value) + return value @field_validator("files", mode="before") @classmethod diff --git a/tests/test_messages_endpoints.py b/tests/test_messages_endpoints.py index 62fd0c941..ee4021784 100644 --- a/tests/test_messages_endpoints.py +++ b/tests/test_messages_endpoints.py @@ -30,11 +30,9 @@ def created_messages(session): MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - messagetables = add_messagetables(messagetables, session) - messages_read = [ - MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables - ] - return messages_read + message_list = add_messagetables(messagetables, session) + + return message_list def test_delete_messages(client: TestClient, created_messages, logged_in_headers):