fix: Use AsyncSession for user management (#4491)

* Use AsyncSession for user management

* Simplify check_key

* Don't trigger blockbuster on settings service initialize

* Fix mypy

* Fix api key update_total_uses

* Fix auto-login

* Revert making CustomComponent.list_key_names async
This commit is contained in:
Christophe Bornet 2024-11-16 02:09:33 +01:00 committed by GitHub
commit 6573ca14cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 430 additions and 339 deletions

View file

@ -26,7 +26,7 @@ from sqlmodel import select
from langflow.logging.logger import configure, logger
from langflow.main import setup_app
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
from langflow.services.settings.constants import DEFAULT_SUPERUSER
from langflow.services.utils import initialize_services
@ -419,30 +419,35 @@ def superuser(
) -> None:
"""Create a superuser."""
configure(log_level=log_level)
initialize_services()
db_service = get_db_service()
with session_getter(db_service) as session:
from langflow.services.auth.utils import create_super_user
if create_super_user(db=session, username=username, password=password):
# Verify that the superuser was created
from langflow.services.database.models.user.model import User
async def _create_superuser():
await initialize_services()
async with async_session_getter(db_service) as session:
from langflow.services.auth.utils import create_super_user
if await create_super_user(db=session, username=username, password=password):
# Verify that the superuser was created
from langflow.services.database.models.user.model import User
stmt = select(User).where(User.username == username)
user: User = (await session.exec(stmt)).first()
if user is None or not user.is_superuser:
typer.echo("Superuser creation failed.")
return
# Now create the first folder for the user
result = await create_default_folder_if_it_doesnt_exist(session, user.id)
if result:
typer.echo("Default folder created successfully.")
else:
msg = "Could not create default folder."
raise RuntimeError(msg)
typer.echo("Superuser created successfully.")
user: User = session.exec(select(User).where(User.username == username)).first()
if user is None or not user.is_superuser:
typer.echo("Superuser creation failed.")
return
# Now create the first folder for the user
result = create_default_folder_if_it_doesnt_exist(session, user.id)
if result:
typer.echo("Default folder created successfully.")
else:
msg = "Could not create default folder."
raise RuntimeError(msg)
typer.echo("Superuser created successfully.")
typer.echo("Superuser creation failed.")
else:
typer.echo("Superuser creation failed.")
asyncio.run(_create_superuser())
# command to copy the langflow database from the cache to the current directory
@ -494,7 +499,7 @@ def migration(
):
raise typer.Abort
initialize_services(fix_migration=fix)
asyncio.run(initialize_services(fix_migration=fix))
db_service = get_db_service()
if not test:
db_service.run_migrations()
@ -515,18 +520,20 @@ def api_key(
None
"""
configure(log_level=log_level)
initialize_services()
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return
async def aapi_key():
await initialize_services()
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return None
async with async_session_scope() as session:
from langflow.services.database.models.user.model import User
superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first()
stmt = select(User).where(User.username == DEFAULT_SUPERUSER)
superuser = (await session.exec(stmt)).first()
if not superuser:
typer.echo(
"Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled."
@ -535,7 +542,8 @@ def api_key(
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import create_api_key, delete_api_key
api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first()
stmt = select(ApiKey).where(ApiKey.user_id == superuser.id)
api_key = (await session.exec(stmt)).first()
if api_key:
await delete_api_key(session, api_key.id)

View file

@ -3,7 +3,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Response
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils
@ -62,7 +62,7 @@ async def save_store_api_key(
api_key_request: ApiKeyCreateRequest,
response: Response,
current_user: CurrentActiveUser,
db: DbSession,
db: AsyncDbSession,
):
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
@ -74,7 +74,7 @@ async def save_store_api_key(
encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service)
current_user.store_api_key = encrypted
db.add(current_user)
db.commit()
await db.commit()
response.set_cookie(
"apikey_tkn_lflw",

View file

@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.api.utils import DbSession
from langflow.api.utils import AsyncDbSession
from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
@ -21,14 +21,14 @@ router = APIRouter(tags=["Login"])
@router.post("/login", response_model=Token)
def login_to_get_access_token(
async def login_to_get_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: DbSession,
db: AsyncDbSession,
):
auth_settings = get_settings_service().auth_settings
try:
user = authenticate_user(form_data.username, form_data.password, db)
user = await authenticate_user(form_data.username, form_data.password, db)
except Exception as exc:
if isinstance(exc, HTTPException):
raise
@ -38,7 +38,7 @@ def login_to_get_access_token(
) from exc
if user:
tokens = create_user_tokens(user_id=user.id, db=db, update_last_login=True)
tokens = await create_user_tokens(user_id=user.id, db=db, update_last_login=True)
response.set_cookie(
"refresh_token_lf",
tokens["refresh_token"],
@ -66,9 +66,9 @@ def login_to_get_access_token(
expires=None, # Set to None to make it a session cookie
domain=auth_settings.COOKIE_DOMAIN,
)
get_variable_service().initialize_user_variables(user.id, db)
await get_variable_service().initialize_user_variables(user.id, db)
# Create default folder for user if it doesn't exist
create_default_folder_if_it_doesnt_exist(db, user.id)
await create_default_folder_if_it_doesnt_exist(db, user.id)
return tokens
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -78,11 +78,11 @@ def login_to_get_access_token(
@router.get("/auto_login")
async def auto_login(response: Response, db: DbSession):
async def auto_login(response: Response, db: AsyncDbSession):
auth_settings = get_settings_service().auth_settings
if auth_settings.AUTO_LOGIN:
user_id, tokens = create_user_longterm_token(db)
user_id, tokens = await create_user_longterm_token(db)
response.set_cookie(
"access_token_lf",
tokens["access_token"],
@ -93,7 +93,7 @@ async def auto_login(response: Response, db: DbSession):
domain=auth_settings.COOKIE_DOMAIN,
)
user = get_user_by_id(db, user_id)
user = await get_user_by_id(db, user_id)
if user:
if user.store_api_key is None:
@ -124,14 +124,14 @@ async def auto_login(response: Response, db: DbSession):
async def refresh_token(
request: Request,
response: Response,
db: DbSession,
db: AsyncDbSession,
):
auth_settings = get_settings_service().auth_settings
token = request.cookies.get("refresh_token_lf")
if token:
tokens = create_refresh_token(token, db)
tokens = await create_refresh_token(token, db)
response.set_cookie(
"refresh_token_lf",
tokens["refresh_token"],

View file

@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
get_current_active_superuser,
@ -25,7 +25,7 @@ router = APIRouter(tags=["Users"], prefix="/users")
@router.post("/", response_model=UserRead, status_code=201)
async def add_user(
user: UserCreate,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Add a new user to the database."""
new_user = User.model_validate(user, from_attributes=True)
@ -33,13 +33,13 @@ async def add_user(
new_user.password = get_password_hash(user.password)
new_user.is_active = get_settings_service().auth_settings.NEW_USER_IS_ACTIVE
session.add(new_user)
session.commit()
session.refresh(new_user)
folder = create_default_folder_if_it_doesnt_exist(session, new_user.id)
await session.commit()
await session.refresh(new_user)
folder = await create_default_folder_if_it_doesnt_exist(session, new_user.id)
if not folder:
raise HTTPException(status_code=500, detail="Error creating default folder")
except IntegrityError as e:
session.rollback()
await session.rollback()
raise HTTPException(status_code=400, detail="This username is unavailable.") from e
return new_user
@ -58,14 +58,14 @@ async def read_all_users(
*,
skip: int = 0,
limit: int = 10,
session: DbSession,
session: AsyncDbSession,
) -> UsersResponse:
"""Retrieve a list of users from the database with pagination."""
query: SelectOfScalar = select(User).offset(skip).limit(limit)
users = session.exec(query).fetchall()
users = (await session.exec(query)).fetchall()
count_query = select(func.count()).select_from(User)
total_count = session.exec(count_query).first()
total_count = (await session.exec(count_query)).first()
return UsersResponse(
total_count=total_count,
@ -78,7 +78,7 @@ async def patch_user(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Update an existing user's data."""
update_password = bool(user_update.password)
@ -93,10 +93,10 @@ async def patch_user(
raise HTTPException(status_code=400, detail="You can't change your password here")
user_update.password = get_password_hash(user_update.password)
if user_db := get_user_by_id(session, user_id):
if user_db := await get_user_by_id(session, user_id):
if not update_password:
user_update.password = user_db.password
return update_user(user_db, user_update, session)
return await update_user(user_db, user_update, session)
raise HTTPException(status_code=404, detail="User not found")
@ -105,7 +105,7 @@ async def reset_password(
user_id: UUID,
user_update: UserUpdate,
user: CurrentActiveUser,
session: DbSession,
session: AsyncDbSession,
) -> User:
"""Reset a user's password."""
if user_id != user.id:
@ -117,8 +117,8 @@ async def reset_password(
raise HTTPException(status_code=400, detail="You can't use your current password")
new_password = get_password_hash(user_update.password)
user.password = new_password
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user

View file

@ -3,7 +3,7 @@ from uuid import UUID
from fastapi import APIRouter, HTTPException
from sqlalchemy.exc import NoResultFound
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_variable_service
from langflow.services.variable.constants import GENERIC_TYPE
@ -15,7 +15,7 @@ router = APIRouter(prefix="/variables", tags=["Variables"])
@router.post("/", response_model=VariableRead, status_code=201)
async def create_variable(
*,
session: DbSession,
session: AsyncDbSession,
variable: VariableCreate,
current_user: CurrentActiveUser,
):
@ -30,10 +30,10 @@ async def create_variable(
if not variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
if variable.name in await variable_service.list_variables(user_id=current_user.id, session=session):
raise HTTPException(status_code=400, detail="Variable name already exists")
try:
return variable_service.create_variable(
return await variable_service.create_variable(
user_id=current_user.id,
name=variable.name,
value=variable.value,
@ -50,7 +50,7 @@ async def create_variable(
@router.get("/", response_model=list[VariableRead], status_code=200)
async def read_variables(
*,
session: DbSession,
session: AsyncDbSession,
current_user: CurrentActiveUser,
):
"""Read all variables."""
@ -59,7 +59,7 @@ async def read_variables(
msg = "Variable service is not an instance of DatabaseVariableService"
raise TypeError(msg)
try:
return variable_service.get_all(user_id=current_user.id, session=session)
return await variable_service.get_all(user_id=current_user.id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@ -67,7 +67,7 @@ async def read_variables(
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
async def update_variable(
*,
session: DbSession,
session: AsyncDbSession,
variable_id: UUID,
variable: VariableUpdate,
current_user: CurrentActiveUser,
@ -78,7 +78,7 @@ async def update_variable(
msg = "Variable service is not an instance of DatabaseVariableService"
raise TypeError(msg)
try:
return variable_service.update_variable_fields(
return await variable_service.update_variable_fields(
user_id=current_user.id,
variable_id=variable_id,
variable=variable,
@ -94,13 +94,13 @@ async def update_variable(
@router.delete("/{variable_id}", status_code=204)
async def delete_variable(
*,
session: DbSession,
session: AsyncDbSession,
variable_id: UUID,
current_user: CurrentActiveUser,
) -> None:
"""Delete a variable."""
variable_service = get_variable_service()
try:
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
await variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

View file

@ -1,4 +1,5 @@
# Add helper functions for each event type
import asyncio
from collections.abc import AsyncIterator
from time import perf_counter
from typing import Any, Protocol
@ -249,7 +250,7 @@ async def process_agent_events(
agent_message.properties.icon = "Bot"
agent_message.properties.state = "partial"
# Store the initial message
agent_message = send_message_method(message=agent_message)
agent_message = await asyncio.to_thread(send_message_method, message=agent_message)
try:
# Create a mapping of run_ids to tool contents
tool_blocks_map: dict[str, ToolContent] = {}

View file

@ -448,7 +448,7 @@ class CustomComponent(BaseComponent):
variable_service = get_variable_service()
with session_scope() as session:
return variable_service.list_variables(user_id=self.user_id, session=session)
return variable_service.list_variables_sync(user_id=self.user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable.

View file

@ -28,6 +28,7 @@ from langflow.services.database.models.folder.utils import (
)
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.deps import (
async_session_scope,
get_settings_service,
get_storage_service,
get_variable_service,
@ -519,7 +520,7 @@ def _is_valid_uuid(val):
return str(uuid_obj) == val
def load_flows_from_directory() -> None:
async def load_flows_from_directory() -> None:
"""On langflow startup, this loads all flows from the directory specified in the settings.
All flows are uploaded into the default folder for the superuser.
@ -533,8 +534,8 @@ def load_flows_from_directory() -> None:
logger.warning("AUTO_LOGIN is disabled, not loading flows from directory")
return
with session_scope() as session:
user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
async with async_session_scope() as session:
user = await get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
if user is None:
msg = "Superuser not found in the database"
raise NoResultFound(msg)
@ -553,7 +554,7 @@ def load_flows_from_directory() -> None:
flow["id"] = no_json_name
flow_id = flow.get("id")
existing = find_existing_flow(session, flow_id, flow_endpoint_name)
existing = await find_existing_flow(session, flow_id, flow_endpoint_name)
if existing:
logger.debug(f"Found existing flow: {existing.name}")
logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}")
@ -585,15 +586,15 @@ def load_flows_from_directory() -> None:
session.add(flow)
def find_existing_flow(session, flow_id, flow_endpoint_name):
async def find_existing_flow(session, flow_id, flow_endpoint_name):
if flow_endpoint_name:
logger.debug(f"flow_endpoint_name: {flow_endpoint_name}")
stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name)
if existing := session.exec(stmt).first():
if existing := (await session.exec(stmt)).first():
logger.debug(f"Found existing flow by endpoint name: {existing.name}")
return existing
stmt = select(Flow).where(Flow.id == flow_id)
if existing := session.exec(stmt).first():
if existing := (await session.exec(stmt)).first():
logger.debug(f"Found existing flow by id: {flow_id}")
return existing
return None
@ -645,7 +646,7 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None:
)
def initialize_super_user_if_needed() -> None:
async def initialize_super_user_if_needed() -> None:
settings_service = get_settings_service()
if not settings_service.auth_settings.AUTO_LOGIN:
return
@ -655,8 +656,8 @@ def initialize_super_user_if_needed() -> None:
msg = "SUPERUSER and SUPERUSER_PASSWORD must be set in the settings if AUTO_LOGIN is true."
raise ValueError(msg)
with session_scope() as session:
super_user = create_super_user(db=session, username=username, password=password)
get_variable_service().initialize_user_variables(super_user.id, session)
create_default_folder_if_it_doesnt_exist(session, super_user.id)
logger.info("Super user initialized")
async with async_session_scope() as async_session:
super_user = await create_super_user(db=async_session, username=username, password=password)
await get_variable_service().initialize_user_variables(super_user.id, async_session)
await create_default_folder_if_it_doesnt_exist(async_session, super_user.id)
logger.info("Super user initialized")

View file

@ -89,11 +89,6 @@ class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
def get_lifespan(*, fix_migration=False, version=None):
telemetry_service = get_telemetry_service()
def _initialize():
initialize_services(fix_migration=fix_migration)
setup_llm_caching()
initialize_super_user_if_needed()
@asynccontextmanager
async def lifespan(_app: FastAPI):
configure(async_file=True)
@ -104,12 +99,13 @@ def get_lifespan(*, fix_migration=False, version=None):
else:
rprint("[bold green]Starting Langflow...[/bold green]")
try:
await asyncio.to_thread(_initialize)
await initialize_services(fix_migration=fix_migration)
await asyncio.to_thread(setup_llm_caching)
await initialize_super_user_if_needed()
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
telemetry_service.start()
await asyncio.to_thread(load_flows_from_directory)
await load_flows_from_directory()
yield
except Exception as exc:

View file

@ -1,10 +1,9 @@
import asyncio
import base64
import random
import warnings
from collections.abc import Coroutine
from datetime import datetime, timedelta, timezone
from typing import Annotated
from typing import TYPE_CHECKING, Annotated
from uuid import UUID
from cryptography.fernet import Fernet
@ -12,16 +11,18 @@ from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
from jose import JWTError, jwt
from loguru import logger
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.websockets import WebSocket
from langflow.services.database.models.api_key.crud import check_key
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_db_service, get_session, get_settings_service
from langflow.services.deps import get_async_session, get_db_service, get_settings_service
from langflow.services.settings.service import SettingsService
if TYPE_CHECKING:
from langflow.services.database.models.api_key.model import ApiKey
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
API_KEY_NAME = "x-api-key"
@ -33,14 +34,14 @@ MINIMUM_KEY_LENGTH = 32
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
def api_key_security(
async def api_key_security(
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
) -> UserRead | None:
settings_service = get_settings_service()
result: ApiKey | User | None = None
result: ApiKey | User | None
with get_db_service().with_session() as db:
async with get_db_service().with_async_session() as db:
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
@ -49,7 +50,7 @@ def api_key_security(
detail="Missing first superuser credentials",
)
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
elif not query_param and not header_param:
raise HTTPException(
@ -58,18 +59,16 @@ def api_key_security(
)
elif query_param:
result = check_key(db, query_param)
result = await check_key(db, query_param)
else:
result = check_key(db, header_param)
result = await check_key(db, header_param)
if not result:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
if isinstance(result, ApiKey):
return UserRead.model_validate(result.user, from_attributes=True)
if isinstance(result, User):
return UserRead.model_validate(result, from_attributes=True)
msg = "Invalid result type"
@ -80,11 +79,11 @@ async def get_current_user(
token: Annotated[str, Security(oauth2_login)],
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[Session, Depends(get_session)],
db: Annotated[AsyncSession, Depends(get_async_session)],
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
user = await asyncio.to_thread(api_key_security, query_param, header_param)
user = await api_key_security(query_param, header_param)
if user:
return user
@ -95,8 +94,8 @@ async def get_current_user(
async def get_current_user_by_jwt(
token: Annotated[str, Depends(oauth2_login)],
db: Annotated[Session, Depends(get_session)],
token: str,
db: AsyncSession,
) -> User:
settings_service = get_settings_service()
@ -144,7 +143,7 @@ async def get_current_user_by_jwt(
headers={"WWW-Authenticate": "Bearer"},
) from e
user = get_user_by_id(db, user_id)
user = await get_user_by_id(db, user_id)
if user is None or not user.is_active:
logger.info("User not found or inactive.")
raise HTTPException(
@ -157,7 +156,7 @@ async def get_current_user_by_jwt(
async def get_current_user_for_websocket(
websocket: WebSocket,
db: Annotated[Session, Depends(get_session)],
db: Annotated[AsyncSession, Depends(get_async_session)],
query_param: Annotated[str, Security(api_key_query)],
) -> User | None:
token = websocket.query_params.get("token")
@ -165,7 +164,7 @@ async def get_current_user_for_websocket(
if token:
return await get_current_user_by_jwt(token, db)
if api_key:
return await asyncio.to_thread(api_key_security, api_key, query_param)
return await api_key_security(api_key, query_param)
return None
@ -207,12 +206,12 @@ def create_token(data: dict, expires_delta: timedelta):
)
def create_super_user(
async def create_super_user(
username: str,
password: str,
db: Session,
db: AsyncSession,
) -> User:
super_user = get_user_by_username(db, username)
super_user = await get_user_by_username(db, username)
if not super_user:
super_user = User(
@ -224,17 +223,17 @@ def create_super_user(
)
db.add(super_user)
db.commit()
db.refresh(super_user)
await db.commit()
await db.refresh(super_user)
return super_user
def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]:
settings_service = get_settings_service()
username = settings_service.auth_settings.SUPERUSER
super_user = get_user_by_username(db, username)
super_user = await get_user_by_username(db, username)
if not super_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created")
access_token_expires_longterm = timedelta(days=365)
@ -244,7 +243,7 @@ def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
)
# Update: last_login_at
update_user_last_login_at(super_user.id, db)
await update_user_last_login_at(super_user.id, db)
return super_user.id, {
"access_token": access_token,
@ -270,7 +269,7 @@ def get_user_id_from_token(token: str) -> UUID:
return UUID(int=0)
def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = False) -> dict:
async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict:
settings_service = get_settings_service()
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
@ -287,7 +286,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool =
# Update: last_login_at
if update_last_login:
update_user_last_login_at(user_id, db)
await update_user_last_login_at(user_id, db)
return {
"access_token": access_token,
@ -296,7 +295,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool =
}
def create_refresh_token(refresh_token: str, db: Session):
async def create_refresh_token(refresh_token: str, db: AsyncSession):
settings_service = get_settings_service()
try:
@ -314,12 +313,12 @@ def create_refresh_token(refresh_token: str, db: Session):
if user_id is None or token_type == "":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
user_exists = get_user_by_id(db, user_id)
user_exists = await get_user_by_id(db, user_id)
if user_exists is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
return create_user_tokens(user_id, db)
return await create_user_tokens(user_id, db)
except JWTError as e:
logger.exception("JWT decoding error")
@ -329,8 +328,8 @@ def create_refresh_token(refresh_token: str, db: Session):
) from e
def authenticate_user(username: str, password: str, db: Session) -> User | None:
user = get_user_by_username(db, username)
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None:
user = await get_user_by_username(db, username)
if not user:
return None

View file

@ -1,13 +1,17 @@
import asyncio
import datetime
import secrets
import threading
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.database.models import User
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service
if TYPE_CHECKING:
from sqlmodel.sql.expression import SelectOfScalar
@ -47,33 +51,29 @@ async def delete_api_key(session: AsyncSession, api_key_id: UUID) -> None:
await session.commit()
def check_key(session: Session, api_key: str) -> ApiKey | None:
update_total_uses_tasks: set[asyncio.Task] = set()
async def check_key(session: AsyncSession, api_key: str) -> User | None:
"""Check if the API key is valid."""
query: SelectOfScalar = select(ApiKey).where(ApiKey.api_key == api_key)
api_key_object: ApiKey | None = session.exec(query).first()
query: SelectOfScalar = select(ApiKey).options(selectinload(ApiKey.user)).where(ApiKey.api_key == api_key)
api_key_object: ApiKey | None = (await session.exec(query)).first()
if api_key_object is not None:
threading.Thread(
target=update_total_uses,
args=(
session,
api_key_object,
),
).start()
return api_key_object
task = asyncio.create_task(update_total_uses(api_key_object.id))
task.add_done_callback(update_total_uses_tasks.discard)
update_total_uses_tasks.add(task)
return api_key_object.user
return None
def update_total_uses(session, api_key: ApiKey):
async def update_total_uses(api_key_id: UUID):
"""Update the total uses and last used at."""
# This is running in a separate thread to avoid slowing down the request
# but session is not thread safe so we need to create a new session
with Session(session.get_bind()) as new_session:
new_api_key = new_session.get(ApiKey, api_key.id)
async with async_session_getter(get_db_service()) as session:
new_api_key = await session.get(ApiKey, api_key_id)
if new_api_key is None:
msg = "API Key not found"
raise ValueError(msg)
new_api_key.total_uses += 1
new_api_key.last_used_at = datetime.datetime.now(datetime.timezone.utc)
new_session.add(new_api_key)
new_session.commit()
return new_api_key
session.add(new_api_key)
await session.commit()

View file

@ -1,6 +1,7 @@
from uuid import UUID
from sqlmodel import Session, and_, select, update
from sqlmodel import and_, select, update
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.database.models.flow.model import Flow
@ -8,8 +9,9 @@ from .constants import DEFAULT_FOLDER_DESCRIPTION, DEFAULT_FOLDER_NAME
from .model import Folder
def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
folder = session.exec(select(Folder).where(Folder.user_id == user_id)).first()
async def create_default_folder_if_it_doesnt_exist(session: AsyncSession, user_id: UUID):
stmt = select(Folder).where(Folder.user_id == user_id)
folder = (await session.exec(stmt)).first()
if not folder:
folder = Folder(
name=DEFAULT_FOLDER_NAME,
@ -17,9 +19,9 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
description=DEFAULT_FOLDER_DESCRIPTION,
)
session.add(folder)
session.commit()
session.refresh(folder)
session.exec(
await session.commit()
await session.refresh(folder)
await session.exec(
update(Flow)
.where(
and_(
@ -29,12 +31,14 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
)
.values(folder_id=folder.id)
)
session.commit()
await session.commit()
return folder
def get_default_folder_id(session: Session, user_id: UUID):
folder = session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id)).first()
async def get_default_folder_id(session: AsyncSession, user_id: UUID):
folder = (
await session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id))
).first()
if not folder:
folder = create_default_folder_if_it_doesnt_exist(session, user_id)
folder = await create_default_folder_if_it_doesnt_exist(session, user_id)
return folder.id

View file

@ -5,20 +5,23 @@ from fastapi import HTTPException, status
from loguru import logger
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.database.models.user.model import User, UserUpdate
def get_user_by_username(db: Session, username: str) -> User | None:
return db.exec(select(User).where(User.username == username)).first()
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
stmt = select(User).where(User.username == username)
return (await db.exec(stmt)).first()
def get_user_by_id(db: Session, user_id: UUID) -> User | None:
return db.exec(select(User).where(User.id == user_id)).first()
async def get_user_by_id(db: AsyncSession, user_id: UUID) -> User | None:
stmt = select(User).where(User.id == user_id)
return (await db.exec(stmt)).first()
def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
async def update_user(user_db: User | None, user: UserUpdate, db: AsyncSession) -> User:
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
@ -40,18 +43,18 @@ def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
flag_modified(user_db, "updated_at")
try:
db.commit()
await db.commit()
except IntegrityError as e:
db.rollback()
await db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
def update_user_last_login_at(user_id: UUID, db: Session):
async def update_user_last_login_at(user_id: UUID, db: AsyncSession):
try:
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc))
user = get_user_by_id(db, user_id)
return update_user(user, user_data, db)
user = await get_user_by_id(db, user_id)
return await update_user(user, user_data, db)
except Exception: # noqa: BLE001
logger.opt(exception=True).debug("Error updating user last login at")

View file

@ -142,28 +142,29 @@ class DatabaseService(Service):
@asynccontextmanager
async def with_async_session(self):
async with AsyncSession(self.async_engine) as session:
async with AsyncSession(self.async_engine, expire_on_commit=False) as session:
yield session
def migrate_flows_if_auto_login(self) -> None:
async def migrate_flows_if_auto_login(self) -> None:
# if auto_login is enabled, we need to migrate the flows
# to the default superuser if they don't have a user id
# associated with them
settings_service = get_settings_service()
if settings_service.auth_settings.AUTO_LOGIN:
with self.with_session() as session:
flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all()
async with self.with_async_session() as session:
stmt = select(models.Flow).where(models.Flow.user_id is None)
flows = (await session.exec(stmt)).all()
if flows:
logger.debug("Migrating flows to default superuser")
username = settings_service.auth_settings.SUPERUSER
user = get_user_by_username(session, username)
user = await get_user_by_username(session, username)
if not user:
logger.error("Default superuser not found")
msg = "Default superuser not found"
raise RuntimeError(msg)
for flow in flows:
flow.user_id = user.id
session.commit()
await session.commit()
logger.debug("Flows migrated successfully")
def check_schema_health(self) -> bool:
@ -346,20 +347,15 @@ class DatabaseService(Service):
logger.debug("Database and tables created successfully")
def _teardown(self) -> None:
async def teardown(self) -> None:
logger.debug("Tearing down database")
try:
settings_service = get_settings_service()
# remove the default superuser if auto_login is enabled
# using the SUPERUSER to get the user
with self.with_session() as session:
teardown_superuser(settings_service, session)
async with self.with_async_session() as session:
await teardown_superuser(settings_service, session)
except Exception: # noqa: BLE001
logger.exception("Error tearing down database")
self.engine.dispose()
async def teardown(self) -> None:
await asyncio.to_thread(self._teardown)
await self.async_engine.dispose()
await asyncio.to_thread(self.engine.dispose)

View file

@ -1,12 +1,13 @@
from __future__ import annotations
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING
from alembic.util.exc import CommandError
from loguru import logger
from sqlmodel import Session, text
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
@ -70,6 +71,19 @@ def session_getter(db_service: DatabaseService):
session.close()
@asynccontextmanager
async def async_session_getter(db_service: DatabaseService):
try:
session = AsyncSession(db_service.async_engine)
yield session
except Exception:
logger.exception("Session rollback because of exception")
await session.rollback()
raise
finally:
await session.close()
@dataclass
class Result:
name: str

View file

@ -1,7 +1,8 @@
import asyncio
from loguru import logger
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.auth.utils import create_super_user, verify_password
from langflow.services.cache.factory import CacheServiceFactory
@ -9,13 +10,14 @@ from langflow.services.database.utils import initialize_database
from langflow.services.schema import ServiceType
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
from .deps import get_db_service, get_service, get_session, get_settings_service
from .deps import get_db_service, get_service, get_settings_service
def get_or_create_super_user(session: Session, username, password, is_default):
async def get_or_create_super_user(session: AsyncSession, username, password, is_default):
from langflow.services.database.models.user.model import User
user = session.exec(select(User).where(User.username == username)).first()
stmt = select(User).where(User.username == username)
user = (await session.exec(stmt)).first()
if user and user.is_superuser:
return None # Superuser already exists
@ -51,7 +53,7 @@ def get_or_create_super_user(session: Session, username, password, is_default):
else:
logger.debug("Creating superuser.")
try:
return create_super_user(username, password, db=session)
return await create_super_user(username, password, db=session)
except Exception as exc: # noqa: BLE001
if "UNIQUE constraint failed: user.username" in str(exc):
# This is to deal with workers running this
@ -62,12 +64,12 @@ def get_or_create_super_user(session: Session, username, password, is_default):
logger.opt(exception=True).debug("Error creating superuser.")
def setup_superuser(settings_service, session: Session) -> None:
async def setup_superuser(settings_service, session: AsyncSession) -> None:
if settings_service.auth_settings.AUTO_LOGIN:
logger.debug("AUTO_LOGIN is set to True. Creating default superuser.")
else:
# Remove the default superuser if it exists
teardown_superuser(settings_service, session)
await teardown_superuser(settings_service, session)
username = settings_service.auth_settings.SUPERUSER
password = settings_service.auth_settings.SUPERUSER_PASSWORD
@ -75,7 +77,9 @@ def setup_superuser(settings_service, session: Session) -> None:
is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD)
try:
user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default)
user = await get_or_create_super_user(
session=session, username=username, password=password, is_default=is_default
)
if user is not None:
logger.debug("Superuser created successfully.")
except Exception as exc:
@ -86,7 +90,7 @@ def setup_superuser(settings_service, session: Session) -> None:
settings_service.auth_settings.reset_credentials()
def teardown_superuser(settings_service, session) -> None:
async def teardown_superuser(settings_service, session: AsyncSession) -> None:
"""Teardown the superuser."""
# If AUTO_LOGIN is True, we will remove the default superuser
# from the database.
@ -97,30 +101,27 @@ def teardown_superuser(settings_service, session) -> None:
username = DEFAULT_SUPERUSER
from langflow.services.database.models.user.model import User
user = session.exec(select(User).where(User.username == username)).first()
stmt = select(User).where(User.username == username)
user = (await session.exec(stmt)).first()
# Check if super was ever logged in, if not delete it
# if it has logged in, it means the user is using it to login
if user and user.is_superuser is True and not user.last_login_at:
session.delete(user)
session.commit()
await session.delete(user)
await session.commit()
logger.debug("Default superuser removed successfully.")
except Exception as exc:
logger.exception(exc)
session.rollback()
await session.rollback()
msg = "Could not remove default superuser."
raise RuntimeError(msg) from exc
def _teardown_superuser():
with get_db_service().with_session() as session:
teardown_superuser(get_settings_service(), session)
async def teardown_services() -> None:
"""Teardown all the services."""
try:
await asyncio.to_thread(_teardown_superuser)
async with get_db_service().with_async_session() as session:
await teardown_superuser(get_settings_service(), session)
except Exception as exc: # noqa: BLE001
logger.exception(exc)
try:
@ -156,15 +157,16 @@ def initialize_session_service() -> None:
)
def initialize_services(*, fix_migration: bool = False) -> None:
async def initialize_services(*, fix_migration: bool = False) -> None:
"""Initialize all the services needed."""
# Test cache connection
get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory())
# Setup the superuser
initialize_database(fix_migration=fix_migration)
setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), next(get_session()))
await asyncio.to_thread(initialize_database, fix_migration=fix_migration)
async with get_db_service().with_async_session() as session:
await setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), session)
try:
get_db_service().migrate_flows_if_auto_login()
await get_db_service().migrate_flows_if_auto_login()
except Exception as exc:
msg = "Error migrating flows"
logger.exception(msg)

View file

@ -2,6 +2,7 @@ import abc
from uuid import UUID
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable
@ -13,7 +14,7 @@ class VariableService(Service):
name = "variable_service"
@abc.abstractmethod
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
"""Initialize user variables.
Args:
@ -36,7 +37,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
"""List all variables.
Args:
@ -48,7 +49,19 @@ class VariableService(Service):
"""
@abc.abstractmethod
def update_variable(self, user_id: UUID | str, name: str, value: str, session: Session) -> Variable:
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
"""List all variables.
Args:
user_id: The user ID.
session: The database session.
Returns:
A list of variable names.
"""
@abc.abstractmethod
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:
"""Update a variable.
Args:
@ -62,7 +75,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def delete_variable(self, user_id: UUID | str, name: str, session: Session) -> None:
async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None:
"""Delete a variable.
Args:
@ -75,7 +88,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None:
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
"""Delete a variable by ID.
Args:
@ -85,7 +98,7 @@ class VariableService(Service):
"""
@abc.abstractmethod
def create_variable(
async def create_variable(
self,
user_id: UUID | str,
name: str,
@ -93,7 +106,7 @@ class VariableService(Service):
*,
default_fields: list[str],
_type: str,
session: Session,
session: AsyncSession,
) -> Variable:
"""Create a variable.

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import os
from typing import TYPE_CHECKING
@ -17,6 +18,7 @@ if TYPE_CHECKING:
from uuid import UUID
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.settings.service import SettingsService
@ -28,7 +30,7 @@ class KubernetesSecretService(VariableService, Service):
self.kubernetes_secrets = KubernetesSecretManager()
@override
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
# Check for environment variables that should be stored in the database
should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not"
logger.info(f"{should_or_should_not} store environment variables in the kubernetes.")
@ -45,7 +47,8 @@ class KubernetesSecretService(VariableService, Service):
try:
secret_name = encode_user_id(user_id)
self.kubernetes_secrets.create_secret(
await asyncio.to_thread(
self.kubernetes_secrets.create_secret,
name=secret_name,
data=variables,
)
@ -75,12 +78,13 @@ class KubernetesSecretService(VariableService, Service):
msg = f"user_id {user_id} variable name {name} not found."
raise ValueError(msg)
@override
def get_variable(
self,
user_id: UUID | str,
name: str,
field: str,
_session: Session,
session: Session,
) -> str:
secret_name = encode_user_id(user_id)
key, value = self.resolve_variable(secret_name, user_id, name)
@ -92,10 +96,11 @@ class KubernetesSecretService(VariableService, Service):
raise TypeError(msg)
return value
def list_variables(
@override
def list_variables_sync(
self,
user_id: UUID | str,
_session: Session,
session: Session,
) -> list[str | None]:
variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id))
if not variables:
@ -109,28 +114,49 @@ class KubernetesSecretService(VariableService, Service):
names.append(key)
return names
def update_variable(
@override
async def list_variables(
self,
user_id: UUID | str,
session: AsyncSession,
) -> list[str | None]:
return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session)
def _update_variable(
self,
user_id: UUID | str,
name: str,
value: str,
_session: Session,
):
secret_name = encode_user_id(user_id)
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})
def delete_variable(self, user_id: UUID | str, name: str, _session: Session) -> None:
secret_name = encode_user_id(user_id)
@override
async def update_variable(
self,
user_id: UUID | str,
name: str,
value: str,
session: AsyncSession,
):
return await asyncio.to_thread(self._update_variable, user_id, name, value)
def _delete_variable(self, user_id: UUID | str, name: str) -> None:
secret_name = encode_user_id(user_id)
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key)
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, _session: Session) -> None:
self.delete_variable(user_id, _session, str(variable_id))
@override
async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None:
await asyncio.to_thread(self._delete_variable, user_id, name)
@override
def create_variable(
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, session: AsyncSession) -> None:
await self.delete_variable(user_id, str(variable_id), session)
@override
async def create_variable(
self,
user_id: UUID | str,
name: str,
@ -138,7 +164,7 @@ class KubernetesSecretService(VariableService, Service):
*,
default_fields: list[str],
_type: str,
session: Session,
session: AsyncSession,
) -> Variable:
secret_name = encode_user_id(user_id)
secret_key = name
@ -147,7 +173,9 @@ class KubernetesSecretService(VariableService, Service):
else:
_type = GENERIC_TYPE
self.kubernetes_secrets.upsert_secret(secret_name=secret_name, data={secret_key: value})
await asyncio.to_thread(
self.kubernetes_secrets.upsert_secret, secret_name=secret_name, data={secret_key: value}
)
variable_base = VariableCreate(
name=name,

View file

@ -17,6 +17,8 @@ if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.settings.service import SettingsService
@ -24,7 +26,7 @@ class DatabaseVariableService(VariableService, Service):
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
if not self.settings_service.settings.store_environment_variables:
logger.info("Skipping environment variable storage.")
return
@ -34,12 +36,12 @@ class DatabaseVariableService(VariableService, Service):
if var_name in os.environ and os.environ[var_name].strip():
value = os.environ[var_name].strip()
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name)
existing = session.exec(query).first()
existing = (await session.exec(query)).first()
try:
if existing:
self.update_variable(user_id, var_name, value, session)
await self.update_variable(user_id, var_name, value, session)
else:
self.create_variable(
await self.create_variable(
user_id=user_id,
name=var_name,
value=value,
@ -76,40 +78,46 @@ class DatabaseVariableService(VariableService, Service):
# we decrypt the value
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
def get_all(self, user_id: UUID | str, session: Session) -> list[Variable | None]:
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
stmt = select(Variable).where(Variable.user_id == user_id)
return list((await session.exec(stmt)).all())
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
variables = self.get_all(user_id=user_id, session=session)
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables if variable]
def update_variable(
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
variables = await self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]
async def update_variable(
self,
user_id: UUID | str,
name: str,
value: str,
session: Session,
session: AsyncSession,
):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name)
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{name} variable not found."
raise ValueError(msg)
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service)
variable.value = encrypted
session.add(variable)
session.commit()
session.refresh(variable)
await session.commit()
await session.refresh(variable)
return variable
def update_variable_fields(
async def update_variable_fields(
self,
user_id: UUID | str,
variable_id: UUID | str,
variable: VariableUpdate,
session: Session,
session: AsyncSession,
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = session.exec(query).one()
db_variable = (await session.exec(query)).one()
db_variable.updated_at = datetime.now(timezone.utc)
variable.value = variable.value or ""
@ -121,33 +129,34 @@ class DatabaseVariableService(VariableService, Service):
setattr(db_variable, key, value)
session.add(db_variable)
session.commit()
session.refresh(db_variable)
await session.commit()
await session.refresh(db_variable)
return db_variable
def delete_variable(
async def delete_variable(
self,
user_id: UUID | str,
name: str,
session: Session,
session: AsyncSession,
) -> None:
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
variable = session.exec(stmt).first()
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{name} variable not found."
raise ValueError(msg)
session.delete(variable)
session.commit()
await session.delete(variable)
await session.commit()
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None:
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first()
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)
variable = (await session.exec(stmt)).first()
if not variable:
msg = f"{variable_id} variable not found."
raise ValueError(msg)
session.delete(variable)
session.commit()
await session.delete(variable)
await session.commit()
def create_variable(
async def create_variable(
self,
user_id: UUID | str,
name: str,
@ -155,7 +164,7 @@ class DatabaseVariableService(VariableService, Service):
*,
default_fields: Sequence[str] = (),
_type: str = GENERIC_TYPE,
session: Session,
session: AsyncSession,
):
variable_base = VariableCreate(
name=name,
@ -165,6 +174,6 @@ class DatabaseVariableService(VariableService, Service):
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
session.add(variable)
session.commit()
session.refresh(variable)
await session.commit()
await session.refresh(variable)
return variable

View file

@ -90,6 +90,8 @@ def _wrap_file_read_blocking(func):
"_read_pyc",
}:
return func(self, *args, **kwargs)
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
return func(self, *args, **kwargs)
raise _blocking_error(func)
return file_op
@ -104,6 +106,8 @@ def _wrap_file_write_blocking(func):
for frame_info in inspect.stack():
if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_write_pyc":
return func(self, *args, **kwargs)
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
return func(self, *args, **kwargs)
if self not in {sys.stdout, sys.stderr}:
raise _blocking_error(func)
return func(self, *args, **kwargs)

View file

@ -24,7 +24,7 @@ async def test_initialize_services():
"""Benchmark the initialization of services."""
from langflow.services.utils import initialize_services
await asyncio.to_thread(initialize_services, fix_migration=False)
await initialize_services(fix_migration=False)
settings_service = await asyncio.to_thread(get_settings_service)
assert "test_performance.db" in settings_service.settings.database_url
@ -45,8 +45,8 @@ async def test_initialize_super_user():
from langflow.initial_setup.setup import initialize_super_user_if_needed
from langflow.services.utils import initialize_services
await asyncio.to_thread(initialize_services, fix_migration=False)
await asyncio.to_thread(initialize_super_user_if_needed)
await initialize_services(fix_migration=False)
await initialize_super_user_if_needed()
settings_service = await asyncio.to_thread(get_settings_service)
assert "test_performance.db" in settings_service.settings.database_url
@ -69,7 +69,7 @@ async def test_create_starter_projects():
from langflow.interface.types import get_and_cache_all_types_dict
from langflow.services.utils import initialize_services
await asyncio.to_thread(initialize_services, fix_migration=False)
await initialize_services(fix_migration=False)
settings_service = await asyncio.to_thread(get_settings_service)
types_dict = await get_and_cache_all_types_dict(settings_service)
await asyncio.to_thread(create_or_update_starter_projects, types_dict)
@ -81,6 +81,6 @@ async def test_load_flows():
"""Benchmark loading flows from directory."""
from langflow.initial_setup.setup import load_flows_from_directory
await asyncio.to_thread(load_flows_from_directory)
await load_flows_from_directory()
settings_service = await asyncio.to_thread(get_settings_service)
assert "test_performance.db" in settings_service.settings.database_url

View file

@ -1,6 +1,6 @@
from datetime import datetime
from unittest.mock import patch
from uuid import uuid4
from uuid import UUID, uuid4
import pytest
from langflow.services.database.models.variable.model import VariableUpdate
@ -8,7 +8,9 @@ from langflow.services.deps import get_settings_service
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
from sqlmodel import Session, SQLModel, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
@pytest.fixture
@ -18,114 +20,125 @@ def service():
@pytest.fixture
def session():
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
async def session():
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with AsyncSession(engine) as session:
yield session
def test_initialize_user_variables__create_and_update(service, session):
def _get_variable(
session: Session,
service,
user_id: UUID | str,
name: str,
field: str,
):
return service.get_variable(user_id, name, field, session=session)
async def test_initialize_user_variables__create_and_update(service, session: AsyncSession):
user_id = uuid4()
field = ""
good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)}
bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"}
env_vars = {**good_vars, **bad_vars}
service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
await service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
env_vars["OPENAI_API_KEY"] = "updated_value"
with patch.dict("os.environ", env_vars, clear=True):
service.initialize_user_variables(user_id=user_id, session=session)
await service.initialize_user_variables(user_id=user_id, session=session)
variables = service.list_variables(user_id, session=session)
variables = await service.list_variables(user_id, session=session)
for name in variables:
value = service.get_variable(user_id, name, field, session=session)
value = await session.run_sync(_get_variable, service, user_id, name, field)
assert value == env_vars[name]
assert all(i in variables for i in good_vars)
assert all(i not in variables for i in bad_vars)
def test_initialize_user_variables__not_found_variable(service, session):
async def test_initialize_user_variables__not_found_variable(service, session: AsyncSession):
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
m.side_effect = Exception()
service.initialize_user_variables(uuid4(), session=session)
await service.initialize_user_variables(uuid4(), session=session)
assert True
def test_initialize_user_variables__skipping_environment_variable_storage(service, session):
async def test_initialize_user_variables__skipping_environment_variable_storage(service, session: AsyncSession):
service.settings_service.settings.store_environment_variables = False
service.initialize_user_variables(uuid4(), session=session)
await service.initialize_user_variables(uuid4(), session=session)
assert True
def test_get_variable(service, session):
async def test_get_variable(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
field = ""
service.create_variable(user_id, name, value, session=session)
await service.create_variable(user_id, name, value, session=session)
result = service.get_variable(user_id, name, field, session=session)
result = await session.run_sync(_get_variable, service, user_id, name, field)
assert result == value
def test_get_variable__valueerror(service, session):
async def test_get_variable__valueerror(service, session: AsyncSession):
user_id = uuid4()
name = "name"
field = ""
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
await session.run_sync(_get_variable, service, user_id, name, field)
def test_get_variable__typeerror(service, session):
async def test_get_variable__typeerror(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
field = "session_id"
_type = CREDENTIAL_TYPE
service.create_variable(user_id, name, value, _type=_type, session=session)
await service.create_variable(user_id, name, value, _type=_type, session=session)
with pytest.raises(TypeError) as exc:
service.get_variable(user_id, name, field, session)
await session.run_sync(_get_variable, service, user_id, name, field)
assert name in str(exc.value)
assert "purpose is to prevent the exposure of value" in str(exc.value)
def test_list_variables(service, session):
async def test_list_variables(service, session: AsyncSession):
user_id = uuid4()
names = ["name1", "name2", "name3"]
value = "value"
for name in names:
service.create_variable(user_id, name, value, session=session)
await service.create_variable(user_id, name, value, session=session)
result = service.list_variables(user_id, session=session)
result = await service.list_variables(user_id, session=session)
assert all(name in result for name in names)
def test_list_variables__empty(service, session):
result = service.list_variables(uuid4(), session=session)
async def test_list_variables__empty(service, session: AsyncSession):
result = await service.list_variables(uuid4(), session=session)
assert not result
assert isinstance(result, list)
def test_update_variable(service, session):
async def test_update_variable(service, session: AsyncSession):
user_id = uuid4()
name = "name"
old_value = "old_value"
new_value = "new_value"
field = ""
service.create_variable(user_id, name, old_value, session=session)
await service.create_variable(user_id, name, old_value, session=session)
old_recovered = service.get_variable(user_id, name, field, session=session)
result = service.update_variable(user_id, name, new_value, session=session)
new_recovered = service.get_variable(user_id, name, field, session=session)
old_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
result = await service.update_variable(user_id, name, new_value, session=session)
new_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
assert old_value == old_recovered
assert new_value == new_recovered
@ -139,26 +152,26 @@ def test_update_variable(service, session):
assert isinstance(result.updated_at, datetime)
def test_update_variable__valueerror(service, session):
async def test_update_variable__valueerror(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.update_variable(user_id, name, value, session=session)
await service.update_variable(user_id, name, value, session=session)
def test_update_variable_fields(service, session):
async def test_update_variable_fields(service, session: AsyncSession):
user_id = uuid4()
new_name = new_value = "donkey"
variable = service.create_variable(user_id, "old_name", "old_value", session=session)
variable = await service.create_variable(user_id, "old_name", "old_value", session=session)
saved = variable.model_dump()
variable = VariableUpdate(**saved)
variable.name = new_name
variable.value = new_value
variable.default_fields = ["new_field"]
result = service.update_variable_fields(
result = await service.update_variable_fields(
user_id=user_id,
variable_id=saved.get("id"),
variable=variable,
@ -177,58 +190,58 @@ def test_update_variable_fields(service, session):
assert saved.get("updated_at") != result.updated_at
def test_delete_variable(service, session):
async def test_delete_variable(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
field = ""
service.create_variable(user_id, name, value, session=session)
recovered = service.get_variable(user_id, name, field, session=session)
service.delete_variable(user_id, name, session=session)
await service.create_variable(user_id, name, value, session=session)
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
await service.delete_variable(user_id, name, session=session)
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
await session.run_sync(_get_variable, service, user_id, name, field)
assert recovered == value
def test_delete_variable__valueerror(service, session):
async def test_delete_variable__valueerror(service, session: AsyncSession):
user_id = uuid4()
name = "name"
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.delete_variable(user_id, name, session=session)
await service.delete_variable(user_id, name, session=session)
def test_delete_variable_by_id(service, session):
async def test_delete_variable_by_id(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
field = "field"
saved = service.create_variable(user_id, name, value, session=session)
recovered = service.get_variable(user_id, name, field, session=session)
service.delete_variable_by_id(user_id, saved.id, session=session)
saved = await service.create_variable(user_id, name, value, session=session)
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
await service.delete_variable_by_id(user_id, saved.id, session=session)
with pytest.raises(ValueError, match=f"{name} variable not found."):
service.get_variable(user_id, name, field, session)
await session.run_sync(_get_variable, service, user_id, name, field)
assert recovered == value
def test_delete_variable_by_id__valueerror(service, session):
async def test_delete_variable_by_id__valueerror(service, session: AsyncSession):
user_id = uuid4()
variable_id = uuid4()
with pytest.raises(ValueError, match=f"{variable_id} variable not found."):
service.delete_variable_by_id(user_id, variable_id, session=session)
await service.delete_variable_by_id(user_id, variable_id, session=session)
def test_create_variable(service, session):
async def test_create_variable(service, session: AsyncSession):
user_id = uuid4()
name = "name"
value = "value"
result = service.create_variable(user_id, name, value, session=session)
result = await service.create_variable(user_id, name, value, session=session)
assert result.user_id == user_id
assert result.name == name

View file

@ -1,4 +1,5 @@
from unittest.mock import MagicMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from langflow.services.settings.constants import (
DEFAULT_SUPERUSER,
@ -91,7 +92,7 @@ from langflow.services.utils import teardown_superuser
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
async def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = True
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
@ -104,29 +105,28 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
mock_get_session.return_value = iter([mock_session])
teardown_superuser(mock_settings_service, mock_session)
await teardown_superuser(mock_settings_service, mock_session)
mock_session.query.assert_not_called()
@patch("langflow.services.deps.get_settings_service")
@patch("langflow.services.deps.get_session")
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
async def test_teardown_superuser_no_default_superuser():
admin_user_name = "admin_user"
mock_settings_service = MagicMock()
mock_settings_service.auth_settings.AUTO_LOGIN = False
mock_settings_service.auth_settings.SUPERUSER = admin_user_name
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105
mock_get_settings_service.return_value = mock_settings_service
mock_session = MagicMock()
mock_session = AsyncMock(return_value=asyncio.Future())
mock_user = MagicMock()
mock_user.is_superuser = False
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
mock_get_session.return_value = [mock_session]
mock_user.last_login_at = None
teardown_superuser(mock_settings_service, mock_session)
mock_result = MagicMock()
mock_result.first.return_value = mock_user
mock_session.exec.return_value = mock_result
mock_session.query.assert_not_called()
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()
await teardown_superuser(mock_settings_service, mock_session)
mock_session.delete.assert_not_awaited()
mock_session.commit.assert_not_awaited()

View file

@ -5,18 +5,18 @@ from httpx import AsyncClient
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter, session_getter
from langflow.services.deps import get_db_service, get_settings_service
from sqlmodel import select
@pytest.fixture
def super_user(client): # noqa: ARG001
async def super_user(client): # noqa: ARG001
settings_manager = get_settings_service()
auth_settings = settings_manager.auth_settings
with session_getter(get_db_service()) as session:
return create_super_user(
db=session,
async with async_session_getter(get_db_service()) as db:
return await create_super_user(
db=db,
username=auth_settings.SUPERUSER,
password=auth_settings.SUPERUSER_PASSWORD,
)