diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index 02d9816f5..20c23f8c0 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -1,4 +1,4 @@ -from langflow.services.auth.utils import create_super_user +from langflow.services.auth.utils import create_super_user, verify_password from langflow.services.database.utils import initialize_database from langflow.services.manager import service_manager from langflow.services.schema import ServiceType @@ -6,50 +6,82 @@ from langflow.services.settings.constants import ( DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD, ) +from sqlmodel import Session from .getters import get_session, get_settings_service from loguru import logger -def setup_superuser(settings_service, session): - """ - Setup the superuser. - """ - # We will use the SUPERUSER and SUPERUSER_PASSWORD - # vars on settings_manager.auth_settings to create the superuser - # if it does not exist. +def get_or_create_super_user(session: Session, username, password, is_default): + from langflow.services.database.models.user.user import User + + user = session.query(User).filter(User.username == username).first() + + if user and user.is_superuser and verify_password(password, user.password): + return None # Superuser already exists + + if user and is_default: + if user.is_superuser: + if verify_password(password, user.password): + return None + else: + # Superuser exists but password is incorrect + # which means that the user has changed the + # base superuser credentials. + # This means that the user has already created + # a superuser and changed the password in the UI + # so we don't need to do anything. + logger.debug( + "Superuser exists but password is incorrect. " + "This means that the user has changed the " + "base superuser credentials." + ) + return None + else: + logger.debug( + "User with superuser credentials exists but is not a superuser." + ) + return None + + if user: + if verify_password(password, user.password): + raise ValueError( + "User with superuser credentials exists but is not a superuser." + ) + else: + raise ValueError("Incorrect superuser credentials") + + if is_default: + logger.debug("Creating default superuser.") + else: + logger.debug("Creating superuser.") + + return create_super_user(username, password, db=session) + + +def setup_superuser(settings_service, session: Session): if settings_service.auth_settings.AUTO_LOGIN: logger.debug("AUTO_LOGIN is set to True. Creating default superuser.") username = settings_service.auth_settings.SUPERUSER password = settings_service.auth_settings.SUPERUSER_PASSWORD - if username == DEFAULT_SUPERUSER and password == DEFAULT_SUPERUSER_PASSWORD: - logger.debug("Default superuser credentials detected.") - logger.debug("Creating default superuser.") - else: - logger.debug("Creating superuser.") + + is_default = (username == DEFAULT_SUPERUSER) and ( + password == DEFAULT_SUPERUSER_PASSWORD + ) try: - from langflow.services.database.models.user.user import User - - user = session.query(User).filter(User.username == username).first() - if user and user.is_superuser is True: - return + user = get_or_create_super_user( + session=session, username=username, password=password, is_default=is_default + ) + if user is not None: + logger.debug("Superuser created successfully.") except Exception as exc: logger.exception(exc) raise RuntimeError( "Could not create superuser. Please create a superuser manually." ) from exc - try: - # create superuser - create_super_user(db=session, username=username, password=password) - except Exception as exc: - logger.exception(exc) - raise RuntimeError( - "Could not create superuser. Please create a superuser manually." - ) from exc - # reset superuser credentials - settings_service.auth_settings.reset_credentials() - logger.debug("Superuser created successfully.") + finally: + settings_service.auth_settings.reset_credentials() def teardown_superuser(settings_service, session): @@ -124,30 +156,40 @@ def initialize_services(): from langflow.services.task import factory as task_factory from langflow.services.session import factory as session_service_factory # type: ignore - service_manager.register_factory(settings_factory.SettingsServiceFactory()) - service_manager.register_factory( - auth_factory.AuthServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE] - ) - service_manager.register_factory( - database_factory.DatabaseServiceFactory(), - dependencies=[ServiceType.SETTINGS_SERVICE], - ) - service_manager.register_factory( - cache_factory.CacheServiceFactory(), dependencies=[ServiceType.SETTINGS_SERVICE] - ) - service_manager.register_factory(chat_factory.ChatServiceFactory()) + factory_and_dependencies = [ + (settings_factory.SettingsServiceFactory(), []), + ( + auth_factory.AuthServiceFactory(), + [ServiceType.SETTINGS_SERVICE], + ), + ( + database_factory.DatabaseServiceFactory(), + [ServiceType.SETTINGS_SERVICE], + ), + ( + cache_factory.CacheServiceFactory(), + [ServiceType.SETTINGS_SERVICE], + ), + (chat_factory.ChatServiceFactory(), []), + (task_factory.TaskServiceFactory(), []), + ( + session_service_factory.SessionServiceFactory(), + [ServiceType.CACHE_SERVICE], + ), + ] + for factory, dependencies in factory_and_dependencies: + try: + service_manager.register_factory(factory, dependencies=dependencies) + except Exception as exc: + logger.exception(exc) + raise RuntimeError( + "Could not initialize services. Please check your settings." + ) from exc - service_manager.register_factory(task_factory.TaskServiceFactory()) - - service_manager.register_factory( - session_service_factory.SessionServiceFactory(), - dependencies=[ServiceType.CACHE_SERVICE], - ) # Test cache connection service_manager.get(ServiceType.CACHE_SERVICE) - # Test database connection - service_manager.get(ServiceType.DATABASE_SERVICE) # Setup the superuser initialize_database() - session = next(get_session()) - setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), session) + setup_superuser( + service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()) + )