Merge branch 'dev' into fix_group_graph
This commit is contained in:
commit
4e5a1414f7
16 changed files with 411 additions and 140 deletions
2
.github/workflows/lint-js.yml
vendored
2
.github/workflows/lint-js.yml
vendored
|
|
@ -5,7 +5,7 @@ on:
|
|||
paths:
|
||||
- "src/frontend/**"
|
||||
merge_group:
|
||||
branches: [dev]
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
NODE_VERSION: "21"
|
||||
|
|
|
|||
2
.github/workflows/python_test.yml
vendored
2
.github/workflows/python_test.yml
vendored
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
poetry run python -m langflow run --host 127.0.0.1 --port 7860 --backend-only &
|
||||
SERVER_PID=$!
|
||||
# Wait for the server to start
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 5; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
# Terminate the server
|
||||
kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1)
|
||||
sleep 10 # give the server some time to terminate
|
||||
|
|
|
|||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
|
|
@ -62,7 +62,7 @@ jobs:
|
|||
python -m langflow run --host 127.0.0.1 --port 7860 &
|
||||
SERVER_PID=$!
|
||||
# Wait for the server to start
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
# Terminate the server
|
||||
kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1)
|
||||
sleep 10 # give the server some time to terminate
|
||||
|
|
@ -124,7 +124,7 @@ jobs:
|
|||
python -m langflow run --host 127.0.0.1 --port 7860 &
|
||||
SERVER_PID=$!
|
||||
# Wait for the server to start
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
timeout 120 bash -c 'until curl -f http://127.0.0.1:7860/api/v1/auto_login; do sleep 2; done' || (echo "Server did not start in time" && kill $SERVER_PID && exit 1)
|
||||
# Terminate the server
|
||||
kill $SERVER_PID || (echo "Failed to terminate the server" && exit 1)
|
||||
sleep 10 # give the server some time to terminate
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
"""Add message table
|
||||
|
||||
Revision ID: 325180f0c4e1
|
||||
Revises: 631faacf5da2
|
||||
Create Date: 2024-06-23 21:29:28.220100
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
from langflow.utils import migration
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "325180f0c4e1"
|
||||
down_revision: Union[str, None] = "631faacf5da2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if not migration.table_exists("message", conn):
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("timestamp", sa.DateTime(), nullable=False),
|
||||
sa.Column("sender", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("sender_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("session_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("flow_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("files", sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["flow_id"],
|
||||
["flow.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if migration.table_exists("message", conn):
|
||||
op.drop_table("message")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from langflow.services.deps import get_monitor_service
|
||||
from langflow.services.monitor.schema import (
|
||||
MessageModelRequest,
|
||||
MessageModelResponse,
|
||||
TransactionModelResponse,
|
||||
VertexBuildMapModel,
|
||||
)
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_monitor_service, get_session
|
||||
from langflow.services.monitor.schema import MessageModelResponse, TransactionModelResponse, VertexBuildMapModel
|
||||
from langflow.services.monitor.service import MonitorService
|
||||
|
||||
router = APIRouter(prefix="/monitor", tags=["Monitor"])
|
||||
|
|
@ -52,45 +52,58 @@ async def get_messages(
|
|||
sender: Optional[str] = Query(None),
|
||||
sender_name: Optional[str] = Query(None),
|
||||
order_by: Optional[str] = Query("timestamp"),
|
||||
monitor_service: MonitorService = Depends(get_monitor_service),
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
df = monitor_service.get_messages(
|
||||
flow_id=flow_id,
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
order_by=order_by,
|
||||
)
|
||||
dicts = df.to_dict(orient="records")
|
||||
return [MessageModelResponse(**d) for d in dicts]
|
||||
stmt = select(MessageTable)
|
||||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if session_id:
|
||||
stmt = stmt.where(MessageTable.session_id == session_id)
|
||||
if sender:
|
||||
stmt = stmt.where(MessageTable.sender == sender)
|
||||
if sender_name:
|
||||
stmt = stmt.where(MessageTable.sender_name == sender_name)
|
||||
if order_by:
|
||||
col = getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
messages = session.exec(stmt)
|
||||
return [MessageModelResponse.model_validate(d, from_attributes=True) for d in messages]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/messages", status_code=204)
|
||||
async def delete_messages(
|
||||
message_ids: List[int],
|
||||
monitor_service: MonitorService = Depends(get_monitor_service),
|
||||
message_ids: List[UUID],
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
monitor_service.delete_messages(message_ids=message_ids)
|
||||
session.exec(select(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/messages/{message_id}", response_model=MessageModelResponse)
|
||||
@router.put("/messages/{message_id}", response_model=MessageRead)
|
||||
async def update_message(
|
||||
message_id: int,
|
||||
message: MessageModelRequest,
|
||||
monitor_service: MonitorService = Depends(get_monitor_service),
|
||||
message_id: UUID,
|
||||
message: MessageUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
message_dict = message.model_dump(exclude_none=True)
|
||||
message_dict.pop("index", None)
|
||||
monitor_service.update_message(message_id=message_id, **message_dict) # type: ignore
|
||||
return MessageModelResponse(index=message_id, **message_dict)
|
||||
|
||||
db_message = session.get(MessageTable, message_id)
|
||||
if not db_message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
message_dict = message.model_dump(exclude_unset=True, exclude_none=True)
|
||||
db_message.sqlmodel_update(message_dict)
|
||||
session.add(db_message)
|
||||
session.commit()
|
||||
session.refresh(db_message)
|
||||
return db_message
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
|
@ -98,10 +111,16 @@ async def update_message(
|
|||
@router.delete("/messages/session/{session_id}", status_code=204)
|
||||
async def delete_messages_session(
|
||||
session_id: str,
|
||||
monitor_service: MonitorService = Depends(get_monitor_service),
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
monitor_service.delete_messages_session(session_id=session_id)
|
||||
session.exec( # type: ignore
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
session.commit()
|
||||
return {"message": "Messages deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
|
@ -137,4 +156,3 @@ async def get_transactions(
|
|||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -119,7 +119,11 @@ class LCModelComponent(Component):
|
|||
return status_message
|
||||
|
||||
def get_chat_result(
|
||||
self, runnable: LanguageModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None
|
||||
self,
|
||||
runnable: LanguageModel,
|
||||
stream: bool,
|
||||
input_value: str | Message,
|
||||
system_message: Optional[str] = None,
|
||||
):
|
||||
messages: list[Union[BaseMessage]] = []
|
||||
if not input_value and not system_message:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
import warnings
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.deps import get_monitor_service
|
||||
from langflow.services.monitor.schema import MessageModel
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
def get_messages(
|
||||
|
|
@ -14,6 +17,7 @@ def get_messages(
|
|||
session_id: Optional[str] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
order: Optional[str] = "DESC",
|
||||
flow_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -29,34 +33,29 @@ def get_messages(
|
|||
Returns:
|
||||
List[Data]: A list of Data objects representing the retrieved messages.
|
||||
"""
|
||||
monitor_service = get_monitor_service()
|
||||
messages_df = monitor_service.get_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
order_by=order_by,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
messages_read: list[Message] = []
|
||||
with session_scope() as session:
|
||||
stmt = select(MessageTable)
|
||||
if sender:
|
||||
stmt = stmt.where(MessageTable.sender == sender)
|
||||
if sender_name:
|
||||
stmt = stmt.where(MessageTable.sender_name == sender_name)
|
||||
if session_id:
|
||||
stmt = stmt.where(MessageTable.session_id == session_id)
|
||||
if flow_id:
|
||||
stmt = stmt.where(MessageTable.flow_id == flow_id)
|
||||
if order_by:
|
||||
if order == "DESC":
|
||||
col = getattr(MessageTable, order_by).desc()
|
||||
else:
|
||||
col = getattr(MessageTable, order_by).asc()
|
||||
stmt = stmt.order_by(col)
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
messages = session.exec(stmt)
|
||||
messages_read = [Message(**d.model_dump()) for d in messages]
|
||||
|
||||
messages: list[Message] = []
|
||||
# messages_df has a timestamp
|
||||
# it gets the last 5 messages, for example
|
||||
# but now they are ordered from most recent to least recent
|
||||
# so we need to reverse the order
|
||||
messages_df = messages_df[::-1] if order == "DESC" else messages_df
|
||||
for row in messages_df.itertuples():
|
||||
msg = Message(
|
||||
text=row.text,
|
||||
sender=row.sender,
|
||||
session_id=row.session_id,
|
||||
sender_name=row.sender_name,
|
||||
timestamp=row.timestamp,
|
||||
)
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
return messages
|
||||
return messages_read
|
||||
|
||||
|
||||
def add_messages(messages: Message | list[Message], flow_id: Optional[str] = None):
|
||||
|
|
@ -64,7 +63,6 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non
|
|||
Add a message to the monitor service.
|
||||
"""
|
||||
try:
|
||||
monitor_service = get_monitor_service()
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
|
|
@ -72,25 +70,29 @@ def add_messages(messages: Message | list[Message], flow_id: Optional[str] = Non
|
|||
types = ", ".join([str(type(message)) for message in messages])
|
||||
raise ValueError(f"The messages must be instances of Message. Found: {types}")
|
||||
|
||||
messages_models: list[MessageModel] = []
|
||||
messages_models: list[MessageTable] = []
|
||||
for msg in messages:
|
||||
if not msg.timestamp:
|
||||
msg.timestamp = monitor_service.get_timestamp()
|
||||
messages_models.append(MessageModel.from_message(msg, flow_id=flow_id))
|
||||
|
||||
for message_model in messages_models:
|
||||
try:
|
||||
monitor_service.add_message(message_model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message to monitor service: {e}")
|
||||
logger.exception(e)
|
||||
raise e
|
||||
return messages_models
|
||||
messages_models.append(MessageTable.from_message(msg, flow_id=flow_id))
|
||||
with session_scope() as session:
|
||||
messages_models = add_messagetables(messages_models, session)
|
||||
return [Message(**message.model_dump()) for message in messages_models]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def add_messagetables(messages: list[MessageTable], session: Session):
|
||||
for message in messages:
|
||||
try:
|
||||
session.add(message)
|
||||
session.commit()
|
||||
session.refresh(message)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
return [MessageRead.model_validate(message, from_attributes=True) for message in messages]
|
||||
|
||||
|
||||
def delete_messages(session_id: str):
|
||||
"""
|
||||
Delete messages from the monitor service based on the provided session ID.
|
||||
|
|
@ -98,8 +100,13 @@ def delete_messages(session_id: str):
|
|||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
monitor_service = get_monitor_service()
|
||||
monitor_service.delete_messages_session(session_id)
|
||||
with session_scope() as session:
|
||||
session.exec(
|
||||
delete(MessageTable)
|
||||
.where(col(MessageTable.session_id) == session_id)
|
||||
.execution_options(synchronize_session="fetch")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def store_message(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, AsyncIterator, Iterator, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langchain_core.load import load
|
||||
|
|
@ -31,7 +32,14 @@ class Message(Data):
|
|||
timestamp: Annotated[str, BeforeValidator(_timestamp_to_str)] = Field(
|
||||
default=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
flow_id: Optional[str] = None
|
||||
flow_id: Optional[str | UUID] = None
|
||||
|
||||
@field_validator("flow_id", mode="before")
|
||||
@classmethod
|
||||
def validate_flow_id(cls, value):
|
||||
if isinstance(value, UUID):
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from .api_key import ApiKey
|
||||
from .flow import Flow
|
||||
from .folder import Folder
|
||||
from .message import MessageTable
|
||||
from .user import User
|
||||
from .variable import Variable
|
||||
|
||||
__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder"]
|
||||
__all__ = ["Flow", "User", "ApiKey", "Variable", "Folder", "MessageTable"]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import re
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import emoji
|
||||
|
|
@ -17,6 +17,7 @@ from langflow.schema import Data
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.folder import Folder
|
||||
from langflow.services.database.models.message import MessageTable
|
||||
from langflow.services.database.models.user import User
|
||||
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ class Flow(FlowBase, table=True):
|
|||
user: "User" = Relationship(back_populates="flows")
|
||||
folder_id: Optional[UUID] = Field(default=None, foreign_key="folder.id", nullable=True, index=True)
|
||||
folder: Optional["Folder"] = Relationship(back_populates="flows")
|
||||
messages: List["MessageTable"] = Relationship(back_populates="flow")
|
||||
|
||||
def to_data(self):
|
||||
serialized = self.model_dump()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .model import MessageTable, MessageCreate, MessageRead, MessageUpdate
|
||||
|
||||
__all__ = ["MessageTable", "MessageCreate", "MessageRead", "MessageUpdate"]
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlmodel import JSON, Column, Field, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
||||
|
||||
class MessageBase(SQLModel):
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
sender: str
|
||||
sender_name: str
|
||||
session_id: str
|
||||
text: str
|
||||
files: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
def validate_files(cls, value):
|
||||
if not value:
|
||||
value = []
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: "Message", flow_id: str | None = None):
|
||||
# first check if the record has all the required fields
|
||||
if message.text is None or not message.sender or not message.sender_name:
|
||||
raise ValueError("The message does not have the required fields (text, sender, sender_name).")
|
||||
if isinstance(message.timestamp, str):
|
||||
timestamp = datetime.fromisoformat(message.timestamp)
|
||||
else:
|
||||
timestamp = message.timestamp
|
||||
return cls(
|
||||
sender=message.sender,
|
||||
sender_name=message.sender_name,
|
||||
text=message.text,
|
||||
session_id=message.session_id,
|
||||
files=message.files or [],
|
||||
timestamp=timestamp,
|
||||
flow_id=flow_id,
|
||||
)
|
||||
|
||||
|
||||
class MessageTable(MessageBase, table=True):
|
||||
__tablename__ = "message"
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
flow_id: Optional[UUID] = Field(default=None, foreign_key="flow.id")
|
||||
flow: "Flow" = Relationship(back_populates="messages")
|
||||
files: List[str] = Field(sa_column=Column(JSON))
|
||||
|
||||
# Needed for Column(JSON)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class MessageRead(MessageBase):
|
||||
id: UUID
|
||||
flow_id: Optional[UUID] = Field()
|
||||
|
||||
|
||||
class MessageCreate(MessageBase):
|
||||
pass
|
||||
|
||||
|
||||
class MessageUpdate(SQLModel):
|
||||
text: Optional[str] = None
|
||||
sender: Optional[str] = None
|
||||
sender_name: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
files: Optional[list[str]] = None
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
|
@ -81,8 +82,8 @@ class TransactionModelResponse(DefaultModel):
|
|||
|
||||
|
||||
class MessageModel(DefaultModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
flow_id: Optional[str] = Field(default=None, alias="flow_id")
|
||||
id: Optional[str | UUID] = Field(default=None)
|
||||
flow_id: Optional[UUID] = Field(default=None)
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
sender: str
|
||||
sender_name: str
|
||||
|
|
@ -127,16 +128,7 @@ class MessageModel(DefaultModel):
|
|||
|
||||
|
||||
class MessageModelResponse(MessageModel):
|
||||
index: Optional[int] = Field(default=None)
|
||||
|
||||
@field_validator("index", mode="before")
|
||||
def validate_id(cls, v):
|
||||
if isinstance(v, float):
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
return None
|
||||
return v
|
||||
pass
|
||||
|
||||
|
||||
class MessageModelRequest(MessageModel):
|
||||
|
|
|
|||
|
|
@ -3,14 +3,15 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import duckdb
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch
|
||||
from loguru import logger
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.monitor.utils import add_row_to_table, drop_and_create_table_if_schema_mismatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.service import SettingsService
|
||||
from langflow.services.monitor.schema import MessageModel, TransactionModel, VertexBuildModel
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
||||
class MonitorService(Service):
|
||||
|
|
@ -129,45 +130,6 @@ class MonitorService(Service):
|
|||
|
||||
return self.exec_query(query, read_only=False)
|
||||
|
||||
def add_message(self, message: "MessageModel"):
|
||||
self.add_row("messages", message)
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
flow_id: Optional[str] = None,
|
||||
sender: Optional[str] = None,
|
||||
sender_name: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
order_by: Optional[str] = "timestamp",
|
||||
order: Optional[str] = "DESC",
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
query = "SELECT index, flow_id, sender_name, sender, session_id, text, files, timestamp FROM messages"
|
||||
conditions = []
|
||||
if sender:
|
||||
conditions.append(f"sender = '{sender}'")
|
||||
if sender_name:
|
||||
conditions.append(f"sender_name = '{sender_name}'")
|
||||
if session_id:
|
||||
conditions.append(f"session_id = '{session_id}'")
|
||||
if flow_id:
|
||||
conditions.append(f"flow_id = '{flow_id}'")
|
||||
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
if order_by and order:
|
||||
# Make sure the order is from newest to oldest
|
||||
query += f" ORDER BY {order_by} {order.upper()}"
|
||||
|
||||
if limit is not None:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
with duckdb.connect(str(self.db_path), read_only=True) as conn:
|
||||
df = conn.execute(query).df()
|
||||
|
||||
return df
|
||||
|
||||
def get_transactions(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
|
|
|
|||
76
tests/test_messages_endpoints.py
Normal file
76
tests/test_messages_endpoints.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.memory import add_messagetables
|
||||
|
||||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead, MessageUpdate
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def created_message():
|
||||
with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
return message_read
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
message_list = add_messagetables(messagetables, session)
|
||||
|
||||
return message_list
|
||||
|
||||
|
||||
def test_delete_messages(client: TestClient, created_messages, logged_in_headers):
|
||||
response = client.request(
|
||||
"DELETE", "api/v1/monitor/messages", json=[str(msg.id) for msg in created_messages], headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 204, response.text
|
||||
assert response.reason_phrase == "No Content"
|
||||
|
||||
|
||||
def test_update_message(client: TestClient, logged_in_headers, created_message):
|
||||
message_id = created_message.id
|
||||
message_update = MessageUpdate(text="Updated content")
|
||||
response = client.put(
|
||||
f"api/v1/monitor/messages/{message_id}", json=message_update.model_dump(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
updated_message = MessageRead(**response.json())
|
||||
assert updated_message.text == "Updated content"
|
||||
|
||||
|
||||
def test_update_message_not_found(client: TestClient, logged_in_headers):
|
||||
non_existent_id = UUID("00000000-0000-0000-0000-000000000000")
|
||||
message_update = MessageUpdate(text="Updated content")
|
||||
response = client.put(
|
||||
f"api/v1/monitor/messages/{non_existent_id}", json=message_update.model_dump(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 404, response.text
|
||||
assert response.json()["detail"] == "Message not found"
|
||||
|
||||
|
||||
def test_delete_messages_session(client: TestClient, created_messages, logged_in_headers):
|
||||
session_id = "session_id2"
|
||||
response = client.delete(f"api/v1/monitor/messages/session/{session_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 204
|
||||
assert response.reason_phrase == "No Content"
|
||||
|
||||
assert len(created_messages) == 3
|
||||
response = client.get("api/v1/monitor/messages", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
72
tests/unit/test_messages.py
Normal file
72
tests/unit/test_messages.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import pytest
|
||||
|
||||
from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message
|
||||
from langflow.schema.message import Message
|
||||
|
||||
# Assuming you have these imports available
|
||||
from langflow.services.database.models.message import MessageCreate, MessageRead
|
||||
from langflow.services.database.models.message.model import MessageTable
|
||||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def created_message():
|
||||
with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
return message_read
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
messagetables = add_messagetables(messagetables, session)
|
||||
messages_read = [
|
||||
MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables
|
||||
]
|
||||
return messages_read
|
||||
|
||||
|
||||
def test_get_messages(session):
|
||||
add_messages(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
|
||||
add_messages(Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"))
|
||||
messages = get_messages(sender="User", session_id="session_id2", limit=2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].text == "Test message 1"
|
||||
assert messages[1].text == "Test message 2"
|
||||
|
||||
|
||||
def test_add_messages(session):
|
||||
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")
|
||||
messages = add_messages(message)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].text == "New Test message"
|
||||
|
||||
|
||||
def test_add_messagetables(session):
|
||||
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")]
|
||||
added_messages = add_messagetables(messages, session)
|
||||
assert len(added_messages) == 1
|
||||
assert added_messages[0].text == "New Test message"
|
||||
|
||||
|
||||
def test_delete_messages(session):
|
||||
session_id = "session_id2"
|
||||
delete_messages(session_id)
|
||||
messages = session.query(MessageTable).filter(MessageTable.session_id == session_id).all()
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
def test_store_message(session):
|
||||
message = Message(text="Stored message", sender="User", sender_name="User", session_id="stored_session_id")
|
||||
stored_messages = store_message(message)
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].text == "Stored message"
|
||||
Loading…
Add table
Add a link
Reference in a new issue