Add Variable model and remove Credential model

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-25 09:46:53 -03:00
commit 2ba305bc7a
18 changed files with 223 additions and 227 deletions

View file

@ -0,0 +1,63 @@
"""Replace Credential table with Variable
Revision ID: 1a110b568907
Revises: 63b9c451fd30
Create Date: 2024-03-25 09:40:02.743453
"""
from typing import Sequence, Union
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision: str = '1a110b568907'
down_revision: Union[str, None] = '63b9c451fd30'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "variable" not in table_names:
op.create_table('variable',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='fk_variable_user_id'),
sa.PrimaryKeyConstraint('id')
)
if "credential" in table_names:
op.drop_table('credential')
# ### end Alembic commands ###
def downgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "credential" not in table_names:
op.create_table('credential',
sa.Column('name', sa.VARCHAR(), nullable=True),
sa.Column('value', sa.VARCHAR(), nullable=True),
sa.Column('provider', sa.VARCHAR(), nullable=True),
sa.Column('user_id', sa.CHAR(length=32), nullable=False),
sa.Column('id', sa.CHAR(length=32), nullable=False),
sa.Column('created_at', sa.DATETIME(), nullable=False),
sa.Column('updated_at', sa.DATETIME(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='fk_credential_user_id'),
sa.PrimaryKeyConstraint('id')
)
if "variable" in table_names:
op.drop_table('variable')
# ### end Alembic commands ###

View file

@ -4,7 +4,6 @@ from fastapi import APIRouter
from langflow.api.v1 import (
api_key_router,
chat_router,
credentials_router,
endpoints_router,
files_router,
flows_router,
@ -26,6 +25,6 @@ router.include_router(flows_router)
router.include_router(users_router)
router.include_router(api_key_router)
router.include_router(login_router)
router.include_router(credentials_router)
router.include_router(variables_router)
router.include_router(files_router)
router.include_router(monitor_router)

View file

@ -1,6 +1,5 @@
from langflow.api.v1.api_key import router as api_key_router
from langflow.api.v1.chat import router as chat_router
from langflow.api.v1.credential import router as credentials_router
from langflow.api.v1.endpoints import router as endpoints_router
from langflow.api.v1.files import router as files_router
from langflow.api.v1.flows import router as flows_router
@ -9,6 +8,7 @@ from langflow.api.v1.monitor import router as monitor_router
from langflow.api.v1.store import router as store_router
from langflow.api.v1.users import router as users_router
from langflow.api.v1.validate import router as validate_router
from langflow.api.v1.variable import router as variables_router
__all__ = [
"chat_router",
@ -19,7 +19,7 @@ __all__ = [
"users_router",
"api_key_router",
"login_router",
"credentials_router",
"variables_router",
"monitor_router",
"files_router",
]

View file

@ -1,118 +0,0 @@
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.credential import (
Credential,
CredentialCreate,
CredentialRead,
CredentialUpdate,
)
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_settings_service
router = APIRouter(prefix="/credentials", tags=["Credentials"])
@router.post("/", response_model=CredentialRead, status_code=201)
def create_credential(
*,
session: Session = Depends(get_session),
credential: CredentialCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
):
"""Create a new credential."""
try:
# check if credential name already exists
credential_exists = session.exec(
select(Credential).where(
Credential.name == credential.name,
Credential.user_id == current_user.id,
)
).first()
if credential_exists:
raise HTTPException(status_code=400, detail="Credential name already exists")
credential_dict = credential.model_dump()
credential_dict["user_id"] = current_user.id
db_credential = Credential.model_validate(credential_dict)
if not db_credential.value:
raise HTTPException(status_code=400, detail="Credential value cannot be empty")
encrypted = auth_utils.encrypt_api_key(db_credential.value, settings_service=settings_service)
db_credential.value = encrypted
db_credential.user_id = current_user.id
session.add(db_credential)
session.commit()
session.refresh(db_credential)
return db_credential
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/", response_model=list[CredentialRead], status_code=200)
def read_credentials(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
):
"""Read all credentials."""
try:
credentials = session.exec(select(Credential).where(Credential.user_id == current_user.id)).all()
return credentials
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.patch("/{credential_id}", response_model=CredentialRead, status_code=200)
def update_credential(
*,
session: Session = Depends(get_session),
credential_id: UUID,
credential: CredentialUpdate,
current_user: User = Depends(get_current_active_user),
):
"""Update a credential."""
try:
db_credential = session.exec(
select(Credential).where(Credential.id == credential_id, Credential.user_id == current_user.id)
).first()
if not db_credential:
raise HTTPException(status_code=404, detail="Credential not found")
credential_data = credential.model_dump(exclude_unset=True)
for key, value in credential_data.items():
setattr(db_credential, key, value)
db_credential.updated_at = datetime.utcnow()
session.commit()
session.refresh(db_credential)
return db_credential
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.delete("/{credential_id}", status_code=204)
def delete_credential(
*,
session: Session = Depends(get_session),
credential_id: UUID,
current_user: User = Depends(get_current_active_user),
):
"""Delete a credential."""
try:
db_credential = session.exec(
select(Credential).where(Credential.id == credential_id, Credential.user_id == current_user.id)
).first()
if not db_credential:
raise HTTPException(status_code=404, detail="Credential not found")
session.delete(db_credential)
session.commit()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -0,0 +1,113 @@
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.services.database.models.variable import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service
router = APIRouter(prefix="/variables", tags=["Variables"])
@router.post("/", response_model=VariableRead, status_code=201)
def create_variable(
*,
session: Session = Depends(get_session),
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
):
"""Create a new variable."""
try:
# check if variable name already exists
variable_exists = session.exec(
select(Variable).where(
Variable.name == variable.name,
Variable.user_id == current_user.id,
)
).first()
if variable_exists:
raise HTTPException(status_code=400, detail="Variable name already exists")
variable_dict = variable.model_dump()
variable_dict["user_id"] = current_user.id
db_variable = Variable.model_validate(variable_dict)
if not db_variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=settings_service)
db_variable.value = encrypted
db_variable.user_id = current_user.id
session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/", response_model=list[VariableRead], status_code=200)
def read_variables(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
):
"""Read all variables."""
try:
variables = session.exec(select(Variable).where(Variable.user_id == current_user.id)).all()
return variables
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
def update_variable(
*,
session: Session = Depends(get_session),
variable_id: UUID,
variable: VariableUpdate,
current_user: User = Depends(get_current_active_user),
):
"""Update a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.utcnow()
session.commit()
session.refresh(db_variable)
return db_variable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.delete("/{variable_id}", status_code=204)
def delete_variable(
*,
session: Session = Depends(get_session),
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
):
"""Delete a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
session.delete(db_variable)
session.commit()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -1,6 +1,7 @@
import operator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, List, Optional,
Sequence, Union)
from uuid import UUID
import yaml
@ -11,14 +12,14 @@ from sqlmodel import select
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
extract_union_types_from_generic_alias)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
from langflow.services.deps import get_db_service, get_storage_service, get_variable_service
get_variable_service)
from langflow.services.storage.service import StorageService
from langflow.utils import validate
@ -372,30 +373,30 @@ class CustomComponent(Component):
def get_credential(name: str):
if hasattr(self, "_user_id") and not self._user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service() # Get service instance
variable_service = get_variable_service() # Get service instance
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return variable_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
def list_key_names(self):
"""
Lists the names of the credentials for the current user.
Lists the names of the variables for the current user.
Raises:
ValueError: If the user id is not set.
Returns:
List[str]: The names of the credentials for the current user.
List[str]: The names of the variables for the current user.
"""
if hasattr(self, "_user_id") and not self._user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service()
variable_service = get_variable_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return variable_service.list_variables(user_id=self._user_id, session=session)
def index(self, value: int = 0):
"""

View file

@ -1,15 +1,15 @@
from typing import TYPE_CHECKING
from langflow.services.credentials.service import CredentialService
from langflow.services.credentials.service import VariableService
from langflow.services.factory import ServiceFactory
if TYPE_CHECKING:
from langflow.services.settings.service import SettingsService
class CredentialServiceFactory(ServiceFactory):
class VariableServiceFactory(ServiceFactory):
def __init__(self):
super().__init__(CredentialService)
super().__init__(VariableService)
def create(self, settings_service: "SettingsService"):
return CredentialService(settings_service)
return VariableService(settings_service)

View file

@ -2,28 +2,27 @@ from typing import TYPE_CHECKING, Optional, Union
from uuid import UUID
from fastapi import Depends
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.credential.model import Credential
from langflow.services.database.models.variable.model import Variable
from langflow.services.deps import get_session
from sqlmodel import Session, select
if TYPE_CHECKING:
from langflow.services.settings.service import SettingsService
class CredentialService(Service):
name = "credential_service"
class VariableService(Service):
name = "variable_service"
def __init__(self, settings_service: "SettingsService"):
self.settings_service = settings_service
def get_credential(self, user_id: Union[UUID, str], name: str, session: Session = Depends(get_session)) -> str:
# we get the credential from the database
# credential = session.query(Credential).filter(Credential.user_id == user_id, Credential.name == name).first()
credential = session.exec(
select(Credential).where(Credential.user_id == user_id, Credential.name == name)
).first()
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
credential = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
# we decrypt the value
if not credential or not credential.value:
raise ValueError(f"{name} credential not found.")
@ -33,5 +32,5 @@ class CredentialService(Service):
def list_credentials(
self, user_id: Union[UUID, str], session: Session = Depends(get_session)
) -> list[Optional[str]]:
credentials = session.exec(select(Credential).where(Credential.user_id == user_id)).all()
credentials = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [credential.name for credential in credentials]

View file

@ -1,6 +1,6 @@
from .api_key import ApiKey
from .credential import Credential
from .flow import Flow
from .user import User
from .variable import Variable
__all__ = ["Flow", "User", "ApiKey", "Credential"]
__all__ = ["Flow", "User", "ApiKey", "Variable"]

View file

@ -1,3 +0,0 @@
from .model import Credential, CredentialCreate, CredentialRead, CredentialUpdate
__all__ = ["Credential", "CredentialCreate", "CredentialRead", "CredentialUpdate"]

View file

@ -1,42 +0,0 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from uuid import UUID, uuid4
from langflow.services.database.models.credential.schema import CredentialType
from sqlmodel import Field, Relationship, SQLModel
if TYPE_CHECKING:
from langflow.services.database.models.user import User
class CredentialBase(SQLModel):
name: Optional[str] = Field(None, description="Name of the credential")
value: Optional[str] = Field(None, description="Encrypted value of the credential")
provider: Optional[str] = Field(None, description="Provider of the credential (e.g OpenAI)")
class Credential(CredentialBase, table=True):
id: Optional[UUID] = Field(default_factory=uuid4, primary_key=True, description="Unique ID for the credential")
# name is unique per user
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time of the credential")
updated_at: Optional[datetime] = Field(None, description="Last update time of the credential")
# foreign key to user table
user_id: UUID = Field(description="User ID associated with this credential", foreign_key="user.id")
user: "User" = Relationship(back_populates="credentials")
class CredentialCreate(CredentialBase):
# AcceptedProviders is a custom Enum
provider: Optional[CredentialType] = Field(None, description="Provider of the credential (e.g OpenAI)")
class CredentialRead(SQLModel):
id: UUID
name: Optional[str] = Field(None, description="Name of the credential")
provider: Optional[str] = Field(None, description="Provider of the credential (e.g OpenAI)")
class CredentialUpdate(SQLModel):
id: UUID # Include the ID for updating
name: Optional[str] = Field(None, description="Name of the credential")
value: Optional[str] = Field(None, description="Encrypted value of the credential")

View file

@ -1,8 +0,0 @@
from enum import Enum
class CredentialType(str, Enum):
"""CredentialType is an Enum of the accepted providers"""
OPENAI_API_KEY = "OPENAI_API_KEY"
ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY"

View file

@ -6,7 +6,7 @@ from sqlmodel import Field, Relationship, SQLModel
if TYPE_CHECKING:
from langflow.services.database.models.api_key import ApiKey
from langflow.services.database.models.credential import Credential
from langflow.services.database.models.variable import Variable
from langflow.services.database.models.flow import Flow
@ -26,7 +26,7 @@ class User(SQLModel, table=True):
)
store_api_key: Optional[str] = Field(default=None, nullable=True)
flows: list["Flow"] = Relationship(back_populates="user")
credentials: list["Credential"] = Relationship(
variables: list["Variable"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "delete"},
)

View file

@ -0,0 +1,3 @@
from .model import Variable, VariableCreate, VariableRead, VariableUpdate
__all__ = ["Variable", "VariableCreate", "VariableRead", "VariableUpdate"]

View file

@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Generator
from langflow.services import ServiceType, service_manager
if TYPE_CHECKING:
from sqlmodel import Session
from langflow.services.cache.service import BaseCacheService
from langflow.services.chat.service import ChatService
from langflow.services.credentials.service import CredentialService
from langflow.services.credentials.service import VariableService
from langflow.services.database.service import DatabaseService
from langflow.services.monitor.service import MonitorService
from langflow.services.plugins.service import PluginService
@ -16,7 +18,6 @@ if TYPE_CHECKING:
from langflow.services.storage.service import StorageService
from langflow.services.store.service import StoreService
from langflow.services.task.service import TaskService
from sqlmodel import Session
def get_socket_service() -> "SocketIOService":
@ -27,8 +28,8 @@ def get_storage_service() -> "StorageService":
return service_manager.get(ServiceType.STORAGE_SERVICE) # type: ignore
def get_credential_service() -> "CredentialService":
return service_manager.get(ServiceType.CREDENTIAL_SERVICE) # type: ignore
def get_variable_service() -> "VariableService":
return service_manager.get(ServiceType.VARIABLE_SERVICE) # type: ignore
def get_plugins_service() -> "PluginService":

View file

@ -16,7 +16,7 @@ class ServiceType(str, Enum):
TASK_SERVICE = "task_service"
PLUGIN_SERVICE = "plugin_service"
STORE_SERVICE = "store_service"
CREDENTIAL_SERVICE = "credential_service"
VARIABLE_SERVICE = "variable_service"
STORAGE_SERVICE = "storage_service"
MONITOR_SERVICE = "monitor_service"
SOCKET_IO_SERVICE = "socket_io_service"

View file

@ -5,10 +5,7 @@ from langflow.services.auth.utils import create_super_user, verify_password
from langflow.services.database.utils import initialize_database
from langflow.services.manager import service_manager
from langflow.services.schema import ServiceType
from langflow.services.settings.constants import (
DEFAULT_SUPERUSER,
DEFAULT_SUPERUSER_PASSWORD,
)
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
from langflow.services.socket.utils import set_socketio_server
from .deps import get_db_service, get_session, get_settings_service
@ -22,9 +19,7 @@ def get_factories_and_deps():
from langflow.services.database import factory as database_factory
from langflow.services.monitor import factory as monitor_factory
from langflow.services.plugins import factory as plugins_factory
from langflow.services.session import (
factory as session_service_factory,
) # type: ignore
from langflow.services.session import factory as session_service_factory # type: ignore
from langflow.services.settings import factory as settings_factory
from langflow.services.socket import factory as socket_factory
from langflow.services.storage import factory as storage_factory
@ -54,7 +49,7 @@ def get_factories_and_deps():
(plugins_factory.PluginServiceFactory(), [ServiceType.SETTINGS_SERVICE]),
(store_factory.StoreServiceFactory(), [ServiceType.SETTINGS_SERVICE]),
(
credentials_factory.CredentialServiceFactory(),
credentials_factory.VariableServiceFactory(),
[ServiceType.SETTINGS_SERVICE],
),
(
@ -191,14 +186,7 @@ def initialize_session_service():
"""
Initialize the session manager.
"""
from langflow.services.cache import factory as cache_factory
from langflow.services.session import (
factory as session_service_factory,
) # type: ignore
initialize_settings_service()
service_manager.register_factory(cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE])
from langflow.services.session import factory as session_service_factory
service_manager.register_factory(
session_service_factory.SessionServiceFactory(),

View file

@ -863,7 +863,7 @@ export async function getGlobalVariables(): Promise<{
[key: string]: { id: string; provider: string };
}> {
const globalVariables = {};
(await api.get(`${BASE_URL_API}credentials/`)).data.forEach((element) => {
(await api.get(`${BASE_URL_API}variables/`)).data.forEach((element) => {
globalVariables[element.name] = {
id: element.id,
provider: element.provider,
@ -881,7 +881,7 @@ export async function registerGlobalVariable({
value: string;
provider?: string;
}): Promise<AxiosResponse<{ name: string; id: string; provider: string }>> {
return await api.post(`${BASE_URL_API}credentials/`, {
return await api.post(`${BASE_URL_API}variables/`, {
name,
value,
provider,
@ -889,7 +889,7 @@ export async function registerGlobalVariable({
}
export async function deleteGlobalVariable(id: string) {
api.delete(`${BASE_URL_API}credentials/${id}`);
api.delete(`${BASE_URL_API}variables/${id}`);
}
export async function updateGlobalVariable(
@ -897,7 +897,7 @@ export async function updateGlobalVariable(
value: string,
id: string
) {
api.patch(`${BASE_URL_API}credentials/${id}`, {
api.patch(`${BASE_URL_API}variables/${id}`, {
name,
value,
});