From c85cd91e09573c5c258bf12a6795948940d9fd47 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 22 Oct 2024 13:27:31 +0200 Subject: [PATCH] ref: Move get services out of API methods (#4216) * Move get services out of API methods * Rollback files API to be able to mock the dependency --- .../base/langflow/api/health_check_router.py | 4 +- src/backend/base/langflow/api/v1/api_key.py | 8 +-- src/backend/base/langflow/api/v1/chat.py | 28 +++++------ src/backend/base/langflow/api/v1/endpoints.py | 38 ++++++-------- src/backend/base/langflow/api/v1/files.py | 23 ++++----- src/backend/base/langflow/api/v1/flows.py | 11 ++--- src/backend/base/langflow/api/v1/login.py | 12 ++--- src/backend/base/langflow/api/v1/monitor.py | 15 +++--- .../base/langflow/api/v1/starter_projects.py | 4 +- src/backend/base/langflow/api/v1/store.py | 49 +++++++------------ src/backend/base/langflow/api/v1/users.py | 6 +-- src/backend/base/langflow/api/v1/validate.py | 12 ++--- src/backend/base/langflow/api/v1/variable.py | 2 +- src/backend/tests/conftest.py | 2 +- 14 files changed, 91 insertions(+), 123 deletions(-) diff --git a/src/backend/base/langflow/api/health_check_router.py b/src/backend/base/langflow/api/health_check_router.py index 1f03c9fe0..180edd1dd 100644 --- a/src/backend/base/langflow/api/health_check_router.py +++ b/src/backend/base/langflow/api/health_check_router.py @@ -36,10 +36,10 @@ async def health(): # /health_check evaluates key services # It's a reliable health check for a langflow instance -@health_check_router.get("/health_check", response_model=HealthResponse) +@health_check_router.get("/health_check") async def health_check( session: Annotated[Session, Depends(get_session)], -): +) -> HealthResponse: response = HealthResponse() # use a fixed valid UUId that UUID collision is very unlikely user_id = "da93c2bd-c857-4b10-8c8c-60988103320f" diff --git a/src/backend/base/langflow/api/v1/api_key.py b/src/backend/base/langflow/api/v1/api_key.py index 3cf3920b7..645a9cec9 100644 --- a/src/backend/base/langflow/api/v1/api_key.py +++ b/src/backend/base/langflow/api/v1/api_key.py @@ -19,11 +19,11 @@ if TYPE_CHECKING: router = APIRouter(tags=["APIKey"], prefix="/api_key") -@router.get("/", response_model=ApiKeysResponse) +@router.get("/") def get_api_keys_route( db: Annotated[Session, Depends(get_session)], current_user: Annotated[User, Depends(auth_utils.get_current_active_user)], -): +) -> ApiKeysResponse: try: user_id = current_user.id keys = get_api_keys(db, user_id) @@ -33,12 +33,12 @@ def get_api_keys_route( raise HTTPException(status_code=400, detail=str(exc)) from exc -@router.post("/", response_model=UnmaskedApiKeyRead) +@router.post("/") def create_api_key_route( req: ApiKeyCreate, current_user: Annotated[User, Depends(auth_utils.get_current_active_user)], db: Annotated[Session, Depends(get_session)], -): +) -> UnmaskedApiKeyRead: try: user_id = current_user.id return create_api_key(db, req, user_id=user_id) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 11e7dd2e0..a6fd37a56 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -42,7 +42,6 @@ from langflow.services.auth.utils import get_current_active_user from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service, get_session, get_telemetry_service from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload -from langflow.services.telemetry.service import TelemetryService if TYPE_CHECKING: from langflow.graph.vertex.types import InterfaceVertex @@ -66,17 +65,15 @@ async def try_running_celery_task(vertex, user_id): return vertex -@router.post("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) +@router.post("/build/{flow_id}/vertices") async def retrieve_vertices_order( flow_id: uuid.UUID, background_tasks: BackgroundTasks, data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None, stop_component_id: str | None = None, start_component_id: str | None = None, - chat_service: ChatService = Depends(get_chat_service), session=Depends(get_session), - telemetry_service: TelemetryService = Depends(get_telemetry_service), -): +) -> VerticesOrderResponse: """Retrieve the vertices order for a given flow. Args: @@ -85,9 +82,7 @@ async def retrieve_vertices_order( data (Optional[FlowDataRequest], optional): The flow data. Defaults to None. stop_component_id (str, optional): The ID of the stop component. Defaults to None. start_component_id (str, optional): The ID of the start component. Defaults to None. - chat_service (ChatService, optional): The chat service dependency. Defaults to Depends(get_chat_service). session (Session, optional): The session dependency. Defaults to Depends(get_session). - telemetry_service (TelemetryService, optional): The telemetry service. Returns: VerticesOrderResponse: The response containing the ordered vertex IDs and the run ID. @@ -95,6 +90,8 @@ async def retrieve_vertices_order( Raises: HTTPException: If there is an error checking the build status. """ + chat_service = get_chat_service() + telemetry_service = get_telemetry_service() start_time = time.perf_counter() components_count = None try: @@ -150,11 +147,11 @@ async def build_flow( stop_component_id: str | None = None, start_component_id: str | None = None, log_builds: bool | None = True, - chat_service: ChatService = Depends(get_chat_service), current_user=Depends(get_current_active_user), - telemetry_service: TelemetryService = Depends(get_telemetry_service), session=Depends(get_session), ): + chat_service = get_chat_service() + telemetry_service = get_telemetry_service() if not inputs: inputs = InputValueRequest(session=str(flow_id)) @@ -464,10 +461,8 @@ async def build_vertex( background_tasks: BackgroundTasks, inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, files: list[str] | None = None, - chat_service: ChatService = Depends(get_chat_service), current_user=Depends(get_current_active_user), - telemetry_service: TelemetryService = Depends(get_telemetry_service), -): +) -> VertexBuildResponse: """Build a vertex instead of the entire graph. Args: @@ -476,9 +471,7 @@ async def build_vertex( background_tasks (BackgroundTasks): The background tasks dependency. inputs (Optional[InputValueRequest], optional): The input values for the vertex. Defaults to None. files (List[str], optional): The files to use. Defaults to None. - chat_service (ChatService, optional): The chat service dependency. Defaults to Depends(get_chat_service). current_user (Any, optional): The current user dependency. Defaults to Depends(get_current_active_user). - telemetry_service (TelemetryService, optional): The telemetry service. Returns: VertexBuildResponse: The response containing the built vertex information. @@ -487,6 +480,8 @@ async def build_vertex( HTTPException: If there is an error building the vertex. """ + chat_service = get_chat_service() + telemetry_service = get_telemetry_service() flow_id_str = str(flow_id) next_runnable_vertices = [] @@ -699,7 +694,6 @@ async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService async def build_vertex_stream( flow_id: uuid.UUID, vertex_id: str, - chat_service: ChatService = Depends(get_chat_service), ): """Build a vertex instead of the entire graph. @@ -727,6 +721,8 @@ async def build_vertex_stream( HTTPException: If an error occurs while building the vertex. """ try: - return StreamingResponse(_stream_vertex(str(flow_id), vertex_id, chat_service), media_type="text/event-stream") + return StreamingResponse( + _stream_vertex(str(flow_id), vertex_id, get_chat_service()), media_type="text/event-stream" + ) except Exception as exc: raise HTTPException(status_code=500, detail="Error building Component") from exc diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index e9a02712f..456ebd4c7 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -47,9 +47,7 @@ from langflow.services.deps import ( get_task_service, get_telemetry_service, ) -from langflow.services.session.service import SessionService from langflow.services.telemetry.schema import RunPayload -from langflow.services.telemetry.service import TelemetryService from langflow.utils.constants import SIDEBAR_CATEGORIES from langflow.utils.version import get_version_info @@ -60,14 +58,11 @@ router = APIRouter(tags=["Base"]) @router.get("/all", dependencies=[Depends(get_current_active_user)]) -async def get_all( - *, - settings_service=Depends(get_settings_service), -): +async def get_all(): from langflow.interface.types import get_and_cache_all_types_dict try: - return await get_and_cache_all_types_dict(settings_service=settings_service) + return await get_and_cache_all_types_dict(settings_service=get_settings_service()) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -165,7 +160,7 @@ async def simple_run_flow_task( logger.exception(f"Error running flow {flow.id} task") -@router.post("/run/{flow_id_or_name}", response_model=RunResponse, response_model_exclude_none=True) # noqa: RUF100, FAST003 +@router.post("/run/{flow_id_or_name}", response_model_exclude_none=True) # noqa: RUF100, FAST003 async def simplified_run_flow( *, background_tasks: BackgroundTasks, @@ -173,8 +168,7 @@ async def simplified_run_flow( input_request: SimplifiedAPIRequest | None = None, stream: bool = False, api_key_user: UserRead = Depends(api_key_security), - telemetry_service: TelemetryService = Depends(get_telemetry_service), -): +) -> RunResponse: """Executes a specified flow by ID. Executes a specified flow by ID with input customization, performance enhancements through caching, @@ -239,6 +233,7 @@ async def simplified_run_flow( supporting a wide range of applications by allowing for dynamic input and output configuration along with performance optimizations through session management and caching. """ + telemetry_service = get_telemetry_service() input_request = input_request if input_request is not None else SimplifiedAPIRequest() if flow is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found") @@ -297,7 +292,6 @@ async def webhook_run_flow( user: Annotated[User, Depends(get_user_by_flow_id_or_endpoint_name)], request: Request, background_tasks: BackgroundTasks, - telemetry_service: Annotated[TelemetryService, Depends(get_telemetry_service)], ): """Run a flow using a webhook request. @@ -306,7 +300,6 @@ async def webhook_run_flow( user (User): The flow user. request (Request): The incoming HTTP request. background_tasks (BackgroundTasks): The background tasks manager. - telemetry_service (TelemetryService): The telemetry service. Returns: dict: A dictionary containing the status of the task. @@ -314,6 +307,7 @@ async def webhook_run_flow( Raises: HTTPException: If the flow is not found or if there is an error processing the request. """ + telemetry_service = get_telemetry_service() start_time = time.perf_counter() logger.debug("Received webhook request") error_msg = "" @@ -378,8 +372,7 @@ async def experimental_run_flow( stream: Annotated[bool, Body(embed=True)] = False, session_id: Annotated[None | str, Body(embed=True)] = None, api_key_user: UserRead = Depends(api_key_security), - session_service: SessionService = Depends(get_session_service), -): +) -> RunResponse: """Executes a specified flow by ID with optional input values, output selection, tweaks, and streaming capability. This endpoint supports running flows with caching to enhance performance and efficiency. @@ -397,7 +390,6 @@ async def experimental_run_flow( - `session_id` (Union[None, str], optional): An optional session ID to utilize existing session data for the flow execution. - `api_key_user` (User): The user associated with the current API key. Automatically resolved from the API key. - - `session_service` (SessionService): The session service object for managing flow sessions. ### Returns: A `RunResponse` object containing the selected outputs (or all if not specified) of the executed flow @@ -427,6 +419,7 @@ async def experimental_run_flow( This endpoint facilitates complex flow executions with customized inputs, outputs, and configurations, catering to diverse application requirements. """ # noqa: E501 + session_service = get_session_service() flow_id_str = str(flow_id) if outputs is None: outputs = [] @@ -511,8 +504,8 @@ async def process(): ) -@router.get("/task/{task_id}", response_model=TaskStatusResponse) -async def get_task_status(task_id: str): +@router.get("/task/{task_id}") +async def get_task_status(task_id: str) -> TaskStatusResponse: task_service = get_task_service() task = task_service.get_task(task_id) result = None @@ -538,13 +531,12 @@ async def get_task_status(task_id: str): @router.post( "/upload/{flow_id}", - response_model=UploadFileResponse, status_code=HTTPStatus.CREATED, ) async def create_upload_file( file: UploadFile, flow_id: UUID, -): +) -> UploadFileResponse: try: flow_id_str = str(flow_id) file_path = save_uploaded_file(file, folder_name=flow_id_str) @@ -564,11 +556,11 @@ def get_version(): return get_version_info() -@router.post("/custom_component", status_code=HTTPStatus.OK, response_model=CustomComponentResponse) +@router.post("/custom_component", status_code=HTTPStatus.OK) async def custom_component( raw_code: CustomComponentRequest, user: Annotated[User, Depends(get_current_active_user)], -): +) -> CustomComponentResponse: component = Component(_code=raw_code.code) built_frontend_node, component_instance = build_custom_component_template(component, user_id=user.id) @@ -646,6 +638,6 @@ def get_config(): raise HTTPException(status_code=500, detail=str(exc)) from exc -@router.get("/sidebar_categories", response_model=SidebarCategoriesResponse) -def get_sidebar_categories(): +@router.get("/sidebar_categories") +def get_sidebar_categories() -> SidebarCategoriesResponse: return SidebarCategoriesResponse(categories=SIDEBAR_CATEGORIES) diff --git a/src/backend/base/langflow/api/v1/files.py b/src/backend/base/langflow/api/v1/files.py index 105e66add..8b24e99c9 100644 --- a/src/backend/base/langflow/api/v1/files.py +++ b/src/backend/base/langflow/api/v1/files.py @@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse from langflow.api.v1.schemas import UploadFileResponse from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models.flow import Flow -from langflow.services.deps import get_session, get_storage_service +from langflow.services.deps import get_session, get_settings_service, get_storage_service from langflow.services.storage.service import StorageService from langflow.services.storage.utils import build_content_type_from_extension @@ -39,14 +39,15 @@ def get_flow_id( @router.post("/upload/{flow_id}", status_code=HTTPStatus.CREATED) async def upload_file( + *, file: UploadFile, flow_id: Annotated[UUID, Depends(get_flow_id)], current_user=Depends(get_current_active_user), session=Depends(get_session), - storage_service: StorageService = Depends(get_storage_service), -): + storage_service: Annotated[StorageService, Depends(get_storage_service)], +) -> UploadFileResponse: try: - max_file_size_upload = get_storage_service().settings_service.settings.max_file_size_upload + max_file_size_upload = get_settings_service().settings.max_file_size_upload except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e @@ -106,9 +107,8 @@ async def download_file( @router.get("/images/{flow_id}/{file_name}") -async def download_image( - file_name: str, flow_id: UUID, storage_service: Annotated[StorageService, Depends(get_storage_service)] -): +async def download_image(file_name: str, flow_id: UUID): + storage_service = get_storage_service() extension = file_name.split(".")[-1] flow_id_str = str(flow_id) @@ -135,11 +135,11 @@ async def download_image( async def download_profile_picture( folder_name: str, file_name: str, - storage_service: Annotated[StorageService, Depends(get_storage_service)], ): try: + storage_service = get_storage_service() extension = file_name.split(".")[-1] - config_dir = get_storage_service().settings_service.settings.config_dir + config_dir = storage_service.settings_service.settings.config_dir config_path = Path(config_dir) # type: ignore[arg-type] folder_path = config_path / "profile_pictures" / folder_name content_type = build_content_type_from_extension(extension) @@ -151,9 +151,10 @@ async def download_profile_picture( @router.get("/profile_pictures/list") -async def list_profile_pictures(storage_service: Annotated[StorageService, Depends(get_storage_service)]): +async def list_profile_pictures(): try: - config_dir = get_storage_service().settings_service.settings.config_dir + storage_service = get_storage_service() + config_dir = storage_service.settings_service.settings.config_dir config_path = Path(config_dir) # type: ignore[arg-type] people_path = config_path / "profile_pictures/People" diff --git a/src/backend/base/langflow/api/v1/flows.py b/src/backend/base/langflow/api/v1/flows.py index dc4727486..2aa35f10e 100644 --- a/src/backend/base/langflow/api/v1/flows.py +++ b/src/backend/base/langflow/api/v1/flows.py @@ -129,7 +129,6 @@ def read_flows( *, current_user: User = Depends(get_current_active_user), session: Session = Depends(get_session), - settings_service: SettingsService = Depends(get_settings_service), remove_example_flows: bool = False, components_only: bool = False, get_all: bool = True, @@ -158,7 +157,7 @@ def read_flows( A list of flows or a paginated response containing the list of flows or a list of flow headers. """ try: - auth_settings = settings_service.auth_settings + auth_settings = get_settings_service().auth_settings default_folder = session.exec(select(Folder).where(Folder.name == DEFAULT_FOLDER_NAME)).first() default_folder_id = default_folder.id if default_folder else None @@ -233,10 +232,9 @@ def read_flow( session: Session = Depends(get_session), flow_id: UUID, current_user: User = Depends(get_current_active_user), - settings_service: SettingsService = Depends(get_settings_service), ): """Read a flow.""" - if user_flow := _read_flow(session, flow_id, current_user, settings_service): + if user_flow := _read_flow(session, flow_id, current_user, get_settings_service()): return user_flow raise HTTPException(status_code=404, detail="Flow not found") @@ -248,9 +246,9 @@ def update_flow( flow_id: UUID, flow: FlowUpdate, current_user: User = Depends(get_current_active_user), - settings_service=Depends(get_settings_service), ): """Update a flow.""" + settings_service = get_settings_service() try: db_flow = _read_flow( session=session, @@ -307,14 +305,13 @@ async def delete_flow( session: Session = Depends(get_session), flow_id: UUID, current_user: User = Depends(get_current_active_user), - settings_service=Depends(get_settings_service), ): """Delete a flow.""" flow = _read_flow( session=session, flow_id=flow_id, current_user=current_user, - settings_service=settings_service, + settings_service=get_settings_service(), ) if not flow: raise HTTPException(status_code=404, detail="Flow not found") diff --git a/src/backend/base/langflow/api/v1/login.py b/src/backend/base/langflow/api/v1/login.py index 49569ebd8..ce61f3b40 100644 --- a/src/backend/base/langflow/api/v1/login.py +++ b/src/backend/base/langflow/api/v1/login.py @@ -16,7 +16,6 @@ from langflow.services.auth.utils import ( from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist from langflow.services.database.models.user.crud import get_user_by_id from langflow.services.deps import get_session, get_settings_service, get_variable_service -from langflow.services.settings.service import SettingsService from langflow.services.variable.service import VariableService router = APIRouter(tags=["Login"]) @@ -83,12 +82,10 @@ async def login_to_get_access_token( @router.get("/auto_login") -async def auto_login( - response: Response, db: Annotated[Session, Depends(get_session)], settings_service=Depends(get_settings_service) -): - auth_settings = settings_service.auth_settings +async def auto_login(response: Response, db: Annotated[Session, Depends(get_session)]): + auth_settings = get_settings_service().auth_settings - if settings_service.auth_settings.AUTO_LOGIN: + if auth_settings.AUTO_LOGIN: user_id, tokens = create_user_longterm_token(db) response.set_cookie( "access_token_lf", @@ -131,10 +128,9 @@ async def auto_login( async def refresh_token( request: Request, response: Response, - settings_service: Annotated[SettingsService, Depends(get_settings_service)], db: Annotated[Session, Depends(get_session)], ): - auth_settings = settings_service.auth_settings + auth_settings = get_settings_service().auth_settings token = request.cookies.get("refresh_token_lf") diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index ac3fabf3e..b6bddc779 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -20,11 +20,11 @@ from langflow.services.deps import get_session router = APIRouter(prefix="/monitor", tags=["Monitor"]) -@router.get("/builds", response_model=VertexBuildMapModel) +@router.get("/builds") async def get_vertex_builds( flow_id: Annotated[UUID, Query()], session: Annotated[Session, Depends(get_session)], -): +) -> VertexBuildMapModel: try: vertex_builds = get_vertex_builds_by_flow_id(session, flow_id) return VertexBuildMapModel.from_list_of_dicts(vertex_builds) @@ -43,7 +43,7 @@ async def delete_vertex_builds( raise HTTPException(status_code=500, detail=str(e)) from e -@router.get("/messages", response_model=list[MessageResponse]) +@router.get("/messages") async def get_messages( session: Annotated[Session, Depends(get_session)], flow_id: Annotated[str | None, Query()] = None, @@ -51,7 +51,7 @@ async def get_messages( sender: Annotated[str | None, Query()] = None, sender_name: Annotated[str | None, Query()] = None, order_by: Annotated[str | None, Query()] = "timestamp", -): +) -> list[MessageResponse]: try: stmt = select(MessageTable) if flow_id: @@ -113,13 +113,12 @@ async def update_message( @router.patch( "/messages/session/{old_session_id}", dependencies=[Depends(get_current_active_user)], - response_model=list[MessageResponse], ) async def update_session_id( old_session_id: str, new_session_id: Annotated[str, Query(..., description="The new session ID to update to")], session: Annotated[Session, Depends(get_session)], -): +) -> list[MessageResponse]: try: # Get all messages with the old session ID stmt = select(MessageTable).where(MessageTable.session_id == old_session_id) @@ -166,11 +165,11 @@ async def delete_messages_session( return {"message": "Messages deleted successfully"} -@router.get("/transactions", response_model=list[TransactionReadResponse]) +@router.get("/transactions") async def get_transactions( flow_id: Annotated[UUID, Query()], session: Annotated[Session, Depends(get_session)], -): +) -> list[TransactionReadResponse]: try: transactions = get_transactions_by_flow_id(session, flow_id) return [ diff --git a/src/backend/base/langflow/api/v1/starter_projects.py b/src/backend/base/langflow/api/v1/starter_projects.py index 3c5ac5126..1f76ab4e9 100644 --- a/src/backend/base/langflow/api/v1/starter_projects.py +++ b/src/backend/base/langflow/api/v1/starter_projects.py @@ -6,8 +6,8 @@ from langflow.services.auth.utils import get_current_active_user router = APIRouter(prefix="/starter-projects", tags=["Flows"]) -@router.get("/", dependencies=[Depends(get_current_active_user)], response_model=list[GraphDump], status_code=200) -def get_starter_projects(): +@router.get("/", dependencies=[Depends(get_current_active_user)], status_code=200) +def get_starter_projects() -> list[GraphDump]: """Get a list of starter projects.""" from langflow.initial_setup.load import get_starter_projects_dump diff --git a/src/backend/base/langflow/api/v1/store.py b/src/backend/base/langflow/api/v1/store.py index 9fd8141f2..0f79ace2e 100644 --- a/src/backend/base/langflow/api/v1/store.py +++ b/src/backend/base/langflow/api/v1/store.py @@ -17,7 +17,6 @@ from langflow.services.store.schema import ( TagResponse, UsersLikesResponse, ) -from langflow.services.store.service import StoreService router = APIRouter(prefix="/store", tags=["Components Store"]) @@ -48,24 +47,21 @@ def get_optional_user_store_api_key( @router.get("/check/") -def check_if_store_is_enabled( - settings_service=Depends(get_settings_service), -): +def check_if_store_is_enabled(): return { - "enabled": settings_service.settings.store, + "enabled": get_settings_service().settings.store, } @router.get("/check/api_key") async def check_if_store_has_api_key( api_key: Annotated[str | None, Depends(get_optional_user_store_api_key)], - store_service: Annotated[StoreService, Depends(get_store_service)], ): if api_key is None: return {"has_api_key": False, "is_valid": False} try: - is_valid = await store_service.check_api_key(api_key) + is_valid = await get_store_service().check_api_key(api_key) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) from e @@ -75,31 +71,29 @@ async def check_if_store_has_api_key( @router.post("/components/", response_model=CreateComponentResponse, status_code=201) async def share_component( component: StoreComponentCreate, - store_service: Annotated[StoreService, Depends(get_store_service)], store_api_key: Annotated[str, Depends(get_user_store_api_key)], -): +) -> CreateComponentResponse: try: await check_langflow_version(component) - return await store_service.upload(store_api_key, component) + return await get_store_service().upload(store_api_key, component) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc -@router.patch("/components/{component_id}", response_model=CreateComponentResponse, status_code=201) +@router.patch("/components/{component_id}", status_code=201) async def update_shared_component( component_id: UUID, component: StoreComponentCreate, - store_service: Annotated[StoreService, Depends(get_store_service)], store_api_key: Annotated[str, Depends(get_user_store_api_key)], -): +) -> CreateComponentResponse: try: await check_langflow_version(component) - return await store_service.update(store_api_key, component_id, component) + return await get_store_service().update(store_api_key, component_id, component) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc -@router.get("/components/", response_model=ListComponentResponseModel) +@router.get("/components/") async def get_components( *, component_id: Annotated[str | None, Query()] = None, @@ -113,11 +107,10 @@ async def get_components( fields: Annotated[list[str] | None, Query()] = None, page: int = 1, limit: int = 10, - store_service: StoreService = Depends(get_store_service), store_api_key: str | None = Depends(get_optional_user_store_api_key), -): +) -> ListComponentResponseModel: try: - return await store_service.get_list_component_response_model( + return await get_store_service().get_list_component_response_model( component_id=component_id, search=search, private=private, @@ -140,11 +133,10 @@ async def get_components( @router.get("/components/{component_id}", response_model=DownloadComponentResponse) async def download_component( component_id: UUID, - store_service: Annotated[StoreService, Depends(get_store_service)], store_api_key: Annotated[str, Depends(get_user_store_api_key)], -): +) -> DownloadComponentResponse: try: - component = await store_service.download(store_api_key, component_id) + component = await get_store_service().download(store_api_key, component_id) except CustomError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: @@ -157,11 +149,9 @@ async def download_component( @router.get("/tags", response_model=list[TagResponse]) -async def get_tags( - store_service: Annotated[StoreService, Depends(get_store_service)], -): +async def get_tags(): try: - return await store_service.get_tags() + return await get_store_service().get_tags() except CustomError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: @@ -170,24 +160,23 @@ async def get_tags( @router.get("/users/likes", response_model=list[UsersLikesResponse]) async def get_list_of_components_liked_by_user( - store_service: Annotated[StoreService, Depends(get_store_service)], store_api_key: Annotated[str, Depends(get_user_store_api_key)], ): try: - return await store_service.get_user_likes(store_api_key) + return await get_store_service().get_user_likes(store_api_key) except CustomError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc -@router.post("/users/likes/{component_id}", response_model=UsersLikesResponse) +@router.post("/users/likes/{component_id}") async def like_component( component_id: UUID, - store_service: Annotated[StoreService, Depends(get_store_service)], store_api_key: Annotated[str, Depends(get_user_store_api_key)], -): +) -> UsersLikesResponse: try: + store_service = get_store_service() result = await store_service.like_component(store_api_key, str(component_id)) likes_count = await store_service.get_component_likes_count(str(component_id), store_api_key) diff --git a/src/backend/base/langflow/api/v1/users.py b/src/backend/base/langflow/api/v1/users.py index 77e0b9b75..0a1b69198 100644 --- a/src/backend/base/langflow/api/v1/users.py +++ b/src/backend/base/langflow/api/v1/users.py @@ -26,13 +26,12 @@ router = APIRouter(tags=["Users"], prefix="/users") def add_user( user: UserCreate, session: Annotated[Session, Depends(get_session)], - settings_service=Depends(get_settings_service), ) -> User: """Add a new user to the database.""" new_user = User.model_validate(user, from_attributes=True) try: new_user.password = get_password_hash(user.password) - new_user.is_active = settings_service.auth_settings.NEW_USER_IS_ACTIVE + new_user.is_active = get_settings_service().auth_settings.NEW_USER_IS_ACTIVE session.add(new_user) session.commit() session.refresh(new_user) @@ -54,11 +53,10 @@ def read_current_user( return current_user -@router.get("/") +@router.get("/", dependencies=[Depends(get_current_active_superuser)]) def read_all_users( skip: int = 0, limit: int = 10, - _: Session = Depends(get_current_active_superuser), session: Session = Depends(get_session), ) -> UsersResponse: """Retrieve a list of users from the database with pagination.""" diff --git a/src/backend/base/langflow/api/v1/validate.py b/src/backend/base/langflow/api/v1/validate.py index eaa5d4bbe..9f205d767 100644 --- a/src/backend/base/langflow/api/v1/validate.py +++ b/src/backend/base/langflow/api/v1/validate.py @@ -9,21 +9,21 @@ from langflow.utils.validate import validate_code router = APIRouter(prefix="/validate", tags=["Validate"]) -@router.post("/code", status_code=200, response_model=CodeValidationResponse) -def post_validate_code(code: Code): +@router.post("/code", status_code=200) +def post_validate_code(code: Code) -> CodeValidationResponse: try: errors = validate_code(code.code) return CodeValidationResponse( imports=errors.get("imports", {}), function=errors.get("function", {}), ) - except Exception as e: # noqa: BLE001 + except Exception as e: logger.opt(exception=True).debug("Error validating code") - return HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e -@router.post("/prompt", status_code=200, response_model=PromptValidationResponse) -def post_validate_prompt(prompt_request: ValidatePromptRequest): +@router.post("/prompt", status_code=200) +def post_validate_prompt(prompt_request: ValidatePromptRequest) -> PromptValidationResponse: try: if not prompt_request.frontend_node: return PromptValidationResponse( diff --git a/src/backend/base/langflow/api/v1/variable.py b/src/backend/base/langflow/api/v1/variable.py index f2c8ac6c4..c77ee2abd 100644 --- a/src/backend/base/langflow/api/v1/variable.py +++ b/src/backend/base/langflow/api/v1/variable.py @@ -21,9 +21,9 @@ def create_variable( session: Session = Depends(get_session), variable: VariableCreate, current_user: User = Depends(get_current_active_user), - variable_service: DatabaseVariableService = Depends(get_variable_service), ): """Create a new variable.""" + variable_service = get_variable_service() if not variable.name and not variable.value: raise HTTPException(status_code=400, detail="Variable name and value cannot be empty") diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 5d80f9cc5..5a089f7e5 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -12,10 +12,10 @@ from uuid import UUID import orjson import pytest from asgi_lifespan import LifespanManager -from base.langflow.components.inputs.ChatInput import ChatInput from dotenv import load_dotenv from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient +from langflow.components.inputs.ChatInput import ChatInput from langflow.graph.graph.base import Graph from langflow.initial_setup.setup import STARTER_FOLDER_NAME from langflow.services.auth.utils import get_password_hash