diff --git a/.env.example b/.env.example index 6c6d2f667..c03638072 100644 --- a/.env.example +++ b/.env.example @@ -45,3 +45,11 @@ LANGFLOW_OPEN_BROWSER= # Values: true, false # Example: LANGFLOW_REMOVE_API_KEYS=false LANGFLOW_REMOVE_API_KEYS= + +# Superuser username +# Example: LANGFLOW_SUPERUSER=admin +LANGFLOW_SUPERUSER= + +# Superuser password +# Example: LANGFLOW_SUPERUSER_PASSWORD=123456 +LANGFLOW_SUPERUSER_PASSWORD= \ No newline at end of file diff --git a/Makefile b/Makefile index 7dc0e7254..2fe79291e 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,7 @@ coverage: --cov-report term-missing:skip-covered tests: + @make install_backend poetry run pytest tests -n auto format: diff --git a/src/backend/langflow/__main__.py b/src/backend/langflow/__main__.py index 3a110f380..84adf7bee 100644 --- a/src/backend/langflow/__main__.py +++ b/src/backend/langflow/__main__.py @@ -2,8 +2,9 @@ import sys import time import httpx from langflow.services.database.utils import session_getter -from langflow.services.manager import initialize_services, initialize_settings_manager -from langflow.services.utils import get_db_manager, get_settings_manager +from langflow.services.utils import initialize_services +from langflow.services.getters import get_db_manager, get_settings_manager +from langflow.services.utils import initialize_settings_manager from multiprocess import Process, cpu_count # type: ignore import platform @@ -360,8 +361,8 @@ def superuser( # Verify that the superuser was created from langflow.services.database.models.user.user import User - user = session.query(User).filter(User.username == username).first() - if user is None: + user: User = session.query(User).filter(User.username == username).first() + if user is None or not user.is_superuser: typer.echo("Superuser creation failed.") return diff --git a/src/backend/langflow/api/v1/api_key.py b/src/backend/langflow/api/v1/api_key.py index 280f240e8..7f5916d06 100644 --- a/src/backend/langflow/api/v1/api_key.py +++ b/src/backend/langflow/api/v1/api_key.py @@ -14,7 +14,7 @@ from langflow.services.database.models.api_key.crud import ( delete_api_key, ) from langflow.services.database.models.user.user import User -from langflow.services.utils import get_session +from langflow.services.getters import get_session from sqlmodel import Session diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 690cad60b..adc6b3d61 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -14,7 +14,7 @@ from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, St from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_current_active_user, get_current_user from loguru import logger -from langflow.services.utils import get_chat_manager, get_session +from langflow.services.getters import get_chat_manager, get_session from cachetools import LRUCache from sqlmodel import Session from langflow.services.chat.manager import ChatManager diff --git a/src/backend/langflow/api/v1/components.py b/src/backend/langflow/api/v1/components.py index 4071461fb..d2b39dfd2 100644 --- a/src/backend/langflow/api/v1/components.py +++ b/src/backend/langflow/api/v1/components.py @@ -2,7 +2,7 @@ from datetime import timezone from typing import List from uuid import UUID from langflow.services.database.models.component import Component, ComponentModel -from langflow.services.utils import get_session +from langflow.services.getters import get_session from sqlmodel import Session, select from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.exc import IntegrityError diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index d1f898105..870f91d28 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -6,7 +6,7 @@ from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.processing.process import process_graph_cached, process_tweaks from langflow.services.database.models.user.user import User -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from loguru import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status import sqlalchemy as sa @@ -27,7 +27,7 @@ from langflow.interface.types import ( build_langchain_custom_component_list_from_path, ) -from langflow.services.utils import get_session +from langflow.services.getters import get_session from sqlmodel import Session # build router diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index be65048d4..c323dae53 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -12,8 +12,8 @@ from langflow.services.database.models.flow import ( FlowUpdate, ) from langflow.services.database.models.user.user import User -from langflow.services.utils import get_session -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_session +from langflow.services.getters import get_settings_manager import orjson from sqlmodel import Session from fastapi import APIRouter, Depends, HTTPException diff --git a/src/backend/langflow/api/v1/login.py b/src/backend/langflow/api/v1/login.py index 4241b8d47..9ff059bf9 100644 --- a/src/backend/langflow/api/v1/login.py +++ b/src/backend/langflow/api/v1/login.py @@ -2,7 +2,7 @@ from sqlmodel import Session from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from langflow.services.utils import get_session +from langflow.services.getters import get_session from langflow.api.v1.schemas import Token from langflow.services.auth.utils import ( authenticate_user, @@ -12,7 +12,7 @@ from langflow.services.auth.utils import ( get_current_active_user, ) -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager router = APIRouter(tags=["Login"]) diff --git a/src/backend/langflow/api/v1/users.py b/src/backend/langflow/api/v1/users.py index e68512e43..e1e24d197 100644 --- a/src/backend/langflow/api/v1/users.py +++ b/src/backend/langflow/api/v1/users.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select from fastapi import APIRouter, Depends, HTTPException -from langflow.services.utils import get_session +from langflow.services.getters import get_session from langflow.services.auth.utils import ( get_current_active_superuser, get_current_active_user, diff --git a/src/backend/langflow/interface/agents/base.py b/src/backend/langflow/interface/agents/base.py index 574264e47..f48015b2c 100644 --- a/src/backend/langflow/interface/agents/base.py +++ b/src/backend/langflow/interface/agents/base.py @@ -5,7 +5,7 @@ from langchain.agents import types from langflow.custom.customs import get_custom_nodes from langflow.interface.agents.custom import CUSTOM_AGENTS from langflow.interface.base import LangChainTypeCreator -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.agents import AgentFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/base.py b/src/backend/langflow/interface/base.py index b006a3174..4bb657b3c 100644 --- a/src/backend/langflow/interface/base.py +++ b/src/backend/langflow/interface/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type, Union from langchain.chains.base import Chain from langchain.agents import AgentExecutor -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from pydantic import BaseModel from langflow.template.field.base import TemplateField diff --git a/src/backend/langflow/interface/chains/base.py b/src/backend/langflow/interface/chains/base.py index 755ac82dd..99b0a693f 100644 --- a/src/backend/langflow/interface/chains/base.py +++ b/src/backend/langflow/interface/chains/base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Type from langflow.custom.customs import get_custom_nodes from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.chains import ChainFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 1357daf68..9de09507a 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -4,7 +4,7 @@ from fastapi import HTTPException from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader -from langflow.services.utils import get_db_manager +from langflow.services.getters import get_db_manager from langflow.interface.custom.utils import extract_inner_type from langflow.utils import validate diff --git a/src/backend/langflow/interface/document_loaders/base.py b/src/backend/langflow/interface/document_loaders/base.py index a2c147e16..05311444b 100644 --- a/src/backend/langflow/interface/document_loaders/base.py +++ b/src/backend/langflow/interface/document_loaders/base.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.documentloaders import DocumentLoaderFrontNode from langflow.interface.custom_lists import documentloaders_type_to_cls_dict diff --git a/src/backend/langflow/interface/embeddings/base.py b/src/backend/langflow/interface/embeddings/base.py index 1063d10d1..0145d9859 100644 --- a/src/backend/langflow/interface/embeddings/base.py +++ b/src/backend/langflow/interface/embeddings/base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import embedding_type_to_cls_dict -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.base import FrontendNode from langflow.template.frontend_node.embeddings import EmbeddingFrontendNode diff --git a/src/backend/langflow/interface/llms/base.py b/src/backend/langflow/interface/llms/base.py index 87e4937cf..17a2ae0ee 100644 --- a/src/backend/langflow/interface/llms/base.py +++ b/src/backend/langflow/interface/llms/base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import llm_type_to_cls_dict -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.llms import LLMFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/memories/base.py b/src/backend/langflow/interface/memories/base.py index 61c6cc430..6c826d0ac 100644 --- a/src/backend/langflow/interface/memories/base.py +++ b/src/backend/langflow/interface/memories/base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import memory_type_to_cls_dict -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.base import FrontendNode from langflow.template.frontend_node.memories import MemoryFrontendNode diff --git a/src/backend/langflow/interface/output_parsers/base.py b/src/backend/langflow/interface/output_parsers/base.py index b6eb36a0e..48bcd1896 100644 --- a/src/backend/langflow/interface/output_parsers/base.py +++ b/src/backend/langflow/interface/output_parsers/base.py @@ -4,7 +4,7 @@ from langchain import output_parsers from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.output_parsers import OutputParserFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/prompts/base.py b/src/backend/langflow/interface/prompts/base.py index 70818429e..d74e0c1e8 100644 --- a/src/backend/langflow/interface/prompts/base.py +++ b/src/backend/langflow/interface/prompts/base.py @@ -5,7 +5,7 @@ from langchain import prompts from langflow.custom.customs import get_custom_nodes from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.prompts import PromptFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/retrievers/base.py b/src/backend/langflow/interface/retrievers/base.py index 92e3f2f61..415a7fda8 100644 --- a/src/backend/langflow/interface/retrievers/base.py +++ b/src/backend/langflow/interface/retrievers/base.py @@ -4,7 +4,7 @@ from langchain import retrievers from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.retrievers import RetrieverFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/text_splitters/base.py b/src/backend/langflow/interface/text_splitters/base.py index 8b21303ce..fba4e32cc 100644 --- a/src/backend/langflow/interface/text_splitters/base.py +++ b/src/backend/langflow/interface/text_splitters/base.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.textsplitters import TextSplittersFrontendNode from langflow.interface.custom_lists import textsplitter_type_to_cls_dict diff --git a/src/backend/langflow/interface/toolkits/base.py b/src/backend/langflow/interface/toolkits/base.py index fe0003b15..be602cb7c 100644 --- a/src/backend/langflow/interface/toolkits/base.py +++ b/src/backend/langflow/interface/toolkits/base.py @@ -4,7 +4,7 @@ from langchain.agents import agent_toolkits from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class, import_module -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from loguru import logger from langflow.utils.util import build_template_from_class diff --git a/src/backend/langflow/interface/tools/base.py b/src/backend/langflow/interface/tools/base.py index 1dbc9a6ed..999f93703 100644 --- a/src/backend/langflow/interface/tools/base.py +++ b/src/backend/langflow/interface/tools/base.py @@ -15,7 +15,7 @@ from langflow.interface.tools.constants import ( OTHER_TOOLS, ) from langflow.interface.tools.util import get_tool_params -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.field.base import TemplateField from langflow.template.template.base import Template diff --git a/src/backend/langflow/interface/utilities/base.py b/src/backend/langflow/interface/utilities/base.py index 9009983b0..3cec49be9 100644 --- a/src/backend/langflow/interface/utilities/base.py +++ b/src/backend/langflow/interface/utilities/base.py @@ -5,7 +5,7 @@ from langchain import SQLDatabase, utilities from langflow.custom.customs import get_custom_nodes from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.utilities import UtilitiesFrontendNode from loguru import logger diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index 5bf44e203..f993c971b 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -10,7 +10,7 @@ from langchain.base_language import BaseLanguageModel from PIL.Image import Image from loguru import logger from langflow.services.chat.config import ChatConfig -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager def load_file_into_dict(file_path: str) -> dict: diff --git a/src/backend/langflow/interface/vector_store/base.py b/src/backend/langflow/interface/vector_store/base.py index f7aca8c9c..06b8668f3 100644 --- a/src/backend/langflow/interface/vector_store/base.py +++ b/src/backend/langflow/interface/vector_store/base.py @@ -4,7 +4,7 @@ from langchain import vectorstores from langflow.interface.base import LangChainTypeCreator from langflow.interface.importing.utils import import_class -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.template.frontend_node.vectorstores import VectorStoreFrontendNode from loguru import logger diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index c869ec138..9caa157d0 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -9,9 +9,11 @@ from langflow.api import router from langflow.interface.utils import setup_llm_caching -from langflow.services.database.utils import initialize_database -from langflow.services.manager import initialize_services, teardown_services +from langflow.services.utils import initialize_services from langflow.services.plugins.langfuse import LangfuseInstance +from langflow.services.utils import ( + teardown_services, +) from langflow.utils.logger import configure @@ -39,11 +41,12 @@ def create_app(): app.include_router(router) app.on_event("startup")(initialize_services) - app.on_event("startup")(initialize_database) app.on_event("startup")(setup_llm_caching) - app.on_event("shutdown")(teardown_services) app.on_event("startup")(LangfuseInstance.update) + + app.on_event("shutdown")(teardown_services) app.on_event("shutdown")(LangfuseInstance.teardown) + return app diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index 485968a38..801874c7d 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -12,7 +12,7 @@ from langflow.services.database.models.user.crud import ( get_user_by_username, update_user_last_login_at, ) -from langflow.services.utils import get_session, get_settings_manager +from langflow.services.getters import get_session, get_settings_manager from sqlmodel import Session oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login") @@ -37,15 +37,13 @@ async def api_key_security( result: Optional[Union[ApiKey, User]] = None if settings_manager.auth_settings.AUTO_LOGIN: # Get the first user - if not settings_manager.auth_settings.FIRST_SUPERUSER: + if not settings_manager.auth_settings.SUPERUSER: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Missing first superuser credentials", ) - result = get_user_by_username( - db, settings_manager.auth_settings.FIRST_SUPERUSER - ) + result = get_user_by_username(db, settings_manager.auth_settings.SUPERUSER) elif not query_param and not header_param: raise HTTPException( @@ -182,8 +180,8 @@ def create_super_user( def create_user_longterm_token(db: Session = Depends(get_session)) -> dict: settings_manager = get_settings_manager() - username = settings_manager.auth_settings.FIRST_SUPERUSER - password = settings_manager.auth_settings.FIRST_SUPERUSER_PASSWORD + username = settings_manager.auth_settings.SUPERUSER + password = settings_manager.auth_settings.SUPERUSER_PASSWORD if not username or not password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/backend/langflow/services/base.py b/src/backend/langflow/services/base.py index aaa966047..301771944 100644 --- a/src/backend/langflow/services/base.py +++ b/src/backend/langflow/services/base.py @@ -3,6 +3,10 @@ from abc import ABC class Service(ABC): name: str + ready: bool = False def teardown(self): pass + + def set_ready(self): + self.ready = True diff --git a/src/backend/langflow/services/database/manager.py b/src/backend/langflow/services/database/manager.py index 7f8afab6f..3c842bf52 100644 --- a/src/backend/langflow/services/database/manager.py +++ b/src/backend/langflow/services/database/manager.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from langflow.services.base import Service from langflow.services.database.models.user.crud import get_user_by_username from langflow.services.database.utils import Result, TableResults -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from sqlalchemy import inspect import sqlalchemy as sa from sqlmodel import SQLModel, Session, create_engine @@ -166,10 +166,10 @@ class DatabaseManager(Service): try: settings_manager = get_settings_manager() # remove the default superuser if auto_login is enabled - # using the FIRST_SUPERUSER to get the user + # using the SUPERUSER to get the user if settings_manager.auth_settings.AUTO_LOGIN: logger.debug("Removing default superuser") - username = settings_manager.auth_settings.FIRST_SUPERUSER + username = settings_manager.auth_settings.SUPERUSER with Session(self.engine) as session: user = get_user_by_username(session, username) session.delete(user) diff --git a/src/backend/langflow/services/database/models/user/crud.py b/src/backend/langflow/services/database/models/user/crud.py index f7f5958fe..36f03e684 100644 --- a/src/backend/langflow/services/database/models/user/crud.py +++ b/src/backend/langflow/services/database/models/user/crud.py @@ -3,7 +3,7 @@ from typing import Union from uuid import UUID from fastapi import Depends, HTTPException, status from langflow.services.database.models.user.user import User, UserUpdate -from langflow.services.utils import get_session +from langflow.services.getters import get_session from sqlalchemy.exc import IntegrityError from sqlmodel import Session from typing import Optional diff --git a/src/backend/langflow/services/getters.py b/src/backend/langflow/services/getters.py new file mode 100644 index 000000000..8b32aef02 --- /dev/null +++ b/src/backend/langflow/services/getters.py @@ -0,0 +1,26 @@ +from langflow.services import ServiceType, service_manager +from typing import TYPE_CHECKING, Generator + + +if TYPE_CHECKING: + from langflow.services.database.manager import DatabaseManager + from langflow.services.settings.manager import SettingsManager + from langflow.services.chat.manager import ChatManager + from sqlmodel import Session + + +def get_settings_manager() -> "SettingsManager": + return service_manager.get(ServiceType.SETTINGS_MANAGER) + + +def get_db_manager() -> "DatabaseManager": + return service_manager.get(ServiceType.DATABASE_MANAGER) + + +def get_session() -> Generator["Session", None, None]: + db_manager = service_manager.get(ServiceType.DATABASE_MANAGER) + yield from db_manager.get_session() + + +def get_chat_manager() -> "ChatManager": + return service_manager.get(ServiceType.CHAT_MANAGER) diff --git a/src/backend/langflow/services/manager.py b/src/backend/langflow/services/manager.py index 60a93fe16..9398a10f4 100644 --- a/src/backend/langflow/services/manager.py +++ b/src/backend/langflow/services/manager.py @@ -61,6 +61,7 @@ class ServiceManager: self.services[service_name] = self.factories[service_name].create( **dependent_services ) + self.services[service_name].set_ready() def _validate_service_creation(self, service_name: ServiceType): """ @@ -93,65 +94,3 @@ class ServiceManager: service_manager = ServiceManager() - - -def initialize_services(): - """ - Initialize all the services needed. - """ - from langflow.services.database import factory as database_factory - from langflow.services.cache import factory as cache_factory - from langflow.services.chat import factory as chat_factory - from langflow.services.settings import factory as settings_factory - from langflow.services.auth import factory as auth_factory - - service_manager.register_factory(settings_factory.SettingsManagerFactory()) - service_manager.register_factory( - auth_factory.AuthManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] - ) - service_manager.register_factory( - database_factory.DatabaseManagerFactory(), - dependencies=[ServiceType.SETTINGS_MANAGER], - ) - service_manager.register_factory(cache_factory.CacheManagerFactory()) - service_manager.register_factory(chat_factory.ChatManagerFactory()) - - # Test cache connection - service_manager.get(ServiceType.CACHE_MANAGER) - # Test database connection - service_manager.get(ServiceType.DATABASE_MANAGER) - - -def initialize_settings_manager(): - """ - Initialize the settings manager. - """ - from langflow.services.settings import factory as settings_factory - - service_manager.register_factory(settings_factory.SettingsManagerFactory()) - - -def initialize_session_manager(): - """ - Initialize the session manager. - """ - from langflow.services.session import factory as session_manager_factory # type: ignore - from langflow.services.cache import factory as cache_factory - - initialize_settings_manager() - - service_manager.register_factory( - cache_factory.CacheManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] - ) - - service_manager.register_factory( - session_manager_factory.SessionManagerFactory(), - dependencies=[ServiceType.CACHE_MANAGER], - ) - - -def teardown_services(): - """ - Teardown all the services. - """ - service_manager.teardown() diff --git a/src/backend/langflow/services/plugins/langfuse.py b/src/backend/langflow/services/plugins/langfuse.py index ce3e25c53..333459080 100644 --- a/src/backend/langflow/services/plugins/langfuse.py +++ b/src/backend/langflow/services/plugins/langfuse.py @@ -1,4 +1,4 @@ -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager from langflow.utils.logger import logger ### Temporary implementation diff --git a/src/backend/langflow/services/settings/auth.py b/src/backend/langflow/services/settings/auth.py index 87a156df7..b6d288183 100644 --- a/src/backend/langflow/services/settings/auth.py +++ b/src/backend/langflow/services/settings/auth.py @@ -1,6 +1,10 @@ from pathlib import Path from typing import Optional import secrets +from langflow.services.settings.constants import ( + DEFAULT_SUPERUSER, + DEFAULT_SUPERUSER_PASSWORD, +) from langflow.services.settings.utils import read_secret_from_file, write_secret_to_file from pydantic import BaseSettings, Field, validator @@ -30,9 +34,9 @@ class AuthSettings(BaseSettings): # If AUTO_LOGIN = True # > The application does not request login and logs in automatically as a super user. - AUTO_LOGIN: bool = True - FIRST_SUPERUSER: str = "langflow" - FIRST_SUPERUSER_PASSWORD: str = "langflow" + AUTO_LOGIN: bool = False + SUPERUSER: str = DEFAULT_SUPERUSER + SUPERUSER_PASSWORD: str = DEFAULT_SUPERUSER_PASSWORD pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -41,6 +45,28 @@ class AuthSettings(BaseSettings): extra = "ignore" env_prefix = "LANGFLOW_" + def reset_credentials(self): + self.SUPERUSER = DEFAULT_SUPERUSER + self.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD + + # If autologin is true, then we need to set the credentials to + # the default values + # so we need to validate the superuser and superuser_password + # fields + @validator("SUPERUSER", "SUPERUSER_PASSWORD", pre=True) + def validate_superuser(cls, value, values): + if values.get("AUTO_LOGIN"): + if value != DEFAULT_SUPERUSER: + value = DEFAULT_SUPERUSER + logger.debug("Resetting superuser to default value") + if values.get("SUPERUSER_PASSWORD") != DEFAULT_SUPERUSER_PASSWORD: + values["SUPERUSER_PASSWORD"] = DEFAULT_SUPERUSER_PASSWORD + logger.debug("Resetting superuser password to default value") + + return value + + return value + @validator("SECRET_KEY", pre=True) def get_secret_key(cls, value, values): config_dir = values.get("CONFIG_DIR") diff --git a/src/backend/langflow/services/settings/constants.py b/src/backend/langflow/services/settings/constants.py new file mode 100644 index 000000000..6cf7d4823 --- /dev/null +++ b/src/backend/langflow/services/settings/constants.py @@ -0,0 +1,2 @@ +DEFAULT_SUPERUSER = "langflow" +DEFAULT_SUPERUSER_PASSWORD = "langflow" diff --git a/src/backend/langflow/services/utils.py b/src/backend/langflow/services/utils.py index 8b32aef02..5f8525797 100644 --- a/src/backend/langflow/services/utils.py +++ b/src/backend/langflow/services/utils.py @@ -1,26 +1,143 @@ -from langflow.services import ServiceType, service_manager -from typing import TYPE_CHECKING, Generator +from langflow.services.auth.utils import create_super_user +from langflow.services.database.utils import initialize_database +from langflow.services.manager import service_manager +from langflow.services.schema import ServiceType +from langflow.services.settings.constants import ( + DEFAULT_SUPERUSER, + DEFAULT_SUPERUSER_PASSWORD, +) +from .getters import get_session, get_settings_manager +from loguru import logger -if TYPE_CHECKING: - from langflow.services.database.manager import DatabaseManager - from langflow.services.settings.manager import SettingsManager - from langflow.services.chat.manager import ChatManager - from sqlmodel import Session +def setup_superuser(): + """ + Setup the superuser. + """ + # We will use the SUPERUSER and SUPERUSER_PASSWORD + # vars on settings_manager.auth_settings to create the superuser + # if it does not exist. + settings_manager = get_settings_manager() + if settings_manager.auth_settings.AUTO_LOGIN: + logger.debug("AUTO_LOGIN is set to True. Creating default superuser.") + + session = next(get_session()) + username = settings_manager.auth_settings.SUPERUSER + password = settings_manager.auth_settings.SUPERUSER_PASSWORD + if username == DEFAULT_SUPERUSER and password == DEFAULT_SUPERUSER_PASSWORD: + logger.debug("Default superuser credentials detected.") + logger.debug("Creating default superuser.") + else: + logger.debug("Creating superuser.") + + try: + from langflow.services.database.models.user.user import User + + user = session.query(User).filter(User.username == username).first() + if user and user.is_superuser is True: + return + except Exception as exc: + logger.exception(exc) + raise RuntimeError( + "Could not create superuser. Please create a superuser manually." + ) from exc + try: + # create superuser + create_super_user(db=session, username=username, password=password) + except Exception as exc: + logger.exception(exc) + raise RuntimeError( + "Could not create superuser. Please create a superuser manually." + ) from exc + # reset superuser credentials + settings_manager.auth_settings.reset_credentials() + logger.debug("Superuser created successfully.") -def get_settings_manager() -> "SettingsManager": - return service_manager.get(ServiceType.SETTINGS_MANAGER) +def teardown_superuser(): + """ + Teardown the superuser. + """ + # If AUTO_LOGIN is True, we will remove the default superuser + # from the database. + settings_manager = get_settings_manager() + if settings_manager.auth_settings.AUTO_LOGIN: + logger.debug("AUTO_LOGIN is set to True. Removing default superuser.") + session = next(get_session()) + username = settings_manager.auth_settings.SUPERUSER + from langflow.services.database.models.user.user import User + + user = session.query(User).filter(User.username == username).first() + if user and user.is_superuser: + session.delete(user) + session.commit() + logger.debug("Default superuser removed successfully.") + else: + logger.debug("Default superuser not found.") -def get_db_manager() -> "DatabaseManager": - return service_manager.get(ServiceType.DATABASE_MANAGER) +def teardown_services(): + """ + Teardown all the services. + """ + teardown_superuser() + service_manager.teardown() -def get_session() -> Generator["Session", None, None]: +def initialize_settings_manager(): + """ + Initialize the settings manager. + """ + from langflow.services.settings import factory as settings_factory + + service_manager.register_factory(settings_factory.SettingsManagerFactory()) + + +def initialize_session_manager(): + """ + Initialize the session manager. + """ + from langflow.services.session import factory as session_manager_factory # type: ignore + from langflow.services.cache import factory as cache_factory + + initialize_settings_manager() + + service_manager.register_factory( + cache_factory.CacheManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] + ) + + service_manager.register_factory( + session_manager_factory.SessionManagerFactory(), + dependencies=[ServiceType.CACHE_MANAGER], + ) + + +def initialize_services(): + """ + Initialize all the services needed. + """ + from langflow.services.database import factory as database_factory + from langflow.services.cache import factory as cache_factory + from langflow.services.chat import factory as chat_factory + from langflow.services.settings import factory as settings_factory + from langflow.services.auth import factory as auth_factory + + service_manager.register_factory(settings_factory.SettingsManagerFactory()) + service_manager.register_factory( + auth_factory.AuthManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER] + ) + service_manager.register_factory( + database_factory.DatabaseManagerFactory(), + dependencies=[ServiceType.SETTINGS_MANAGER], + ) + service_manager.register_factory(cache_factory.CacheManagerFactory()) + service_manager.register_factory(chat_factory.ChatManagerFactory()) + + # Test cache connection + service_manager.get(ServiceType.CACHE_MANAGER) + # Test database connection db_manager = service_manager.get(ServiceType.DATABASE_MANAGER) - yield from db_manager.get_session() - - -def get_chat_manager() -> "ChatManager": - return service_manager.get(ServiceType.CHAT_MANAGER) + # Setup the superuser + initialize_database() + if db_manager.ready: + setup_superuser() diff --git a/tests/conftest.py b/tests/conftest.py index 95aba4462..2c8b9016e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,13 @@ from contextlib import contextmanager import json from pathlib import Path from typing import AsyncGenerator, TYPE_CHECKING -from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.flow.flow import Flow from langflow.services.database.models.user.user import User, UserCreate +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -15,6 +16,9 @@ from sqlmodel import SQLModel, Session, create_engine from sqlmodel.pool import StaticPool from typer.testing import CliRunner +# we need to import tmpdir +import tempfile + if TYPE_CHECKING: from langflow.services.database.manager import DatabaseManager @@ -46,14 +50,14 @@ async def async_client() -> AsyncGenerator: # Create client fixture for FastAPI -@pytest.fixture(scope="module", autouse=True) -def client(): - from langflow.main import create_app +# @pytest.fixture(scope="module", autouse=True) +# def client(): +# from langflow.main import create_app - app = create_app() +# app = create_app() - with TestClient(app) as client: - yield client +# with TestClient(app) as client: +# yield client def get_graph(_type="basic"): @@ -111,8 +115,14 @@ def session_fixture(): yield session -@pytest.fixture(name="client") -def client_fixture(session: Session): +@pytest.fixture(name="client", autouse=True) +def client_fixture(session: Session, monkeypatch): + # Set the database url to a test database + db_dir = tempfile.mkdtemp() + db_path = Path(db_dir) / "test.db" + monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}") + # monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", 1) + def get_session_override(): return session @@ -120,10 +130,13 @@ def client_fixture(session: Session): app = create_app() - app.dependency_overrides[get_session] = get_session_override + # app.dependency_overrides[get_session] = get_session_override with TestClient(app) as client: yield client - app.dependency_overrides.clear() + # app.dependency_overrides.clear() + monkeypatch.undo() + # clear the temp db + db_path.unlink() # @contextmanager @@ -142,11 +155,6 @@ def client_fixture(session: Session): # create a fixture for session_getter above @pytest.fixture(name="session_getter") def session_getter_fixture(client): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - SQLModel.metadata.create_all(engine) - @contextmanager def blank_session_getter(db_manager: "DatabaseManager"): with Session(db_manager.engine) as session: @@ -172,17 +180,18 @@ def test_user(client): @pytest.fixture(scope="function") -def active_user(client, session): - user = User( - username="activeuser", - password=get_password_hash( - "testpassword" - ), # Assuming password needs to be hashed - is_active=True, - is_superuser=False, - ) - session.add(user) - session.commit() +def active_user(client): + db_manager = get_db_manager() + with session_getter(db_manager) as session: + user = User( + username="activeuser", + password=get_password_hash("testpassword"), + is_active=True, + is_superuser=False, + ) + session.add(user) + session.commit() + session.refresh(user) return user @@ -197,7 +206,7 @@ def logged_in_headers(client, active_user): @pytest.fixture -def flow(client, json_flow: str, session, active_user): +def flow(client, json_flow: str, active_user): from langflow.services.database.models.flow.flow import FlowCreate loaded_json = json.loads(json_flow) @@ -205,7 +214,9 @@ def flow(client, json_flow: str, session, active_user): name="test_flow", data=loaded_json.get("data"), user_id=active_user.id ) flow = Flow(**flow_data.dict()) - session.add(flow) - session.commit() + with session_getter(get_db_manager()) as session: + session.add(flow) + session.commit() + session.refresh(flow) return flow diff --git a/tests/test_cli.py b/tests/test_cli.py index 4ed00893e..2884dc800 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ from tempfile import tempdir from langflow.__main__ import app import pytest -from langflow.services import utils +from langflow.services import getters @pytest.fixture(scope="module") @@ -26,7 +26,7 @@ def test_components_path(runner, client, default_settings): ["run", "--components-path", str(temp_dir), *default_settings], ) assert result.exit_code == 0, result.stdout - settings_manager = utils.get_settings_manager() + settings_manager = getters.get_settings_manager() assert str(temp_dir) in settings_manager.settings.COMPONENTS_PATH diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index e75dc0e5b..1695cfd38 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -518,13 +518,13 @@ def db(app): app.db.drop_all() -def test_list_flows_return_type(component, session_getter): - flows = component.list_flows(get_session=session_getter) +def test_list_flows_return_type(component): + flows = component.list_flows() assert isinstance(flows, list) -def test_list_flows_flow_objects(component, session_getter): - flows = component.list_flows(get_session=session_getter) +def test_list_flows_flow_objects(component): + flows = component.list_flows() assert all(isinstance(flow, Flow) for flow in flows) diff --git a/tests/test_database.py b/tests/test_database.py index e4f68ca56..7641f1e65 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,4 +1,6 @@ from langflow.services.database.models.base import orjson_dumps +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import orjson import pytest @@ -178,9 +180,7 @@ def test_upload_file( assert response_data[1]["data"] == data -def test_download_file( - client: TestClient, session: Session, json_flow, active_user, logged_in_headers -): +def test_download_file(client: TestClient, json_flow, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -190,18 +190,20 @@ def test_download_file( FlowCreate(name="Flow 2", description="description", data=data), ] ) - for flow in flow_list.flows: - flow.user_id = active_user.id - db_flow = Flow.from_orm(flow) - session.add(db_flow) - session.commit() + db_manager = get_db_manager() + with session_getter(db_manager) as session: + for flow in flow_list.flows: + flow.user_id = active_user.id + db_flow = Flow.from_orm(flow) + session.add(db_flow) + session.commit() # Make request to endpoint response = client.get("api/v1/flows/download/", headers=logged_in_headers) # Check response status code - assert response.status_code == 200 + assert response.status_code == 200, response.json() # Check response data response_data = response.json()["flows"] - assert len(response_data) == 2 + assert len(response_data) == 2, response_data assert response_data[0]["name"] == "Flow 1" assert response_data[0]["description"] == "description" assert response_data[0]["data"] == data diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index cbb1eb08c..474a72e31 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,7 +1,8 @@ import uuid from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.api_key.api_key import ApiKey -from langflow.services.utils import get_settings_manager +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager, get_settings_manager import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS @@ -88,7 +89,7 @@ PROMPT_REQUEST = { @pytest.fixture -def created_api_key(session, active_user): +def created_api_key(active_user): hashed = get_password_hash("random_key") api_key = ApiKey( name="test_api_key", @@ -96,10 +97,11 @@ def created_api_key(session, active_user): api_key="random_key", hashed_api_key=hashed, ) - - session.add(api_key) - session.commit() - session.refresh(api_key) + db_manager = get_db_manager() + with session_getter(db_manager) as session: + session.add(api_key) + session.commit() + session.refresh(api_key) return api_key diff --git a/tests/test_login.py b/tests/test_login.py index 07abb35ab..651e2264b 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -1,3 +1,5 @@ +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager import pytest from langflow.services.database.models.user import User from langflow.services.auth.utils import get_password_hash @@ -15,10 +17,11 @@ def test_user(): ) -def test_login_successful(client, test_user, session): +def test_login_successful(client, test_user): # Adding the test user to the database - session.add(test_user) - session.commit() + with session_getter(get_db_manager()) as session: + session.add(test_user) + session.commit() response = client.post( "api/v1/login", data={"username": "testuser", "password": "testpassword"} diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index 434691038..b5c39f3a0 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -1,5 +1,5 @@ from fastapi.testclient import TestClient -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager def test_prompts_settings(client: TestClient, logged_in_headers): diff --git a/tests/test_setup_superuser.py b/tests/test_setup_superuser.py new file mode 100644 index 000000000..f1566d9ae --- /dev/null +++ b/tests/test_setup_superuser.py @@ -0,0 +1,140 @@ +from unittest.mock import patch, Mock, MagicMock, call +from langflow.services.database.models.user.user import User +from langflow.services.settings.constants import ( + DEFAULT_SUPERUSER, + DEFAULT_SUPERUSER_PASSWORD, +) +from langflow.services.utils import setup_superuser, teardown_superuser + + +@patch("langflow.services.utils.get_settings_manager") +@patch("langflow.services.utils.create_super_user") +@patch("langflow.services.utils.get_session") +def test_setup_superuser( + mock_get_session, mock_create_super_user, mock_get_settings_manager +): + # Test when AUTO_LOGIN is True + calls = [] + mock_settings_manager = Mock() + mock_settings_manager.auth_settings.AUTO_LOGIN = True + mock_settings_manager.auth_settings.SUPERUSER = DEFAULT_SUPERUSER + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD + mock_get_settings_manager.return_value = mock_settings_manager + mock_session = Mock() + mock_session.query.return_value.filter.return_value.first.return_value = ( + mock_session + ) + # return value of get_session is a generator + mock_get_session.return_value = iter([mock_session, mock_session, mock_session]) + setup_superuser() + mock_session.query.assert_called_once_with(User) + actual_expr = mock_session.query.return_value.filter.call_args[0][0] + expected_expr = User.username == DEFAULT_SUPERUSER + + assert str(actual_expr) == str(expected_expr) + create_call = call( + db=mock_session, username=DEFAULT_SUPERUSER, password=DEFAULT_SUPERUSER_PASSWORD + ) + calls.append(create_call) + mock_create_super_user.assert_has_calls(calls) + assert 1 == mock_create_super_user.call_count + + def reset_mock_credentials(): + mock_settings_manager.auth_settings.SUPERUSER = DEFAULT_SUPERUSER + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = ( + DEFAULT_SUPERUSER_PASSWORD + ) + + ADMIN_USER_NAME = "admin_user" + # Test when username and password are default + mock_settings_manager.auth_settings = Mock() + mock_settings_manager.auth_settings.AUTO_LOGIN = False + mock_settings_manager.auth_settings.SUPERUSER = ADMIN_USER_NAME + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = "password" + mock_settings_manager.auth_settings.reset_credentials = Mock( + side_effect=reset_mock_credentials + ) + + mock_get_settings_manager.return_value = mock_settings_manager + + setup_superuser() + mock_session.query.assert_called_with(User) + actual_expr = mock_session.query.return_value.filter.call_args[0][0] + expected_expr = User.username == ADMIN_USER_NAME + + assert str(actual_expr) == str(expected_expr) + create_call = call(db=mock_session, username=ADMIN_USER_NAME, password="password") + calls.append(create_call) + mock_create_super_user.assert_has_calls(calls) + assert 2 == mock_create_super_user.call_count + # Test that superuser credentials are reset + mock_settings_manager.auth_settings.reset_credentials.assert_called_once() + assert mock_settings_manager.auth_settings.SUPERUSER != ADMIN_USER_NAME + assert mock_settings_manager.auth_settings.SUPERUSER_PASSWORD != "password" + + # Test when superuser already exists + mock_settings_manager.auth_settings.AUTO_LOGIN = False + mock_settings_manager.auth_settings.SUPERUSER = ADMIN_USER_NAME + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = "password" + mock_user = Mock() + mock_user.is_superuser = True + mock_session.query.return_value.filter.return_value.first.return_value = mock_user + setup_superuser() + mock_session.query.assert_called_with(User) + actual_expr = mock_session.query.return_value.filter.call_args[0][0] + expected_expr = User.username == ADMIN_USER_NAME + + assert str(actual_expr) == str(expected_expr) + + +@patch("langflow.services.utils.get_settings_manager") +@patch("langflow.services.utils.get_session") +def test_teardown_superuser_default_superuser( + mock_get_session, mock_get_settings_manager +): + mock_settings_manager = MagicMock() + mock_settings_manager.auth_settings.AUTO_LOGIN = True + mock_settings_manager.auth_settings.SUPERUSER = DEFAULT_SUPERUSER + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD + mock_get_settings_manager.return_value = mock_settings_manager + + mock_session = MagicMock() + mock_user = MagicMock() + mock_user.is_superuser = True + mock_session.query.return_value.filter.return_value.first.return_value = mock_user + mock_get_session.return_value = iter([mock_session]) + + teardown_superuser() + + mock_session.query.assert_called_once_with(User) + actual_expr = mock_session.query.return_value.filter.call_args[0][0] + expected_expr = User.username == DEFAULT_SUPERUSER + + assert str(actual_expr) == str(expected_expr) + mock_session.delete.assert_called_once_with(mock_user) + mock_session.commit.assert_called_once() + + +@patch("langflow.services.utils.get_settings_manager") +@patch("langflow.services.utils.get_session") +def test_teardown_superuser_no_default_superuser( + mock_get_session, mock_get_settings_manager +): + ADMIN_USER_NAME = "admin_user" + mock_settings_manager = MagicMock() + mock_settings_manager.auth_settings.AUTO_LOGIN = False + mock_settings_manager.auth_settings.SUPERUSER = ADMIN_USER_NAME + mock_settings_manager.auth_settings.SUPERUSER_PASSWORD = "password" + mock_get_settings_manager.return_value = mock_settings_manager + + mock_session = MagicMock() + mock_user = MagicMock() + mock_user.is_superuser = False + mock_session.query.return_value.filter.return_value.first.return_value = mock_user + mock_get_session.return_value = [mock_session] + + teardown_superuser() + + mock_session.query.assert_not_called() + mock_session.delete.assert_not_called() + mock_session.commit.assert_not_called() diff --git a/tests/test_user.py b/tests/test_user.py index 54a713ef1..27894e515 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -2,20 +2,22 @@ from datetime import datetime from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user.user import User -from langflow.services.utils import get_settings_manager +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_manager, get_settings_manager import pytest from langflow.services.database.models.user import UserUpdate @pytest.fixture -def super_user(client, session): +def super_user(client): settings_manager = get_settings_manager() auth_settings = settings_manager.auth_settings - return create_super_user( - db=session, - username=auth_settings.FIRST_SUPERUSER, - password=auth_settings.FIRST_SUPERUSER_PASSWORD, - ) + with session_getter(get_db_manager()) as session: + return create_super_user( + db=session, + username=auth_settings.SUPERUSER, + password=auth_settings.SUPERUSER_PASSWORD, + ) @pytest.fixture @@ -23,8 +25,8 @@ def super_user_headers(client, super_user): settings_manager = get_settings_manager() auth_settings = settings_manager.auth_settings login_data = { - "username": auth_settings.FIRST_SUPERUSER, - "password": auth_settings.FIRST_SUPERUSER_PASSWORD, + "username": auth_settings.SUPERUSER, + "password": auth_settings.SUPERUSER_PASSWORD, } response = client.post("/api/v1/login", data=login_data) assert response.status_code == 200 @@ -34,29 +36,34 @@ def super_user_headers(client, super_user): @pytest.fixture -def deactivated_user(session): - user = User( - username="deactivateduser", - password=get_password_hash("testpassword"), - is_active=False, - is_superuser=False, - last_login_at=datetime.now(), - ) - session.add(user) - session.commit() +def deactivated_user(): + with session_getter(get_db_manager()) as session: + user = User( + username="deactivateduser", + password=get_password_hash("testpassword"), + is_active=False, + is_superuser=False, + last_login_at=datetime.now(), + ) + session.add(user) + session.commit() + session.refresh(user) return user -def test_user_waiting_for_approval(client, session): +def test_user_waiting_for_approval( + client, +): # Create a user that is not active and has never logged in - user = User( - username="waitingforapproval", - password=get_password_hash("testpassword"), - is_active=False, - last_login_at=None, - ) - session.add(user) - session.commit() + with session_getter(get_db_manager()) as session: + user = User( + username="waitingforapproval", + password=get_password_hash("testpassword"), + is_active=False, + last_login_at=None, + ) + session.add(user) + session.commit() login_data = {"username": "waitingforapproval", "password": "testpassword"} response = client.post("/api/v1/login", data=login_data) @@ -106,16 +113,17 @@ def test_data_consistency_after_delete(client, test_user, super_user_headers): assert all(user["id"] != user_id for user in response.json()["users"]) -def test_inactive_user(client, session): +def test_inactive_user(client): # Create a user that is not active and has a last_login_at value - user = User( - username="inactiveuser", - password=get_password_hash("testpassword"), - is_active=False, - last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string - ) - session.add(user) - session.commit() + with session_getter(get_db_manager()) as session: + user = User( + username="inactiveuser", + password=get_password_hash("testpassword"), + is_active=False, + last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string + ) + session.add(user) + session.commit() login_data = {"username": "inactiveuser", "password": "testpassword"} response = client.post("/api/v1/login", data=login_data) diff --git a/tests/test_vectorstore_template.py b/tests/test_vectorstore_template.py index 87394b890..5bd629906 100644 --- a/tests/test_vectorstore_template.py +++ b/tests/test_vectorstore_template.py @@ -1,5 +1,5 @@ from fastapi.testclient import TestClient -from langflow.services.utils import get_settings_manager +from langflow.services.getters import get_settings_manager # check that all agents are in settings.agents