refactor: Update add_messages and add_messagetables functions to return Message objects
This commit is contained in:
parent
f56965b16f
commit
d1d5eb6e39
3 changed files with 15 additions and 9 deletions
|
|
@ -7,7 +7,7 @@ from sqlalchemy import delete
|
|||
from sqlmodel import Session, col, select
|
||||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non
|
|||
messages_models.append(MessageTable.from_message(msg, flow_id=flow_id))
|
||||
with session_scope() as session:
|
||||
messages_models = add_messagetables(messages_models, session)
|
||||
return messages_models
|
||||
return [Message(**message.model_dump()) for message in messages_models]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
@ -90,7 +90,7 @@ def add_messagetables(messages: list[MessageTable], session: Session):
|
|||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
return [Message(**message.model_dump()) for message in messages]
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in messages]
|
||||
|
||||
|
||||
def delete_messages(session_id: str):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langchain_core.load import load
|
||||
|
|
@ -31,7 +32,14 @@ class Message(Data):
|
|||
timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field(
|
||||
default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
flow_id: Optional[str] = None
|
||||
flow_id: Optional[str | UUID] = None
|
||||
|
||||
@field_validator("flow_id", mode="before")
|
||||
@classmethod
|
||||
def validate_flow_id(cls, value):
|
||||
if isinstance(value, UUID):
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -30,11 +30,9 @@ def created_messages(session):
|
|||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
messagetables = add_messagetables(messagetables, session)
|
||||
messages_read = [
|
||||
MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables
|
||||
]
|
||||
return messages_read
|
||||
message_list = add_messagetables(messagetables, session)
|
||||
|
||||
return message_list
|
||||
|
||||
|
||||
def test_delete_messages(client: TestClient, created_messages, logged_in_headers):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue