fix: Use AsyncSession for user management (#4491)
* Use AsyncSession for user management * Simplify check_key * Don't trigger blockbuster on settings service initialize * Fix mypy * Fix api key update_total_uses * Fix auto-login * Revert making CustomComponent.list_key_names async
This commit is contained in:
parent
2881346400
commit
6573ca14cc
24 changed files with 430 additions and 339 deletions
|
|
@ -26,7 +26,7 @@ from sqlmodel import select
|
|||
from langflow.logging.logger import configure, logger
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER
|
||||
from langflow.services.utils import initialize_services
|
||||
|
|
@ -419,30 +419,35 @@ def superuser(
|
|||
) -> None:
|
||||
"""Create a superuser."""
|
||||
configure(log_level=log_level)
|
||||
initialize_services()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
from langflow.services.auth.utils import create_super_user
|
||||
|
||||
if create_super_user(db=session, username=username, password=password):
|
||||
# Verify that the superuser was created
|
||||
from langflow.services.database.models.user.model import User
|
||||
async def _create_superuser():
|
||||
await initialize_services()
|
||||
async with async_session_getter(db_service) as session:
|
||||
from langflow.services.auth.utils import create_super_user
|
||||
|
||||
if await create_super_user(db=session, username=username, password=password):
|
||||
# Verify that the superuser was created
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
stmt = select(User).where(User.username == username)
|
||||
user: User = (await session.exec(stmt)).first()
|
||||
if user is None or not user.is_superuser:
|
||||
typer.echo("Superuser creation failed.")
|
||||
return
|
||||
# Now create the first folder for the user
|
||||
result = await create_default_folder_if_it_doesnt_exist(session, user.id)
|
||||
if result:
|
||||
typer.echo("Default folder created successfully.")
|
||||
else:
|
||||
msg = "Could not create default folder."
|
||||
raise RuntimeError(msg)
|
||||
typer.echo("Superuser created successfully.")
|
||||
|
||||
user: User = session.exec(select(User).where(User.username == username)).first()
|
||||
if user is None or not user.is_superuser:
|
||||
typer.echo("Superuser creation failed.")
|
||||
return
|
||||
# Now create the first folder for the user
|
||||
result = create_default_folder_if_it_doesnt_exist(session, user.id)
|
||||
if result:
|
||||
typer.echo("Default folder created successfully.")
|
||||
else:
|
||||
msg = "Could not create default folder."
|
||||
raise RuntimeError(msg)
|
||||
typer.echo("Superuser created successfully.")
|
||||
typer.echo("Superuser creation failed.")
|
||||
|
||||
else:
|
||||
typer.echo("Superuser creation failed.")
|
||||
asyncio.run(_create_superuser())
|
||||
|
||||
|
||||
# command to copy the langflow database from the cache to the current directory
|
||||
|
|
@ -494,7 +499,7 @@ def migration(
|
|||
):
|
||||
raise typer.Abort
|
||||
|
||||
initialize_services(fix_migration=fix)
|
||||
asyncio.run(initialize_services(fix_migration=fix))
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
db_service.run_migrations()
|
||||
|
|
@ -515,18 +520,20 @@ def api_key(
|
|||
None
|
||||
"""
|
||||
configure(log_level=log_level)
|
||||
initialize_services()
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
if not auth_settings.AUTO_LOGIN:
|
||||
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
|
||||
return
|
||||
|
||||
async def aapi_key():
|
||||
await initialize_services()
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
if not auth_settings.AUTO_LOGIN:
|
||||
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
|
||||
return None
|
||||
|
||||
async with async_session_scope() as session:
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first()
|
||||
stmt = select(User).where(User.username == DEFAULT_SUPERUSER)
|
||||
superuser = (await session.exec(stmt)).first()
|
||||
if not superuser:
|
||||
typer.echo(
|
||||
"Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled."
|
||||
|
|
@ -535,7 +542,8 @@ def api_key(
|
|||
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
|
||||
from langflow.services.database.models.api_key.crud import create_api_key, delete_api_key
|
||||
|
||||
api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first()
|
||||
stmt = select(ApiKey).where(ApiKey.user_id == superuser.id)
|
||||
api_key = (await session.exec(stmt)).first()
|
||||
if api_key:
|
||||
await delete_api_key(session, api_key.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from uuid import UUID
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ async def save_store_api_key(
|
|||
api_key_request: ApiKeyCreateRequest,
|
||||
response: Response,
|
||||
current_user: CurrentActiveUser,
|
||||
db: DbSession,
|
||||
db: AsyncDbSession,
|
||||
):
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
|
|
@ -74,7 +74,7 @@ async def save_store_api_key(
|
|||
encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service)
|
||||
current_user.store_api_key = encrypted
|
||||
db.add(current_user)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
response.set_cookie(
|
||||
"apikey_tkn_lflw",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Annotated
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from langflow.api.utils import DbSession
|
||||
from langflow.api.utils import AsyncDbSession
|
||||
from langflow.api.v1.schemas import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
|
|
@ -21,14 +21,14 @@ router = APIRouter(tags=["Login"])
|
|||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
def login_to_get_access_token(
|
||||
async def login_to_get_access_token(
|
||||
response: Response,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: DbSession,
|
||||
db: AsyncDbSession,
|
||||
):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
try:
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
user = await authenticate_user(form_data.username, form_data.password, db)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, HTTPException):
|
||||
raise
|
||||
|
|
@ -38,7 +38,7 @@ def login_to_get_access_token(
|
|||
) from exc
|
||||
|
||||
if user:
|
||||
tokens = create_user_tokens(user_id=user.id, db=db, update_last_login=True)
|
||||
tokens = await create_user_tokens(user_id=user.id, db=db, update_last_login=True)
|
||||
response.set_cookie(
|
||||
"refresh_token_lf",
|
||||
tokens["refresh_token"],
|
||||
|
|
@ -66,9 +66,9 @@ def login_to_get_access_token(
|
|||
expires=None, # Set to None to make it a session cookie
|
||||
domain=auth_settings.COOKIE_DOMAIN,
|
||||
)
|
||||
get_variable_service().initialize_user_variables(user.id, db)
|
||||
await get_variable_service().initialize_user_variables(user.id, db)
|
||||
# Create default folder for user if it doesn't exist
|
||||
create_default_folder_if_it_doesnt_exist(db, user.id)
|
||||
await create_default_folder_if_it_doesnt_exist(db, user.id)
|
||||
return tokens
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -78,11 +78,11 @@ def login_to_get_access_token(
|
|||
|
||||
|
||||
@router.get("/auto_login")
|
||||
async def auto_login(response: Response, db: DbSession):
|
||||
async def auto_login(response: Response, db: AsyncDbSession):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
|
||||
if auth_settings.AUTO_LOGIN:
|
||||
user_id, tokens = create_user_longterm_token(db)
|
||||
user_id, tokens = await create_user_longterm_token(db)
|
||||
response.set_cookie(
|
||||
"access_token_lf",
|
||||
tokens["access_token"],
|
||||
|
|
@ -93,7 +93,7 @@ async def auto_login(response: Response, db: DbSession):
|
|||
domain=auth_settings.COOKIE_DOMAIN,
|
||||
)
|
||||
|
||||
user = get_user_by_id(db, user_id)
|
||||
user = await get_user_by_id(db, user_id)
|
||||
|
||||
if user:
|
||||
if user.store_api_key is None:
|
||||
|
|
@ -124,14 +124,14 @@ async def auto_login(response: Response, db: DbSession):
|
|||
async def refresh_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: DbSession,
|
||||
db: AsyncDbSession,
|
||||
):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
|
||||
token = request.cookies.get("refresh_token_lf")
|
||||
|
||||
if token:
|
||||
tokens = create_refresh_token(token, db)
|
||||
tokens = await create_refresh_token(token, db)
|
||||
response.set_cookie(
|
||||
"refresh_token_lf",
|
||||
tokens["refresh_token"],
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlmodel import select
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
|
|
@ -25,7 +25,7 @@ router = APIRouter(tags=["Users"], prefix="/users")
|
|||
@router.post("/", response_model=UserRead, status_code=201)
|
||||
async def add_user(
|
||||
user: UserCreate,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> User:
|
||||
"""Add a new user to the database."""
|
||||
new_user = User.model_validate(user, from_attributes=True)
|
||||
|
|
@ -33,13 +33,13 @@ async def add_user(
|
|||
new_user.password = get_password_hash(user.password)
|
||||
new_user.is_active = get_settings_service().auth_settings.NEW_USER_IS_ACTIVE
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
session.refresh(new_user)
|
||||
folder = create_default_folder_if_it_doesnt_exist(session, new_user.id)
|
||||
await session.commit()
|
||||
await session.refresh(new_user)
|
||||
folder = await create_default_folder_if_it_doesnt_exist(session, new_user.id)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=500, detail="Error creating default folder")
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="This username is unavailable.") from e
|
||||
|
||||
return new_user
|
||||
|
|
@ -58,14 +58,14 @@ async def read_all_users(
|
|||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> UsersResponse:
|
||||
"""Retrieve a list of users from the database with pagination."""
|
||||
query: SelectOfScalar = select(User).offset(skip).limit(limit)
|
||||
users = session.exec(query).fetchall()
|
||||
users = (await session.exec(query)).fetchall()
|
||||
|
||||
count_query = select(func.count()).select_from(User)
|
||||
total_count = session.exec(count_query).first()
|
||||
total_count = (await session.exec(count_query)).first()
|
||||
|
||||
return UsersResponse(
|
||||
total_count=total_count,
|
||||
|
|
@ -78,7 +78,7 @@ async def patch_user(
|
|||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> User:
|
||||
"""Update an existing user's data."""
|
||||
update_password = bool(user_update.password)
|
||||
|
|
@ -93,10 +93,10 @@ async def patch_user(
|
|||
raise HTTPException(status_code=400, detail="You can't change your password here")
|
||||
user_update.password = get_password_hash(user_update.password)
|
||||
|
||||
if user_db := get_user_by_id(session, user_id):
|
||||
if user_db := await get_user_by_id(session, user_id):
|
||||
if not update_password:
|
||||
user_update.password = user_db.password
|
||||
return update_user(user_db, user_update, session)
|
||||
return await update_user(user_db, user_update, session)
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ async def reset_password(
|
|||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: CurrentActiveUser,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
) -> User:
|
||||
"""Reset a user's password."""
|
||||
if user_id != user.id:
|
||||
|
|
@ -117,8 +117,8 @@ async def reset_password(
|
|||
raise HTTPException(status_code=400, detail="You can't use your current password")
|
||||
new_password = get_password_hash(user_update.password)
|
||||
user.password = new_password
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from uuid import UUID
|
|||
from fastapi import APIRouter, HTTPException
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.deps import get_variable_service
|
||||
from langflow.services.variable.constants import GENERIC_TYPE
|
||||
|
|
@ -15,7 +15,7 @@ router = APIRouter(prefix="/variables", tags=["Variables"])
|
|||
@router.post("/", response_model=VariableRead, status_code=201)
|
||||
async def create_variable(
|
||||
*,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
variable: VariableCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -30,10 +30,10 @@ async def create_variable(
|
|||
if not variable.value:
|
||||
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
|
||||
|
||||
if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
|
||||
if variable.name in await variable_service.list_variables(user_id=current_user.id, session=session):
|
||||
raise HTTPException(status_code=400, detail="Variable name already exists")
|
||||
try:
|
||||
return variable_service.create_variable(
|
||||
return await variable_service.create_variable(
|
||||
user_id=current_user.id,
|
||||
name=variable.name,
|
||||
value=variable.value,
|
||||
|
|
@ -50,7 +50,7 @@ async def create_variable(
|
|||
@router.get("/", response_model=list[VariableRead], status_code=200)
|
||||
async def read_variables(
|
||||
*,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
"""Read all variables."""
|
||||
|
|
@ -59,7 +59,7 @@ async def read_variables(
|
|||
msg = "Variable service is not an instance of DatabaseVariableService"
|
||||
raise TypeError(msg)
|
||||
try:
|
||||
return variable_service.get_all(user_id=current_user.id, session=session)
|
||||
return await variable_service.get_all(user_id=current_user.id, session=session)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ async def read_variables(
|
|||
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
|
||||
async def update_variable(
|
||||
*,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
variable_id: UUID,
|
||||
variable: VariableUpdate,
|
||||
current_user: CurrentActiveUser,
|
||||
|
|
@ -78,7 +78,7 @@ async def update_variable(
|
|||
msg = "Variable service is not an instance of DatabaseVariableService"
|
||||
raise TypeError(msg)
|
||||
try:
|
||||
return variable_service.update_variable_fields(
|
||||
return await variable_service.update_variable_fields(
|
||||
user_id=current_user.id,
|
||||
variable_id=variable_id,
|
||||
variable=variable,
|
||||
|
|
@ -94,13 +94,13 @@ async def update_variable(
|
|||
@router.delete("/{variable_id}", status_code=204)
|
||||
async def delete_variable(
|
||||
*,
|
||||
session: DbSession,
|
||||
session: AsyncDbSession,
|
||||
variable_id: UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
) -> None:
|
||||
"""Delete a variable."""
|
||||
variable_service = get_variable_service()
|
||||
try:
|
||||
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
|
||||
await variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Add helper functions for each event type
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
|
|
@ -249,7 +250,7 @@ async def process_agent_events(
|
|||
agent_message.properties.icon = "Bot"
|
||||
agent_message.properties.state = "partial"
|
||||
# Store the initial message
|
||||
agent_message = send_message_method(message=agent_message)
|
||||
agent_message = await asyncio.to_thread(send_message_method, message=agent_message)
|
||||
try:
|
||||
# Create a mapping of run_ids to tool contents
|
||||
tool_blocks_map: dict[str, ToolContent] = {}
|
||||
|
|
|
|||
|
|
@ -448,7 +448,7 @@ class CustomComponent(BaseComponent):
|
|||
variable_service = get_variable_service()
|
||||
|
||||
with session_scope() as session:
|
||||
return variable_service.list_variables(user_id=self.user_id, session=session)
|
||||
return variable_service.list_variables_sync(user_id=self.user_id, session=session)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable.
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from langflow.services.database.models.folder.utils import (
|
|||
)
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.deps import (
|
||||
async_session_scope,
|
||||
get_settings_service,
|
||||
get_storage_service,
|
||||
get_variable_service,
|
||||
|
|
@ -519,7 +520,7 @@ def _is_valid_uuid(val):
|
|||
return str(uuid_obj) == val
|
||||
|
||||
|
||||
def load_flows_from_directory() -> None:
|
||||
async def load_flows_from_directory() -> None:
|
||||
"""On langflow startup, this loads all flows from the directory specified in the settings.
|
||||
|
||||
All flows are uploaded into the default folder for the superuser.
|
||||
|
|
@ -533,8 +534,8 @@ def load_flows_from_directory() -> None:
|
|||
logger.warning("AUTO_LOGIN is disabled, not loading flows from directory")
|
||||
return
|
||||
|
||||
with session_scope() as session:
|
||||
user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
|
||||
async with async_session_scope() as session:
|
||||
user = await get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
|
||||
if user is None:
|
||||
msg = "Superuser not found in the database"
|
||||
raise NoResultFound(msg)
|
||||
|
|
@ -553,7 +554,7 @@ def load_flows_from_directory() -> None:
|
|||
flow["id"] = no_json_name
|
||||
flow_id = flow.get("id")
|
||||
|
||||
existing = find_existing_flow(session, flow_id, flow_endpoint_name)
|
||||
existing = await find_existing_flow(session, flow_id, flow_endpoint_name)
|
||||
if existing:
|
||||
logger.debug(f"Found existing flow: {existing.name}")
|
||||
logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}")
|
||||
|
|
@ -585,15 +586,15 @@ def load_flows_from_directory() -> None:
|
|||
session.add(flow)
|
||||
|
||||
|
||||
def find_existing_flow(session, flow_id, flow_endpoint_name):
|
||||
async def find_existing_flow(session, flow_id, flow_endpoint_name):
|
||||
if flow_endpoint_name:
|
||||
logger.debug(f"flow_endpoint_name: {flow_endpoint_name}")
|
||||
stmt = select(Flow).where(Flow.endpoint_name == flow_endpoint_name)
|
||||
if existing := session.exec(stmt).first():
|
||||
if existing := (await session.exec(stmt)).first():
|
||||
logger.debug(f"Found existing flow by endpoint name: {existing.name}")
|
||||
return existing
|
||||
stmt = select(Flow).where(Flow.id == flow_id)
|
||||
if existing := session.exec(stmt).first():
|
||||
if existing := (await session.exec(stmt)).first():
|
||||
logger.debug(f"Found existing flow by id: {flow_id}")
|
||||
return existing
|
||||
return None
|
||||
|
|
@ -645,7 +646,7 @@ def create_or_update_starter_projects(all_types_dict: dict) -> None:
|
|||
)
|
||||
|
||||
|
||||
def initialize_super_user_if_needed() -> None:
|
||||
async def initialize_super_user_if_needed() -> None:
|
||||
settings_service = get_settings_service()
|
||||
if not settings_service.auth_settings.AUTO_LOGIN:
|
||||
return
|
||||
|
|
@ -655,8 +656,8 @@ def initialize_super_user_if_needed() -> None:
|
|||
msg = "SUPERUSER and SUPERUSER_PASSWORD must be set in the settings if AUTO_LOGIN is true."
|
||||
raise ValueError(msg)
|
||||
|
||||
with session_scope() as session:
|
||||
super_user = create_super_user(db=session, username=username, password=password)
|
||||
get_variable_service().initialize_user_variables(super_user.id, session)
|
||||
create_default_folder_if_it_doesnt_exist(session, super_user.id)
|
||||
logger.info("Super user initialized")
|
||||
async with async_session_scope() as async_session:
|
||||
super_user = await create_super_user(db=async_session, username=username, password=password)
|
||||
await get_variable_service().initialize_user_variables(super_user.id, async_session)
|
||||
await create_default_folder_if_it_doesnt_exist(async_session, super_user.id)
|
||||
logger.info("Super user initialized")
|
||||
|
|
|
|||
|
|
@ -89,11 +89,6 @@ class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
|
|||
def get_lifespan(*, fix_migration=False, version=None):
|
||||
telemetry_service = get_telemetry_service()
|
||||
|
||||
def _initialize():
|
||||
initialize_services(fix_migration=fix_migration)
|
||||
setup_llm_caching()
|
||||
initialize_super_user_if_needed()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
configure(async_file=True)
|
||||
|
|
@ -104,12 +99,13 @@ def get_lifespan(*, fix_migration=False, version=None):
|
|||
else:
|
||||
rprint("[bold green]Starting Langflow...[/bold green]")
|
||||
try:
|
||||
await asyncio.to_thread(_initialize)
|
||||
await initialize_services(fix_migration=fix_migration)
|
||||
await asyncio.to_thread(setup_llm_caching)
|
||||
await initialize_super_user_if_needed()
|
||||
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
|
||||
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
|
||||
telemetry_service.start()
|
||||
await asyncio.to_thread(load_flows_from_directory)
|
||||
|
||||
await load_flows_from_directory()
|
||||
yield
|
||||
|
||||
except Exception as exc:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import random
|
||||
import warnings
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
|
@ -12,16 +11,18 @@ from fastapi import Depends, HTTPException, Security, status
|
|||
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from langflow.services.database.models.api_key.crud import check_key
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_db_service, get_session, get_settings_service
|
||||
from langflow.services.deps import get_async_session, get_db_service, get_settings_service
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
|
||||
|
||||
API_KEY_NAME = "x-api-key"
|
||||
|
|
@ -33,14 +34,14 @@ MINIMUM_KEY_LENGTH = 32
|
|||
|
||||
|
||||
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
|
||||
def api_key_security(
|
||||
async def api_key_security(
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
header_param: Annotated[str, Security(api_key_header)],
|
||||
) -> UserRead | None:
|
||||
settings_service = get_settings_service()
|
||||
result: ApiKey | User | None = None
|
||||
result: ApiKey | User | None
|
||||
|
||||
with get_db_service().with_session() as db:
|
||||
async with get_db_service().with_async_session() as db:
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
# Get the first user
|
||||
if not settings_service.auth_settings.SUPERUSER:
|
||||
|
|
@ -49,7 +50,7 @@ def api_key_security(
|
|||
detail="Missing first superuser credentials",
|
||||
)
|
||||
|
||||
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
||||
result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
||||
|
||||
elif not query_param and not header_param:
|
||||
raise HTTPException(
|
||||
|
|
@ -58,18 +59,16 @@ def api_key_security(
|
|||
)
|
||||
|
||||
elif query_param:
|
||||
result = check_key(db, query_param)
|
||||
result = await check_key(db, query_param)
|
||||
|
||||
else:
|
||||
result = check_key(db, header_param)
|
||||
result = await check_key(db, header_param)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
if isinstance(result, ApiKey):
|
||||
return UserRead.model_validate(result.user, from_attributes=True)
|
||||
if isinstance(result, User):
|
||||
return UserRead.model_validate(result, from_attributes=True)
|
||||
msg = "Invalid result type"
|
||||
|
|
@ -80,11 +79,11 @@ async def get_current_user(
|
|||
token: Annotated[str, Security(oauth2_login)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
header_param: Annotated[str, Security(api_key_header)],
|
||||
db: Annotated[Session, Depends(get_session)],
|
||||
db: Annotated[AsyncSession, Depends(get_async_session)],
|
||||
) -> User:
|
||||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
user = await asyncio.to_thread(api_key_security, query_param, header_param)
|
||||
user = await api_key_security(query_param, header_param)
|
||||
if user:
|
||||
return user
|
||||
|
||||
|
|
@ -95,8 +94,8 @@ async def get_current_user(
|
|||
|
||||
|
||||
async def get_current_user_by_jwt(
|
||||
token: Annotated[str, Depends(oauth2_login)],
|
||||
db: Annotated[Session, Depends(get_session)],
|
||||
token: str,
|
||||
db: AsyncSession,
|
||||
) -> User:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ async def get_current_user_by_jwt(
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from e
|
||||
|
||||
user = get_user_by_id(db, user_id)
|
||||
user = await get_user_by_id(db, user_id)
|
||||
if user is None or not user.is_active:
|
||||
logger.info("User not found or inactive.")
|
||||
raise HTTPException(
|
||||
|
|
@ -157,7 +156,7 @@ async def get_current_user_by_jwt(
|
|||
|
||||
async def get_current_user_for_websocket(
|
||||
websocket: WebSocket,
|
||||
db: Annotated[Session, Depends(get_session)],
|
||||
db: Annotated[AsyncSession, Depends(get_async_session)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
) -> User | None:
|
||||
token = websocket.query_params.get("token")
|
||||
|
|
@ -165,7 +164,7 @@ async def get_current_user_for_websocket(
|
|||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
if api_key:
|
||||
return await asyncio.to_thread(api_key_security, api_key, query_param)
|
||||
return await api_key_security(api_key, query_param)
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -207,12 +206,12 @@ def create_token(data: dict, expires_delta: timedelta):
|
|||
)
|
||||
|
||||
|
||||
def create_super_user(
|
||||
async def create_super_user(
|
||||
username: str,
|
||||
password: str,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
) -> User:
|
||||
super_user = get_user_by_username(db, username)
|
||||
super_user = await get_user_by_username(db, username)
|
||||
|
||||
if not super_user:
|
||||
super_user = User(
|
||||
|
|
@ -224,17 +223,17 @@ def create_super_user(
|
|||
)
|
||||
|
||||
db.add(super_user)
|
||||
db.commit()
|
||||
db.refresh(super_user)
|
||||
await db.commit()
|
||||
await db.refresh(super_user)
|
||||
|
||||
return super_user
|
||||
|
||||
|
||||
def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
|
||||
async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
username = settings_service.auth_settings.SUPERUSER
|
||||
super_user = get_user_by_username(db, username)
|
||||
super_user = await get_user_by_username(db, username)
|
||||
if not super_user:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created")
|
||||
access_token_expires_longterm = timedelta(days=365)
|
||||
|
|
@ -244,7 +243,7 @@ def create_user_longterm_token(db: Session) -> tuple[UUID, dict]:
|
|||
)
|
||||
|
||||
# Update: last_login_at
|
||||
update_user_last_login_at(super_user.id, db)
|
||||
await update_user_last_login_at(super_user.id, db)
|
||||
|
||||
return super_user.id, {
|
||||
"access_token": access_token,
|
||||
|
|
@ -270,7 +269,7 @@ def get_user_id_from_token(token: str) -> UUID:
|
|||
return UUID(int=0)
|
||||
|
||||
|
||||
def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool = False) -> dict:
|
||||
async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
|
||||
|
|
@ -287,7 +286,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool =
|
|||
|
||||
# Update: last_login_at
|
||||
if update_last_login:
|
||||
update_user_last_login_at(user_id, db)
|
||||
await update_user_last_login_at(user_id, db)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
|
|
@ -296,7 +295,7 @@ def create_user_tokens(user_id: UUID, db: Session, *, update_last_login: bool =
|
|||
}
|
||||
|
||||
|
||||
def create_refresh_token(refresh_token: str, db: Session):
|
||||
async def create_refresh_token(refresh_token: str, db: AsyncSession):
|
||||
settings_service = get_settings_service()
|
||||
|
||||
try:
|
||||
|
|
@ -314,12 +313,12 @@ def create_refresh_token(refresh_token: str, db: Session):
|
|||
if user_id is None or token_type == "":
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
||||
|
||||
user_exists = get_user_by_id(db, user_id)
|
||||
user_exists = await get_user_by_id(db, user_id)
|
||||
|
||||
if user_exists is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
||||
|
||||
return create_user_tokens(user_id, db)
|
||||
return await create_user_tokens(user_id, db)
|
||||
|
||||
except JWTError as e:
|
||||
logger.exception("JWT decoding error")
|
||||
|
|
@ -329,8 +328,8 @@ def create_refresh_token(refresh_token: str, db: Session):
|
|||
) from e
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
user = get_user_by_username(db, username)
|
||||
async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None:
|
||||
user = await get_user_by_username(db, username)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import secrets
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.database.models import User
|
||||
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
|
||||
from langflow.services.database.utils import async_session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
|
@ -47,33 +51,29 @@ async def delete_api_key(session: AsyncSession, api_key_id: UUID) -> None:
|
|||
await session.commit()
|
||||
|
||||
|
||||
def check_key(session: Session, api_key: str) -> ApiKey | None:
|
||||
update_total_uses_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def check_key(session: AsyncSession, api_key: str) -> User | None:
|
||||
"""Check if the API key is valid."""
|
||||
query: SelectOfScalar = select(ApiKey).where(ApiKey.api_key == api_key)
|
||||
api_key_object: ApiKey | None = session.exec(query).first()
|
||||
query: SelectOfScalar = select(ApiKey).options(selectinload(ApiKey.user)).where(ApiKey.api_key == api_key)
|
||||
api_key_object: ApiKey | None = (await session.exec(query)).first()
|
||||
if api_key_object is not None:
|
||||
threading.Thread(
|
||||
target=update_total_uses,
|
||||
args=(
|
||||
session,
|
||||
api_key_object,
|
||||
),
|
||||
).start()
|
||||
return api_key_object
|
||||
task = asyncio.create_task(update_total_uses(api_key_object.id))
|
||||
task.add_done_callback(update_total_uses_tasks.discard)
|
||||
update_total_uses_tasks.add(task)
|
||||
return api_key_object.user
|
||||
return None
|
||||
|
||||
|
||||
def update_total_uses(session, api_key: ApiKey):
|
||||
async def update_total_uses(api_key_id: UUID):
|
||||
"""Update the total uses and last used at."""
|
||||
# This is running in a separate thread to avoid slowing down the request
|
||||
# but session is not thread safe so we need to create a new session
|
||||
|
||||
with Session(session.get_bind()) as new_session:
|
||||
new_api_key = new_session.get(ApiKey, api_key.id)
|
||||
async with async_session_getter(get_db_service()) as session:
|
||||
new_api_key = await session.get(ApiKey, api_key_id)
|
||||
if new_api_key is None:
|
||||
msg = "API Key not found"
|
||||
raise ValueError(msg)
|
||||
new_api_key.total_uses += 1
|
||||
new_api_key.last_used_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
new_session.add(new_api_key)
|
||||
new_session.commit()
|
||||
return new_api_key
|
||||
session.add(new_api_key)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, and_, select, update
|
||||
from sqlmodel import and_, select, update
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
||||
|
|
@ -8,8 +9,9 @@ from .constants import DEFAULT_FOLDER_DESCRIPTION, DEFAULT_FOLDER_NAME
|
|||
from .model import Folder
|
||||
|
||||
|
||||
def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
|
||||
folder = session.exec(select(Folder).where(Folder.user_id == user_id)).first()
|
||||
async def create_default_folder_if_it_doesnt_exist(session: AsyncSession, user_id: UUID):
|
||||
stmt = select(Folder).where(Folder.user_id == user_id)
|
||||
folder = (await session.exec(stmt)).first()
|
||||
if not folder:
|
||||
folder = Folder(
|
||||
name=DEFAULT_FOLDER_NAME,
|
||||
|
|
@ -17,9 +19,9 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
|
|||
description=DEFAULT_FOLDER_DESCRIPTION,
|
||||
)
|
||||
session.add(folder)
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
session.exec(
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
await session.exec(
|
||||
update(Flow)
|
||||
.where(
|
||||
and_(
|
||||
|
|
@ -29,12 +31,14 @@ def create_default_folder_if_it_doesnt_exist(session: Session, user_id: UUID):
|
|||
)
|
||||
.values(folder_id=folder.id)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return folder
|
||||
|
||||
|
||||
def get_default_folder_id(session: Session, user_id: UUID):
|
||||
folder = session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id)).first()
|
||||
async def get_default_folder_id(session: AsyncSession, user_id: UUID):
|
||||
folder = (
|
||||
await session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME, Folder.user_id == user_id))
|
||||
).first()
|
||||
if not folder:
|
||||
folder = create_default_folder_if_it_doesnt_exist(session, user_id)
|
||||
folder = await create_default_folder_if_it_doesnt_exist(session, user_id)
|
||||
return folder.id
|
||||
|
|
|
|||
|
|
@ -5,20 +5,23 @@ from fastapi import HTTPException, status
|
|||
from loguru import logger
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.database.models.user.model import User, UserUpdate
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> User | None:
|
||||
return db.exec(select(User).where(User.username == username)).first()
|
||||
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||
stmt = select(User).where(User.username == username)
|
||||
return (await db.exec(stmt)).first()
|
||||
|
||||
|
||||
def get_user_by_id(db: Session, user_id: UUID) -> User | None:
|
||||
return db.exec(select(User).where(User.id == user_id)).first()
|
||||
async def get_user_by_id(db: AsyncSession, user_id: UUID) -> User | None:
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
return (await db.exec(stmt)).first()
|
||||
|
||||
|
||||
def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
|
||||
async def update_user(user_db: User | None, user: UserUpdate, db: AsyncSession) -> User:
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
|
@ -40,18 +43,18 @@ def update_user(user_db: User | None, user: UserUpdate, db: Session) -> User:
|
|||
flag_modified(user_db, "updated_at")
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
await db.commit()
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def update_user_last_login_at(user_id: UUID, db: Session):
|
||||
async def update_user_last_login_at(user_id: UUID, db: AsyncSession):
|
||||
try:
|
||||
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc))
|
||||
user = get_user_by_id(db, user_id)
|
||||
return update_user(user, user_data, db)
|
||||
user = await get_user_by_id(db, user_id)
|
||||
return await update_user(user, user_data, db)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.opt(exception=True).debug("Error updating user last login at")
|
||||
|
|
|
|||
|
|
@ -142,28 +142,29 @@ class DatabaseService(Service):
|
|||
|
||||
@asynccontextmanager
|
||||
async def with_async_session(self):
|
||||
async with AsyncSession(self.async_engine) as session:
|
||||
async with AsyncSession(self.async_engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
def migrate_flows_if_auto_login(self) -> None:
|
||||
async def migrate_flows_if_auto_login(self) -> None:
|
||||
# if auto_login is enabled, we need to migrate the flows
|
||||
# to the default superuser if they don't have a user id
|
||||
# associated with them
|
||||
settings_service = get_settings_service()
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
with self.with_session() as session:
|
||||
flows = session.exec(select(models.Flow).where(models.Flow.user_id is None)).all()
|
||||
async with self.with_async_session() as session:
|
||||
stmt = select(models.Flow).where(models.Flow.user_id is None)
|
||||
flows = (await session.exec(stmt)).all()
|
||||
if flows:
|
||||
logger.debug("Migrating flows to default superuser")
|
||||
username = settings_service.auth_settings.SUPERUSER
|
||||
user = get_user_by_username(session, username)
|
||||
user = await get_user_by_username(session, username)
|
||||
if not user:
|
||||
logger.error("Default superuser not found")
|
||||
msg = "Default superuser not found"
|
||||
raise RuntimeError(msg)
|
||||
for flow in flows:
|
||||
flow.user_id = user.id
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.debug("Flows migrated successfully")
|
||||
|
||||
def check_schema_health(self) -> bool:
|
||||
|
|
@ -346,20 +347,15 @@ class DatabaseService(Service):
|
|||
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
def _teardown(self) -> None:
|
||||
async def teardown(self) -> None:
|
||||
logger.debug("Tearing down database")
|
||||
try:
|
||||
settings_service = get_settings_service()
|
||||
# remove the default superuser if auto_login is enabled
|
||||
# using the SUPERUSER to get the user
|
||||
with self.with_session() as session:
|
||||
teardown_superuser(settings_service, session)
|
||||
|
||||
async with self.with_async_session() as session:
|
||||
await teardown_superuser(settings_service, session)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error tearing down database")
|
||||
|
||||
self.engine.dispose()
|
||||
|
||||
async def teardown(self) -> None:
|
||||
await asyncio.to_thread(self._teardown)
|
||||
await self.async_engine.dispose()
|
||||
await asyncio.to_thread(self.engine.dispose)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from alembic.util.exc import CommandError
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
|
@ -70,6 +71,19 @@ def session_getter(db_service: DatabaseService):
|
|||
session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_session_getter(db_service: DatabaseService):
|
||||
try:
|
||||
session = AsyncSession(db_service.async_engine)
|
||||
yield session
|
||||
except Exception:
|
||||
logger.exception("Session rollback because of exception")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.auth.utils import create_super_user, verify_password
|
||||
from langflow.services.cache.factory import CacheServiceFactory
|
||||
|
|
@ -9,13 +10,14 @@ from langflow.services.database.utils import initialize_database
|
|||
from langflow.services.schema import ServiceType
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
|
||||
|
||||
from .deps import get_db_service, get_service, get_session, get_settings_service
|
||||
from .deps import get_db_service, get_service, get_settings_service
|
||||
|
||||
|
||||
def get_or_create_super_user(session: Session, username, password, is_default):
|
||||
async def get_or_create_super_user(session: AsyncSession, username, password, is_default):
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
user = session.exec(select(User).where(User.username == username)).first()
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = (await session.exec(stmt)).first()
|
||||
|
||||
if user and user.is_superuser:
|
||||
return None # Superuser already exists
|
||||
|
|
@ -51,7 +53,7 @@ def get_or_create_super_user(session: Session, username, password, is_default):
|
|||
else:
|
||||
logger.debug("Creating superuser.")
|
||||
try:
|
||||
return create_super_user(username, password, db=session)
|
||||
return await create_super_user(username, password, db=session)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if "UNIQUE constraint failed: user.username" in str(exc):
|
||||
# This is to deal with workers running this
|
||||
|
|
@ -62,12 +64,12 @@ def get_or_create_super_user(session: Session, username, password, is_default):
|
|||
logger.opt(exception=True).debug("Error creating superuser.")
|
||||
|
||||
|
||||
def setup_superuser(settings_service, session: Session) -> None:
|
||||
async def setup_superuser(settings_service, session: AsyncSession) -> None:
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
logger.debug("AUTO_LOGIN is set to True. Creating default superuser.")
|
||||
else:
|
||||
# Remove the default superuser if it exists
|
||||
teardown_superuser(settings_service, session)
|
||||
await teardown_superuser(settings_service, session)
|
||||
|
||||
username = settings_service.auth_settings.SUPERUSER
|
||||
password = settings_service.auth_settings.SUPERUSER_PASSWORD
|
||||
|
|
@ -75,7 +77,9 @@ def setup_superuser(settings_service, session: Session) -> None:
|
|||
is_default = (username == DEFAULT_SUPERUSER) and (password == DEFAULT_SUPERUSER_PASSWORD)
|
||||
|
||||
try:
|
||||
user = get_or_create_super_user(session=session, username=username, password=password, is_default=is_default)
|
||||
user = await get_or_create_super_user(
|
||||
session=session, username=username, password=password, is_default=is_default
|
||||
)
|
||||
if user is not None:
|
||||
logger.debug("Superuser created successfully.")
|
||||
except Exception as exc:
|
||||
|
|
@ -86,7 +90,7 @@ def setup_superuser(settings_service, session: Session) -> None:
|
|||
settings_service.auth_settings.reset_credentials()
|
||||
|
||||
|
||||
def teardown_superuser(settings_service, session) -> None:
|
||||
async def teardown_superuser(settings_service, session: AsyncSession) -> None:
|
||||
"""Teardown the superuser."""
|
||||
# If AUTO_LOGIN is True, we will remove the default superuser
|
||||
# from the database.
|
||||
|
|
@ -97,30 +101,27 @@ def teardown_superuser(settings_service, session) -> None:
|
|||
username = DEFAULT_SUPERUSER
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
user = session.exec(select(User).where(User.username == username)).first()
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = (await session.exec(stmt)).first()
|
||||
# Check if super was ever logged in, if not delete it
|
||||
# if it has logged in, it means the user is using it to login
|
||||
if user and user.is_superuser is True and not user.last_login_at:
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
logger.debug("Default superuser removed successfully.")
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
msg = "Could not remove default superuser."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
|
||||
def _teardown_superuser():
|
||||
with get_db_service().with_session() as session:
|
||||
teardown_superuser(get_settings_service(), session)
|
||||
|
||||
|
||||
async def teardown_services() -> None:
|
||||
"""Teardown all the services."""
|
||||
try:
|
||||
await asyncio.to_thread(_teardown_superuser)
|
||||
async with get_db_service().with_async_session() as session:
|
||||
await teardown_superuser(get_settings_service(), session)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(exc)
|
||||
try:
|
||||
|
|
@ -156,15 +157,16 @@ def initialize_session_service() -> None:
|
|||
)
|
||||
|
||||
|
||||
def initialize_services(*, fix_migration: bool = False) -> None:
|
||||
async def initialize_services(*, fix_migration: bool = False) -> None:
|
||||
"""Initialize all the services needed."""
|
||||
# Test cache connection
|
||||
get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory())
|
||||
# Setup the superuser
|
||||
initialize_database(fix_migration=fix_migration)
|
||||
setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), next(get_session()))
|
||||
await asyncio.to_thread(initialize_database, fix_migration=fix_migration)
|
||||
async with get_db_service().with_async_session() as session:
|
||||
await setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), session)
|
||||
try:
|
||||
get_db_service().migrate_flows_if_auto_login()
|
||||
await get_db_service().migrate_flows_if_auto_login()
|
||||
except Exception as exc:
|
||||
msg = "Error migrating flows"
|
||||
logger.exception(msg)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import abc
|
|||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database.models.variable.model import Variable
|
||||
|
|
@ -13,7 +14,7 @@ class VariableService(Service):
|
|||
name = "variable_service"
|
||||
|
||||
@abc.abstractmethod
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
|
||||
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
|
||||
"""Initialize user variables.
|
||||
|
||||
Args:
|
||||
|
|
@ -36,7 +37,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
||||
Args:
|
||||
|
|
@ -48,7 +49,19 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_variable(self, user_id: UUID | str, name: str, value: str, session: Session) -> Variable:
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
A list of variable names.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:
|
||||
"""Update a variable.
|
||||
|
||||
Args:
|
||||
|
|
@ -62,7 +75,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_variable(self, user_id: UUID | str, name: str, session: Session) -> None:
|
||||
async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None:
|
||||
"""Delete a variable.
|
||||
|
||||
Args:
|
||||
|
|
@ -75,7 +88,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None:
|
||||
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
|
||||
"""Delete a variable by ID.
|
||||
|
||||
Args:
|
||||
|
|
@ -85,7 +98,7 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_variable(
|
||||
async def create_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
|
|
@ -93,7 +106,7 @@ class VariableService(Service):
|
|||
*,
|
||||
default_fields: list[str],
|
||||
_type: str,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
) -> Variable:
|
||||
"""Create a variable.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ if TYPE_CHECKING:
|
|||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
|
@ -28,7 +30,7 @@ class KubernetesSecretService(VariableService, Service):
|
|||
self.kubernetes_secrets = KubernetesSecretManager()
|
||||
|
||||
@override
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
|
||||
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
|
||||
# Check for environment variables that should be stored in the database
|
||||
should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not"
|
||||
logger.info(f"{should_or_should_not} store environment variables in the kubernetes.")
|
||||
|
|
@ -45,7 +47,8 @@ class KubernetesSecretService(VariableService, Service):
|
|||
|
||||
try:
|
||||
secret_name = encode_user_id(user_id)
|
||||
self.kubernetes_secrets.create_secret(
|
||||
await asyncio.to_thread(
|
||||
self.kubernetes_secrets.create_secret,
|
||||
name=secret_name,
|
||||
data=variables,
|
||||
)
|
||||
|
|
@ -75,12 +78,13 @@ class KubernetesSecretService(VariableService, Service):
|
|||
msg = f"user_id {user_id} variable name {name} not found."
|
||||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def get_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
_session: Session,
|
||||
session: Session,
|
||||
) -> str:
|
||||
secret_name = encode_user_id(user_id)
|
||||
key, value = self.resolve_variable(secret_name, user_id, name)
|
||||
|
|
@ -92,10 +96,11 @@ class KubernetesSecretService(VariableService, Service):
|
|||
raise TypeError(msg)
|
||||
return value
|
||||
|
||||
def list_variables(
|
||||
@override
|
||||
def list_variables_sync(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
_session: Session,
|
||||
session: Session,
|
||||
) -> list[str | None]:
|
||||
variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id))
|
||||
if not variables:
|
||||
|
|
@ -109,28 +114,49 @@ class KubernetesSecretService(VariableService, Service):
|
|||
names.append(key)
|
||||
return names
|
||||
|
||||
def update_variable(
|
||||
@override
|
||||
async def list_variables(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
session: AsyncSession,
|
||||
) -> list[str | None]:
|
||||
return await asyncio.to_thread(self.list_variables_sync, user_id, session.sync_session)
|
||||
|
||||
def _update_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
_session: Session,
|
||||
):
|
||||
secret_name = encode_user_id(user_id)
|
||||
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
|
||||
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})
|
||||
|
||||
def delete_variable(self, user_id: UUID | str, name: str, _session: Session) -> None:
|
||||
secret_name = encode_user_id(user_id)
|
||||
@override
|
||||
async def update_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
session: AsyncSession,
|
||||
):
|
||||
return await asyncio.to_thread(self._update_variable, user_id, name, value)
|
||||
|
||||
def _delete_variable(self, user_id: UUID | str, name: str) -> None:
|
||||
secret_name = encode_user_id(user_id)
|
||||
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
|
||||
self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key)
|
||||
|
||||
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, _session: Session) -> None:
|
||||
self.delete_variable(user_id, _session, str(variable_id))
|
||||
@override
|
||||
async def delete_variable(self, user_id: UUID | str, name: str, session: AsyncSession) -> None:
|
||||
await asyncio.to_thread(self._delete_variable, user_id, name)
|
||||
|
||||
@override
|
||||
def create_variable(
|
||||
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID | str, session: AsyncSession) -> None:
|
||||
await self.delete_variable(user_id, str(variable_id), session)
|
||||
|
||||
@override
|
||||
async def create_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
|
|
@ -138,7 +164,7 @@ class KubernetesSecretService(VariableService, Service):
|
|||
*,
|
||||
default_fields: list[str],
|
||||
_type: str,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
) -> Variable:
|
||||
secret_name = encode_user_id(user_id)
|
||||
secret_key = name
|
||||
|
|
@ -147,7 +173,9 @@ class KubernetesSecretService(VariableService, Service):
|
|||
else:
|
||||
_type = GENERIC_TYPE
|
||||
|
||||
self.kubernetes_secrets.upsert_secret(secret_name=secret_name, data={secret_key: value})
|
||||
await asyncio.to_thread(
|
||||
self.kubernetes_secrets.upsert_secret, secret_name=secret_name, data={secret_key: value}
|
||||
)
|
||||
|
||||
variable_base = VariableCreate(
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
||||
|
|
@ -24,7 +26,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
def __init__(self, settings_service: SettingsService):
|
||||
self.settings_service = settings_service
|
||||
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None:
|
||||
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None:
|
||||
if not self.settings_service.settings.store_environment_variables:
|
||||
logger.info("Skipping environment variable storage.")
|
||||
return
|
||||
|
|
@ -34,12 +36,12 @@ class DatabaseVariableService(VariableService, Service):
|
|||
if var_name in os.environ and os.environ[var_name].strip():
|
||||
value = os.environ[var_name].strip()
|
||||
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name)
|
||||
existing = session.exec(query).first()
|
||||
existing = (await session.exec(query)).first()
|
||||
try:
|
||||
if existing:
|
||||
self.update_variable(user_id, var_name, value, session)
|
||||
await self.update_variable(user_id, var_name, value, session)
|
||||
else:
|
||||
self.create_variable(
|
||||
await self.create_variable(
|
||||
user_id=user_id,
|
||||
name=var_name,
|
||||
value=value,
|
||||
|
|
@ -76,40 +78,46 @@ class DatabaseVariableService(VariableService, Service):
|
|||
# we decrypt the value
|
||||
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
|
||||
|
||||
def get_all(self, user_id: UUID | str, session: Session) -> list[Variable | None]:
|
||||
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())
|
||||
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]:
|
||||
stmt = select(Variable).where(Variable.user_id == user_id)
|
||||
return list((await session.exec(stmt)).all())
|
||||
|
||||
def list_variables(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
variables = self.get_all(user_id=user_id, session=session)
|
||||
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
||||
def update_variable(
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
variables = await self.get_all(user_id=user_id, session=session)
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
||||
async def update_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
value: str,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
):
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
|
||||
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name)
|
||||
variable = (await session.exec(stmt)).first()
|
||||
if not variable:
|
||||
msg = f"{name} variable not found."
|
||||
raise ValueError(msg)
|
||||
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service)
|
||||
variable.value = encrypted
|
||||
session.add(variable)
|
||||
session.commit()
|
||||
session.refresh(variable)
|
||||
await session.commit()
|
||||
await session.refresh(variable)
|
||||
return variable
|
||||
|
||||
def update_variable_fields(
|
||||
async def update_variable_fields(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
variable_id: UUID | str,
|
||||
variable: VariableUpdate,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
):
|
||||
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
|
||||
db_variable = session.exec(query).one()
|
||||
db_variable = (await session.exec(query)).one()
|
||||
db_variable.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
variable.value = variable.value or ""
|
||||
|
|
@ -121,33 +129,34 @@ class DatabaseVariableService(VariableService, Service):
|
|||
setattr(db_variable, key, value)
|
||||
|
||||
session.add(db_variable)
|
||||
session.commit()
|
||||
session.refresh(db_variable)
|
||||
await session.commit()
|
||||
await session.refresh(db_variable)
|
||||
return db_variable
|
||||
|
||||
def delete_variable(
|
||||
async def delete_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
|
||||
variable = session.exec(stmt).first()
|
||||
variable = (await session.exec(stmt)).first()
|
||||
if not variable:
|
||||
msg = f"{name} variable not found."
|
||||
raise ValueError(msg)
|
||||
session.delete(variable)
|
||||
session.commit()
|
||||
await session.delete(variable)
|
||||
await session.commit()
|
||||
|
||||
def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None:
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first()
|
||||
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
|
||||
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)
|
||||
variable = (await session.exec(stmt)).first()
|
||||
if not variable:
|
||||
msg = f"{variable_id} variable not found."
|
||||
raise ValueError(msg)
|
||||
session.delete(variable)
|
||||
session.commit()
|
||||
await session.delete(variable)
|
||||
await session.commit()
|
||||
|
||||
def create_variable(
|
||||
async def create_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
|
|
@ -155,7 +164,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
*,
|
||||
default_fields: Sequence[str] = (),
|
||||
_type: str = GENERIC_TYPE,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
):
|
||||
variable_base = VariableCreate(
|
||||
name=name,
|
||||
|
|
@ -165,6 +174,6 @@ class DatabaseVariableService(VariableService, Service):
|
|||
)
|
||||
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
|
||||
session.add(variable)
|
||||
session.commit()
|
||||
session.refresh(variable)
|
||||
await session.commit()
|
||||
await session.refresh(variable)
|
||||
return variable
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ def _wrap_file_read_blocking(func):
|
|||
"_read_pyc",
|
||||
}:
|
||||
return func(self, *args, **kwargs)
|
||||
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
|
||||
return func(self, *args, **kwargs)
|
||||
raise _blocking_error(func)
|
||||
|
||||
return file_op
|
||||
|
|
@ -104,6 +106,8 @@ def _wrap_file_write_blocking(func):
|
|||
for frame_info in inspect.stack():
|
||||
if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_write_pyc":
|
||||
return func(self, *args, **kwargs)
|
||||
if frame_info.filename.endswith("settings/service.py") and frame_info.function == "initialize":
|
||||
return func(self, *args, **kwargs)
|
||||
if self not in {sys.stdout, sys.stderr}:
|
||||
raise _blocking_error(func)
|
||||
return func(self, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ async def test_initialize_services():
|
|||
"""Benchmark the initialization of services."""
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await initialize_services(fix_migration=False)
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
||||
|
|
@ -45,8 +45,8 @@ async def test_initialize_super_user():
|
|||
from langflow.initial_setup.setup import initialize_super_user_if_needed
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await asyncio.to_thread(initialize_super_user_if_needed)
|
||||
await initialize_services(fix_migration=False)
|
||||
await initialize_super_user_if_needed()
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ async def test_create_starter_projects():
|
|||
from langflow.interface.types import get_and_cache_all_types_dict
|
||||
from langflow.services.utils import initialize_services
|
||||
|
||||
await asyncio.to_thread(initialize_services, fix_migration=False)
|
||||
await initialize_services(fix_migration=False)
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
types_dict = await get_and_cache_all_types_dict(settings_service)
|
||||
await asyncio.to_thread(create_or_update_starter_projects, types_dict)
|
||||
|
|
@ -81,6 +81,6 @@ async def test_load_flows():
|
|||
"""Benchmark loading flows from directory."""
|
||||
from langflow.initial_setup.setup import load_flows_from_directory
|
||||
|
||||
await asyncio.to_thread(load_flows_from_directory)
|
||||
await load_flows_from_directory()
|
||||
settings_service = await asyncio.to_thread(get_settings_service)
|
||||
assert "test_performance.db" in settings_service.settings.database_url
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
|
|
@ -8,7 +8,9 @@ from langflow.services.deps import get_settings_service
|
|||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import Session, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -18,114 +20,125 @@ def service():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
async def session():
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def test_initialize_user_variables__create_and_update(service, session):
|
||||
def _get_variable(
|
||||
session: Session,
|
||||
service,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
):
|
||||
return service.get_variable(user_id, name, field, session=session)
|
||||
|
||||
|
||||
async def test_initialize_user_variables__create_and_update(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
field = ""
|
||||
good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)}
|
||||
bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"}
|
||||
env_vars = {**good_vars, **bad_vars}
|
||||
|
||||
service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
|
||||
await service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
|
||||
env_vars["OPENAI_API_KEY"] = "updated_value"
|
||||
|
||||
with patch.dict("os.environ", env_vars, clear=True):
|
||||
service.initialize_user_variables(user_id=user_id, session=session)
|
||||
await service.initialize_user_variables(user_id=user_id, session=session)
|
||||
|
||||
variables = service.list_variables(user_id, session=session)
|
||||
variables = await service.list_variables(user_id, session=session)
|
||||
for name in variables:
|
||||
value = service.get_variable(user_id, name, field, session=session)
|
||||
value = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
assert value == env_vars[name]
|
||||
|
||||
assert all(i in variables for i in good_vars)
|
||||
assert all(i not in variables for i in bad_vars)
|
||||
|
||||
|
||||
def test_initialize_user_variables__not_found_variable(service, session):
|
||||
async def test_initialize_user_variables__not_found_variable(service, session: AsyncSession):
|
||||
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
|
||||
m.side_effect = Exception()
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
await service.initialize_user_variables(uuid4(), session=session)
|
||||
assert True
|
||||
|
||||
|
||||
def test_initialize_user_variables__skipping_environment_variable_storage(service, session):
|
||||
async def test_initialize_user_variables__skipping_environment_variable_storage(service, session: AsyncSession):
|
||||
service.settings_service.settings.store_environment_variables = False
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
await service.initialize_user_variables(uuid4(), session=session)
|
||||
assert True
|
||||
|
||||
|
||||
def test_get_variable(service, session):
|
||||
async def test_get_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.get_variable(user_id, name, field, session=session)
|
||||
result = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert result == value
|
||||
|
||||
|
||||
def test_get_variable__valueerror(service, session):
|
||||
async def test_get_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
field = ""
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
|
||||
def test_get_variable__typeerror(service, session):
|
||||
async def test_get_variable__typeerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = "session_id"
|
||||
_type = CREDENTIAL_TYPE
|
||||
service.create_variable(user_id, name, value, _type=_type, session=session)
|
||||
await service.create_variable(user_id, name, value, _type=_type, session=session)
|
||||
|
||||
with pytest.raises(TypeError) as exc:
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "purpose is to prevent the exposure of value" in str(exc.value)
|
||||
|
||||
|
||||
def test_list_variables(service, session):
|
||||
async def test_list_variables(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
names = ["name1", "name2", "name3"]
|
||||
value = "value"
|
||||
for name in names:
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
result = service.list_variables(user_id, session=session)
|
||||
result = await service.list_variables(user_id, session=session)
|
||||
|
||||
assert all(name in result for name in names)
|
||||
|
||||
|
||||
def test_list_variables__empty(service, session):
|
||||
result = service.list_variables(uuid4(), session=session)
|
||||
async def test_list_variables__empty(service, session: AsyncSession):
|
||||
result = await service.list_variables(uuid4(), session=session)
|
||||
|
||||
assert not result
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_update_variable(service, session):
|
||||
async def test_update_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
old_value = "old_value"
|
||||
new_value = "new_value"
|
||||
field = ""
|
||||
service.create_variable(user_id, name, old_value, session=session)
|
||||
await service.create_variable(user_id, name, old_value, session=session)
|
||||
|
||||
old_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
result = service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = service.get_variable(user_id, name, field, session=session)
|
||||
old_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
result = await service.update_variable(user_id, name, new_value, session=session)
|
||||
new_recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert old_value == old_recovered
|
||||
assert new_value == new_recovered
|
||||
|
|
@ -139,26 +152,26 @@ def test_update_variable(service, session):
|
|||
assert isinstance(result.updated_at, datetime)
|
||||
|
||||
|
||||
def test_update_variable__valueerror(service, session):
|
||||
async def test_update_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.update_variable(user_id, name, value, session=session)
|
||||
await service.update_variable(user_id, name, value, session=session)
|
||||
|
||||
|
||||
def test_update_variable_fields(service, session):
|
||||
async def test_update_variable_fields(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
new_name = new_value = "donkey"
|
||||
variable = service.create_variable(user_id, "old_name", "old_value", session=session)
|
||||
variable = await service.create_variable(user_id, "old_name", "old_value", session=session)
|
||||
saved = variable.model_dump()
|
||||
variable = VariableUpdate(**saved)
|
||||
variable.name = new_name
|
||||
variable.value = new_value
|
||||
variable.default_fields = ["new_field"]
|
||||
|
||||
result = service.update_variable_fields(
|
||||
result = await service.update_variable_fields(
|
||||
user_id=user_id,
|
||||
variable_id=saved.get("id"),
|
||||
variable=variable,
|
||||
|
|
@ -177,58 +190,58 @@ def test_update_variable_fields(service, session):
|
|||
assert saved.get("updated_at") != result.updated_at
|
||||
|
||||
|
||||
def test_delete_variable(service, session):
|
||||
async def test_delete_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = ""
|
||||
|
||||
service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.delete_variable(user_id, name, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
||||
def test_delete_variable__valueerror(service, session):
|
||||
async def test_delete_variable__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
await service.delete_variable(user_id, name, session=session)
|
||||
|
||||
|
||||
def test_delete_variable_by_id(service, session):
|
||||
async def test_delete_variable_by_id(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
field = "field"
|
||||
|
||||
saved = service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
saved = await service.create_variable(user_id, name, value, session=session)
|
||||
recovered = await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
await service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
await session.run_sync(_get_variable, service, user_id, name, field)
|
||||
|
||||
assert recovered == value
|
||||
|
||||
|
||||
def test_delete_variable_by_id__valueerror(service, session):
|
||||
async def test_delete_variable_by_id__valueerror(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
variable_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match=f"{variable_id} variable not found."):
|
||||
service.delete_variable_by_id(user_id, variable_id, session=session)
|
||||
await service.delete_variable_by_id(user_id, variable_id, session=session)
|
||||
|
||||
|
||||
def test_create_variable(service, session):
|
||||
async def test_create_variable(service, session: AsyncSession):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
result = service.create_variable(user_id, name, value, session=session)
|
||||
result = await service.create_variable(user_id, name, value, session=session)
|
||||
|
||||
assert result.user_id == user_id
|
||||
assert result.name == name
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
|
|
@ -91,7 +92,7 @@ from langflow.services.utils import teardown_superuser
|
|||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
async def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
|
|
@ -104,29 +105,28 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
|
|||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = iter([mock_session])
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
await teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
|
||||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
async def test_teardown_superuser_no_default_superuser():
|
||||
admin_user_name = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = admin_user_name
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session = AsyncMock(return_value=asyncio.Future())
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
mock_user.last_login_at = None
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = mock_user
|
||||
mock_session.exec.return_value = mock_result
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
await teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.delete.assert_not_awaited()
|
||||
mock_session.commit.assert_not_awaited()
|
||||
|
|
|
|||
|
|
@ -5,18 +5,18 @@ from httpx import AsyncClient
|
|||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.utils import async_session_getter, session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user(client): # noqa: ARG001
|
||||
async def super_user(client): # noqa: ARG001
|
||||
settings_manager = get_settings_service()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
with session_getter(get_db_service()) as session:
|
||||
return create_super_user(
|
||||
db=session,
|
||||
async with async_session_getter(get_db_service()) as db:
|
||||
return await create_super_user(
|
||||
db=db,
|
||||
username=auth_settings.SUPERUSER,
|
||||
password=auth_settings.SUPERUSER_PASSWORD,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue