refactor: Update add_messages and add_messagetables functions to return Message objects

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-25 20:20:12 -03:00
commit d1d5eb6e39
3 changed files with 15 additions and 9 deletions

View file

@ -7,7 +7,7 @@ from sqlalchemy import delete
from sqlmodel import Session, col, select
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageTable
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.deps import session_scope
@ -75,7 +75,7 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non
messages_models.append(MessageTable.from_message(msg, flow_id=flow_id))
with session_scope() as session:
messages_models = add_messagetables(messages_models, session)
return messages_models
return [Message(**message.model_dump()) for message in messages_models]
except Exception as e:
logger.exception(e)
raise e
@ -90,7 +90,7 @@ def add_messagetables(messages: list[MessageTable], session: Session):
except Exception as e:
logger.exception(e)
raise e
return [Message(**message.model_dump()) for message in messages]
return [MessageRead.model_validate(message, from_attributes=True) for message in messages]
def delete_messages(session_id: str):

View file

@ -1,5 +1,6 @@
from datetime import datetime, timezone
from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from langchain_core.load import load
@ -31,7 +32,14 @@ class Message(Data):
timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field(
default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
)
flow_id: Optional[str] = None
flow_id: Optional[str | UUID] = None
@field_validator("flow_id", mode="before")
@classmethod
def validate_flow_id(cls, value):
if isinstance(value, UUID):
value = str(value)
return value
@field_validator("files", mode="before")
@classmethod

View file

@ -30,11 +30,9 @@ def created_messages(session):
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
messagetables = add_messagetables(messagetables, session)
messages_read = [
MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables
]
return messages_read
message_list = add_messagetables(messagetables, session)
return message_list
def test_delete_messages(client: TestClient, created_messages, logged_in_headers):