diff --git a/src/backend/langflow/services/database/models/user/utils.py b/src/backend/langflow/services/database/models/user/utils.py deleted file mode 100644 index 3dc02a499..000000000 --- a/src/backend/langflow/services/database/models/user/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from datetime import datetime, timezone -from typing import Union -from uuid import UUID -from fastapi import Depends, HTTPException -from langflow.services.database.models.user.user import User, UserUpdate -from langflow.services.utils import get_session -from sqlalchemy.exc import IntegrityError -from sqlmodel import Session - - -from sqlalchemy.orm.attributes import flag_modified - - -def get_user_by_username(db: Session, username: str) -> Union[User, None]: - return db.query(User).filter(User.username == username).first() - - -def get_user_by_id(db: Session, id: UUID) -> Union[User, None]: - return db.query(User).filter(User.id == id).first() - - -def update_user( - user_id: UUID, user: UserUpdate, db: Session = Depends(get_session) -) -> User: - user_db = get_user_by_id(db, user_id) - if not user_db: - raise HTTPException(status_code=404, detail="User not found") - - user_db_by_username = get_user_by_username(db, user.username) # type: ignore - if user_db_by_username and user_db_by_username.id != user_id: - raise HTTPException(status_code=409, detail="Username already exists") - - user_data = user.dict(exclude_unset=True) - for attr, value in user_data.items(): - if hasattr(user_db, attr) and value is not None: - setattr(user_db, attr, value) - - user_db.updated_at = datetime.now(timezone.utc) - flag_modified(user_db, "updated_at") - - try: - db.commit() - except IntegrityError as e: - db.rollback() - raise HTTPException(status_code=400, detail=str(e)) from e - - return user_db - - -def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)): - user_data = UserUpdate(last_login_at=datetime.now(timezone.utc)) # type: ignore - - return update_user(user_id, user_data, db) diff --git a/tests/conftest.py b/tests/conftest.py index e90d03d0a..9abe89d49 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from typing import AsyncGenerator, TYPE_CHECKING from langflow.api.v1.flows import get_session from langflow.graph.graph.base import Graph +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.user.user import User, UserCreate import pytest from fastapi.testclient import TestClient from httpx import AsyncClient @@ -155,3 +157,38 @@ def session_getter_fixture(client): @pytest.fixture def runner(): return CliRunner() + + +@pytest.fixture +def test_user(client): + user_data = UserCreate( + username="testuser", + password="testpassword", + ) + response = client.post("/api/v1/user", json=user_data.dict()) + return response.json() + + +@pytest.fixture(scope="function") +def active_user(session): + user = User( + username="activeuser", + password=get_password_hash( + "testpassword" + ), # Assuming password needs to be hashed + is_active=True, + is_superuser=False, + ) + session.add(user) + session.commit() + return user + + +@pytest.fixture +def logged_in_headers(client, active_user): + login_data = {"username": active_user.username, "password": "testpassword"} + response = client.post("/api/v1/login", data=login_data) + assert response.status_code == 200 + tokens = response.json() + a_token = tokens["access_token"] + return {"Authorization": f"Bearer {a_token}"} diff --git a/tests/test_api_key.py b/tests/test_api_key.py new file mode 100644 index 000000000..43b91fa43 --- /dev/null +++ b/tests/test_api_key.py @@ -0,0 +1,50 @@ +import pytest +from langflow.services.database.models.api_key import ApiKeyCreate + + +@pytest.fixture +def api_key(client, logged_in_headers, active_user): + api_key = ApiKeyCreate(name="test-api-key") + + response = client.post( + "api/v1/api_key", data=api_key.json(), headers=logged_in_headers + ) + assert response.status_code == 200, response.text + return response.json() + + +def test_get_api_keys(client, logged_in_headers, api_key): + response = client.get("api/v1/api_key", headers=logged_in_headers) + assert response.status_code == 200, response.text + data = response.json() + assert "total_count" in data + assert "user_id" in data + assert "api_keys" in data + assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"]) + # assert all api keys in data["api_keys"] are masked + assert all("**" in api_key["api_key"] for api_key in data["api_keys"]) + # Add more assertions as needed based on the expected data structure and content + + +def test_create_api_key(client, logged_in_headers): + api_key_name = "test-api-key" + response = client.post( + "api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers + ) + assert response.status_code == 200 + data = response.json() + assert "name" in data and data["name"] == api_key_name + assert "api_key" in data + # When creating the API key is returned which is + # the only time the API key is unmasked + assert "**" not in data["api_key"] + + +def test_delete_api_key(client, logged_in_headers, active_user, api_key): + # Assuming a function to create a test API key, returning the key ID + api_key_id = api_key["id"] + response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers) + assert response.status_code == 200 + data = response.json() + assert data["detail"] == "API Key deleted" + # Optionally, add a follow-up check to ensure that the key is actually removed from the database diff --git a/tests/test_user.py b/tests/test_user.py index f8d4ff788..d734e4d61 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -4,42 +4,7 @@ from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user.user import User from langflow.services.utils import get_settings_manager import pytest -from langflow.services.database.models.user import UserCreate, UserUpdate - - -@pytest.fixture -def test_user(client): - user_data = UserCreate( - username="testuser", - password="testpassword", - ) - response = client.post("/api/v1/user", json=user_data.dict()) - return response.json() - - -@pytest.fixture(scope="function") -def active_user(session): - user = User( - username="activeuser", - password=get_password_hash( - "testpassword" - ), # Assuming password needs to be hashed - is_active=True, - is_superuser=False, - ) - session.add(user) - session.commit() - return user - - -@pytest.fixture -def logged_in_headers(client, active_user): - login_data = {"username": active_user.username, "password": "testpassword"} - response = client.post("/api/v1/login", data=login_data) - assert response.status_code == 200 - tokens = response.json() - a_token = tokens["access_token"] - return {"Authorization": f"Bearer {a_token}"} +from langflow.services.database.models.user import UserUpdate @pytest.fixture