Add default_fields column to the variable table

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-29 09:53:56 -03:00
commit 366ce591cb
3 changed files with 52 additions and 4 deletions

View file

@ -0,0 +1,45 @@
"""Add default_fields column
Revision ID: 1f4d6df60295
Revises: 58b28437a398
Create Date: 2024-04-29 09:49:46.864145
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision: str = "1f4d6df60295"
down_revision: Union[str, None] = "58b28437a398"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
column_names = [column["name"] for column in inspector.get_columns("variable")]
with op.batch_alter_table("variable", schema=None) as batch_op:
if "default_fields" not in column_names:
batch_op.add_column(sa.Column("default_fields", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
column_names = [column["name"] for column in inspector.get_columns("variable")]
with op.batch_alter_table("variable", schema=None) as batch_op:
if "default_fields" in column_names:
batch_op.drop_column("default_fields")
# ### end Alembic commands ###

View file

@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
@ -85,7 +85,7 @@ def update_variable(
variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.utcnow()
db_variable.updated_at = datetime.now(timezone.utc)
session.commit()
session.refresh(db_variable)
return db_variable

View file

@ -1,8 +1,8 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional
from uuid import UUID, uuid4
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel, func
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
if TYPE_CHECKING:
from langflow.services.database.models.user.model import User
@ -15,6 +15,7 @@ def utc_now():
class VariableBase(SQLModel):
name: Optional[str] = Field(None, description="Name of the variable")
value: Optional[str] = Field(None, description="Encrypted value of the variable")
default_fields: Optional[List[str]] = Field(sa_column=Column(JSON))
type: Optional[str] = Field(None, description="Type of the variable")
@ -35,6 +36,7 @@ class Variable(VariableBase, table=True):
sa_column=Column(DateTime(timezone=True), nullable=True),
description="Last update time of the variable",
)
default_fields: Optional[List[str]] = Field(sa_column=Column(JSON))
# foreign key to user table
user_id: UUID = Field(description="User ID associated with this variable", foreign_key="user.id")
user: "User" = Relationship(back_populates="variables")
@ -56,3 +58,4 @@ class VariableUpdate(SQLModel):
id: UUID # Include the ID for updating
name: Optional[str] = Field(None, description="Name of the variable")
value: Optional[str] = Field(None, description="Encrypted value of the variable")
default_fields: Optional[List[str]] = Field(None, description="Default fields for the variable")