Adds Tests for Login, Users and API keys (#821)
This commit is contained in:
commit
eab34e2fdc
44 changed files with 998 additions and 269 deletions
|
|
@ -0,0 +1,100 @@
|
|||
"""Add ApiKey table
|
||||
|
||||
Revision ID: 5512e39b4012
|
||||
Revises: 0a534bdfd84b
|
||||
Create Date: 2023-08-23 21:05:51.042203
|
||||
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5512e39b4012"
|
||||
down_revision: Union[str, None] = "0a534bdfd84b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with contextlib.suppress(sa.exc.OperationalError):
|
||||
op.create_table(
|
||||
"apikey",
|
||||
sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("create_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_apikey_api_key"), "apikey", ["api_key"], unique=True)
|
||||
|
||||
with contextlib.suppress(sa.exc.OperationalError):
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("create_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("last_login_at", sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_user_username"), "user", ["username"], unique=True)
|
||||
with contextlib.suppress(sa.exc.OperationalError):
|
||||
op.drop_table("flowstyle")
|
||||
with contextlib.suppress(sa.exc.OperationalError):
|
||||
op.drop_index("ix_component_frontend_node_id", table_name="component")
|
||||
op.drop_index("ix_component_name", table_name="component")
|
||||
op.drop_table("component")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"component",
|
||||
sa.Column("id", sa.CHAR(length=32), nullable=False),
|
||||
sa.Column("frontend_node_id", sa.CHAR(length=32), nullable=False),
|
||||
sa.Column("name", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("description", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("python_code", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("return_type", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("is_disabled", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column("is_read_only", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column("create_at", sa.DATETIME(), nullable=False),
|
||||
sa.Column("update_at", sa.DATETIME(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_component_name", "component", ["name"], unique=False)
|
||||
op.create_index(
|
||||
"ix_component_frontend_node_id", "component", ["frontend_node_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"flowstyle",
|
||||
sa.Column("color", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("emoji", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("flow_id", sa.CHAR(length=32), nullable=True),
|
||||
sa.Column("id", sa.CHAR(length=32), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["flow_id"],
|
||||
["flow.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
op.drop_index(op.f("ix_user_username"), table_name="user")
|
||||
op.drop_table("user")
|
||||
op.drop_index(op.f("ix_apikey_api_key"), table_name="apikey")
|
||||
op.drop_table("apikey")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -6,6 +6,9 @@ from langflow.api.v1 import (
|
|||
validate_router,
|
||||
flows_router,
|
||||
component_router,
|
||||
users_router,
|
||||
api_key_router,
|
||||
login_router,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -16,3 +19,6 @@ router.include_router(endpoints_router)
|
|||
router.include_router(validate_router)
|
||||
router.include_router(component_router)
|
||||
router.include_router(flows_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(api_key_router)
|
||||
router.include_router(login_router)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ from langflow.api.v1.validate import router as validate_router
|
|||
from langflow.api.v1.chat import router as chat_router
|
||||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.components import router as component_router
|
||||
from langflow.api.v1.users import router as users_router
|
||||
from langflow.api.v1.api_key import router as api_key_router
|
||||
from langflow.api.v1.login import router as login_router
|
||||
|
||||
__all__ = [
|
||||
"chat_router",
|
||||
|
|
@ -10,4 +13,7 @@ __all__ = [
|
|||
"component_router",
|
||||
"validate_router",
|
||||
"flows_router",
|
||||
"users_router",
|
||||
"api_key_router",
|
||||
"login_router",
|
||||
]
|
||||
|
|
|
|||
61
src/backend/langflow/api/v1/api_key.py
Normal file
61
src/backend/langflow/api/v1/api_key.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from langflow.api.v1.schemas import ApiKeysResponse
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.api_key.api_key import (
|
||||
ApiKeyCreate,
|
||||
UnmaskedApiKeyRead,
|
||||
)
|
||||
|
||||
# Assuming you have these methods in your service layer
|
||||
from langflow.services.database.models.api_key.crud import (
|
||||
get_api_keys,
|
||||
create_api_key,
|
||||
delete_api_key,
|
||||
)
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.utils import get_session
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
router = APIRouter(tags=["APIKey"])
|
||||
|
||||
|
||||
@router.get("/api_key", response_model=ApiKeysResponse)
|
||||
def get_api_keys_route(
|
||||
db: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
user_id = current_user.id
|
||||
keys = get_api_keys(db, user_id)
|
||||
|
||||
return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api_key", response_model=UnmaskedApiKeyRead)
|
||||
def create_api_key_route(
|
||||
req: ApiKeyCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
user_id = current_user.id
|
||||
return create_api_key(db, req, user_id=user_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/api_key/{api_key_id}")
|
||||
def delete_api_key_route(
|
||||
api_key_id: UUID,
|
||||
current_user=Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
delete_api_key(db, api_key_id)
|
||||
return {"detail": "API Key deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
|
@ -58,8 +58,12 @@ def get_all():
|
|||
|
||||
logger.info(f"Loading {len(custom_component_dicts)} category(ies)")
|
||||
for custom_component_dict in custom_component_dicts:
|
||||
logger.debug(
|
||||
{key: len(value) for key, value in custom_component_dict.items()}
|
||||
# custom_component_dict is a dict of dicts
|
||||
if not custom_component_dict:
|
||||
continue
|
||||
category = list(custom_component_dict.keys())[0]
|
||||
logger.info(
|
||||
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
|
||||
)
|
||||
custom_components_from_file = merge_nested_dicts_with_renaming(
|
||||
custom_components_from_file, custom_component_dict
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.database.models.token import Token
|
||||
from langflow.auth.auth import (
|
||||
from langflow.services.database.models import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
create_user_tokens,
|
||||
create_refresh_token,
|
||||
|
|
@ -13,7 +13,7 @@ from langflow.auth.auth import (
|
|||
|
||||
from langflow.services.utils import get_settings_manager
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
|
|
@ -36,7 +36,7 @@ async def login_to_get_access_token(
|
|||
async def auto_login(db: Session = Depends(get_session)):
|
||||
settings_manager = get_settings_manager()
|
||||
|
||||
if settings_manager.settings.AUTO_LOGIN:
|
||||
if settings_manager.auth_settings.AUTO_LOGIN:
|
||||
return create_user_longterm_token(db)
|
||||
|
||||
raise HTTPException(
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
from langflow.services.database.models.api_key.api_key import ApiKeyRead
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowRead
|
||||
from langflow.services.database.models.user import UserRead
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
|
@ -137,3 +140,26 @@ class ComponentListCreate(BaseModel):
|
|||
|
||||
class ComponentListRead(BaseModel):
|
||||
flows: List[FlowRead]
|
||||
|
||||
|
||||
class UsersResponse(BaseModel):
|
||||
total_count: int
|
||||
users: List[UserRead]
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: str
|
||||
api_key: str
|
||||
name: str
|
||||
created_at: str
|
||||
last_used_at: str
|
||||
|
||||
|
||||
class ApiKeysResponse(BaseModel):
|
||||
total_count: int
|
||||
user_id: UUID
|
||||
api_keys: List[ApiKeyRead]
|
||||
|
||||
|
||||
class CreateApiKeyRequest(BaseModel):
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from uuid import UUID
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.database.models.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserRead,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -7,28 +14,27 @@ from sqlmodel import Session, select
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.auth.auth import get_current_active_user, get_password_hash
|
||||
from langflow.database.models.user import (
|
||||
User,
|
||||
UserAddModel,
|
||||
UserListModel,
|
||||
UserPatchModel,
|
||||
UsersResponse,
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_user,
|
||||
get_password_hash,
|
||||
)
|
||||
from langflow.services.database.models.user.crud import (
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
router = APIRouter(tags=["Users"])
|
||||
|
||||
|
||||
@router.post("/user", response_model=UserListModel)
|
||||
@router.post("/user", response_model=UserRead, status_code=201)
|
||||
def add_user(
|
||||
user: UserAddModel,
|
||||
user: UserCreate,
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
Add a new user to the database.
|
||||
"""
|
||||
new_user = User(**user.dict())
|
||||
new_user = User.from_orm(user)
|
||||
try:
|
||||
new_user.password = get_password_hash(user.password)
|
||||
|
||||
|
|
@ -42,8 +48,10 @@ def add_user(
|
|||
return new_user
|
||||
|
||||
|
||||
@router.get("/user", response_model=UserListModel)
|
||||
def read_current_user(current_user: User = Depends(get_current_active_user)) -> User:
|
||||
@router.get("/user", response_model=UserRead)
|
||||
def read_current_user(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> User:
|
||||
"""
|
||||
Retrieve the current user's data.
|
||||
"""
|
||||
|
|
@ -54,7 +62,7 @@ def read_current_user(current_user: User = Depends(get_current_active_user)) ->
|
|||
def read_all_users(
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
current_user: Session = Depends(get_current_active_superuser),
|
||||
db: Session = Depends(get_session),
|
||||
) -> UsersResponse:
|
||||
"""
|
||||
|
|
@ -68,14 +76,14 @@ def read_all_users(
|
|||
|
||||
return UsersResponse(
|
||||
total_count=total_count, # type: ignore
|
||||
users=[UserListModel(**dict(user.User)) for user in users],
|
||||
users=[UserRead(**dict(user.User)) for user in users],
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/user/{user_id}", response_model=UserListModel)
|
||||
@router.patch("/user/{user_id}", response_model=UserRead)
|
||||
def patch_user(
|
||||
user_id: UUID,
|
||||
user: UserPatchModel,
|
||||
user: UserUpdate,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
|
|
@ -88,12 +96,21 @@ def patch_user(
|
|||
@router.delete("/user/{user_id}")
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
_: Session = Depends(get_current_active_user),
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
db: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a user from the database.
|
||||
"""
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete your own user account"
|
||||
)
|
||||
elif not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have the permission to delete this user"
|
||||
)
|
||||
|
||||
user_db = db.query(User).filter(User.id == user_id).first()
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
from sqlmodel import Field
|
||||
from uuid import UUID, uuid4
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import timezone, datetime
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import HTTPException, Depends
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
|
||||
|
||||
|
||||
class User(SQLModelSerializable, table=True):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
|
||||
username: str = Field(index=True, unique=True)
|
||||
password: str = Field()
|
||||
is_active: bool = Field(default=False)
|
||||
is_superuser: bool = Field(default=False)
|
||||
create_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UserAddModel(SQLModel):
|
||||
username: str = Field()
|
||||
password: str = Field()
|
||||
|
||||
|
||||
class UserListModel(SQLModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
username: str = Field()
|
||||
is_active: bool = Field()
|
||||
is_superuser: bool = Field()
|
||||
create_at: datetime = Field()
|
||||
updated_at: datetime = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UserPatchModel(SQLModel):
|
||||
username: Optional[str] = Field()
|
||||
is_active: Optional[bool] = Field()
|
||||
is_superuser: Optional[bool] = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UsersResponse(BaseModel):
|
||||
total_count: int
|
||||
users: List[UserListModel]
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> User:
|
||||
db_user = db.query(User).filter(User.username == username).first()
|
||||
return User.from_orm(db_user) if db_user else None # type: ignore
|
||||
|
||||
|
||||
def get_user_by_id(db: Session, id: UUID) -> User:
|
||||
db_user = db.query(User).filter(User.id == id).first()
|
||||
return User.from_orm(db_user) if db_user else None # type: ignore
|
||||
|
||||
|
||||
def update_user(
|
||||
user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session)
|
||||
) -> User:
|
||||
user_db = get_user_by_username(db, user.username) # type: ignore
|
||||
if user_db and user_db.id != user_id:
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
user_db = get_user_by_id(db, user_id)
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
try:
|
||||
user_data = user.dict(exclude_unset=True)
|
||||
for key, value in user_data.items():
|
||||
setattr(user_db, key, value)
|
||||
|
||||
user_db.updated_at = datetime.now(timezone.utc)
|
||||
user_db = db.merge(user_db)
|
||||
db.commit()
|
||||
if db.identity_key(instance=user_db) is not None:
|
||||
db.refresh(user_db)
|
||||
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
|
||||
user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore
|
||||
|
||||
return update_user(user_id, user_data, db)
|
||||
|
|
@ -6,7 +6,7 @@ from fastapi.responses import FileResponse
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from langflow.api import router
|
||||
from langflow.routers import api_key, login, users, health
|
||||
|
||||
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
from langflow.services.database.utils import initialize_database
|
||||
|
|
@ -31,11 +31,6 @@ def create_app():
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(login.router)
|
||||
app.include_router(api_key.router)
|
||||
app.include_router(users.router)
|
||||
app.include_router(health.router)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
app.on_event("startup")(initialize_services)
|
||||
|
|
|
|||
|
|
@ -1,49 +0,0 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
|
||||
router = APIRouter(tags=["APIKey"])
|
||||
|
||||
|
||||
@router.get("/api_key/{user_id}")
|
||||
def get_api_key(user_id: str):
|
||||
return {
|
||||
"total_count": 3,
|
||||
"user_id": user_id,
|
||||
"api_keys": [
|
||||
{
|
||||
"id": "4425707e-cce4-4d1b-a54e-bd2632064657",
|
||||
"api_key": "lf-...abcd",
|
||||
"name": "my api_key name - 01",
|
||||
"created_at": "2023-08-15T19:28:40.019613",
|
||||
"last_used_at": "2023-08-16T18:38:20.875210",
|
||||
},
|
||||
{
|
||||
"id": "6fb7282b-9f2e-4efe-9bda-0c3d8f899473",
|
||||
"api_key": "lf-...abcd",
|
||||
"name": "my api_key name - 02",
|
||||
"created_at": "2023-08-15T19:41:30.077942",
|
||||
"last_used_at": "2023-08-15T19:45:32.067899",
|
||||
},
|
||||
{
|
||||
"id": "c55f3b32-4920-42b6-a5cd-698b4251806e",
|
||||
"api_key": "lf-...abcd",
|
||||
"name": "my api_key name - 03",
|
||||
"created_at": "2023-08-15T20:29:40.577808",
|
||||
"last_used_at": "2023-08-15T20:29:40.577816",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/api_key/{user_id}")
|
||||
def create_api_key(user_id: str):
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"name": "my api-key 01",
|
||||
"api_key": "lf-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI1YTBmODM1ZS0yMTQxLTQ2YWItYmQ4NS0yMWEzMjQ1MTE2ZDAiLCJleHAiOjE2OTIyMTUwMTN9.c_s0ZPRtjSI9yUrhi8ACIwyXf0feRLYfaeIZEbRVKQg",
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/api_key/{api_key_id}")
|
||||
def delete_api_key(api_key_id: str):
|
||||
return {"detail": "API Key deleted"}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def get_health():
|
||||
return {"status": "OK"}
|
||||
12
src/backend/langflow/services/auth/factory.py
Normal file
12
src/backend/langflow/services/auth/factory.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.auth.service import AuthManager
|
||||
|
||||
|
||||
class AuthManagerFactory(ServiceFactory):
|
||||
name = "auth_manager"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(AuthManager)
|
||||
|
||||
def create(self, settings_manager):
|
||||
return AuthManager(settings_manager)
|
||||
18
src/backend/langflow/services/auth/service.py
Normal file
18
src/backend/langflow/services/auth/service.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from fastapi import Request
|
||||
from langflow.services.base import Service
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.manager import SettingsManager
|
||||
|
||||
|
||||
class AuthManager(Service):
|
||||
name = "auth_manager"
|
||||
|
||||
def __init__(self, settings_manager: "SettingsManager"):
|
||||
self.settings_manager = settings_manager
|
||||
|
||||
# We need to define a function that can be passed to the Depends() function.
|
||||
# This function will be called by FastAPI to run oauth2_scheme
|
||||
def run_oauth2_scheme(self, request: Request):
|
||||
return self.settings_manager.auth_settings.oauth2_scheme(request=request)
|
||||
|
|
@ -1,28 +1,30 @@
|
|||
from uuid import UUID
|
||||
from typing import Annotated
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
from passlib.context import CryptContext
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from langflow.services.utils import get_settings_manager, get_session
|
||||
|
||||
from langflow.database.models.user import (
|
||||
User,
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from jose import JWTError, jwt
|
||||
from typing import Annotated, Coroutine
|
||||
from uuid import UUID
|
||||
from langflow.services.auth.service import AuthManager
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.models.user.crud import (
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
update_user_last_login_at,
|
||||
)
|
||||
from langflow.services.utils import get_session, get_settings_manager
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||
def auth_scheme_dependency(request: Request):
|
||||
settings_manager = (
|
||||
get_settings_manager()
|
||||
) # Assuming get_settings_manager is defined
|
||||
|
||||
return AuthManager(settings_manager).run_oauth2_scheme(request)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
|
||||
token: Annotated[str, Depends(auth_scheme_dependency)],
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
settings_manager = get_settings_manager()
|
||||
|
||||
|
|
@ -32,11 +34,14 @@ async def get_current_user(
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if isinstance(token, Coroutine):
|
||||
token = await token
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings_manager.settings.SECRET_KEY,
|
||||
algorithms=[settings_manager.settings.ALGORITHM],
|
||||
settings_manager.auth_settings.SECRET_KEY,
|
||||
algorithms=[settings_manager.auth_settings.ALGORITHM],
|
||||
)
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
|
@ -47,25 +52,39 @@ async def get_current_user(
|
|||
raise credentials_exception from e
|
||||
|
||||
user = get_user_by_id(db, user_id) # type: ignore
|
||||
if user is None:
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=401, detail="Inactive user")
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
settings_manager = get_settings_manager()
|
||||
return settings_manager.auth_settings.pwd_context.verify(
|
||||
plain_password, hashed_password
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
settings_manager = get_settings_manager()
|
||||
return settings_manager.auth_settings.pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_token(data: dict, expires_delta: timedelta):
|
||||
|
|
@ -77,21 +96,23 @@ def create_token(data: dict, expires_delta: timedelta):
|
|||
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings_manager.settings.SECRET_KEY,
|
||||
algorithm=settings_manager.settings.ALGORITHM,
|
||||
settings_manager.auth_settings.SECRET_KEY,
|
||||
algorithm=settings_manager.auth_settings.ALGORITHM,
|
||||
)
|
||||
|
||||
|
||||
def create_super_user(db: Session = Depends(get_session)) -> User:
|
||||
settings_manager = get_settings_manager()
|
||||
|
||||
super_user = get_user_by_username(db, settings_manager.settings.FIRST_SUPERUSER)
|
||||
super_user = get_user_by_username(
|
||||
db, settings_manager.auth_settings.FIRST_SUPERUSER
|
||||
)
|
||||
|
||||
if not super_user:
|
||||
super_user = User(
|
||||
username=settings_manager.settings.FIRST_SUPERUSER,
|
||||
username=settings_manager.auth_settings.FIRST_SUPERUSER,
|
||||
password=get_password_hash(
|
||||
settings_manager.settings.FIRST_SUPERUSER_PASSWORD
|
||||
settings_manager.auth_settings.FIRST_SUPERUSER_PASSWORD
|
||||
),
|
||||
is_superuser=True,
|
||||
is_active=True,
|
||||
|
|
@ -147,7 +168,7 @@ def create_user_tokens(
|
|||
settings_manager = get_settings_manager()
|
||||
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings_manager.settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
minutes=settings_manager.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
access_token = create_token(
|
||||
data={"sub": str(user_id)},
|
||||
|
|
@ -155,7 +176,7 @@ def create_user_tokens(
|
|||
)
|
||||
|
||||
refresh_token_expires = timedelta(
|
||||
minutes=settings_manager.settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
minutes=settings_manager.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
refresh_token = create_token(
|
||||
data={"sub": str(user_id), "type": "rf"},
|
||||
|
|
@ -179,8 +200,8 @@ def create_refresh_token(refresh_token: str, db: Session = Depends(get_session))
|
|||
try:
|
||||
payload = jwt.decode(
|
||||
refresh_token,
|
||||
settings_manager.settings.SECRET_KEY,
|
||||
algorithms=[settings_manager.settings.ALGORITHM],
|
||||
settings_manager.auth_settings.SECRET_KEY,
|
||||
algorithms=[settings_manager.auth_settings.ALGORITHM],
|
||||
)
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
|
@ -6,6 +6,6 @@ class CacheManagerFactory(ServiceFactory):
|
|||
def __init__(self):
|
||||
super().__init__(CacheManager)
|
||||
|
||||
def create(self, settings_service):
|
||||
def create(self):
|
||||
# Here you would have logic to create and configure a CacheManager
|
||||
return CacheManager()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ class ChatManagerFactory(ServiceFactory):
|
|||
def __init__(self):
|
||||
super().__init__(ChatManager)
|
||||
|
||||
def create(self, settings_service):
|
||||
def create(self):
|
||||
# Here you would have logic to create and configure a ChatManager
|
||||
return ChatManager()
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ class DatabaseManagerFactory(ServiceFactory):
|
|||
def __init__(self):
|
||||
super().__init__(DatabaseManager)
|
||||
|
||||
def create(self, settings_service: "SettingsManager"):
|
||||
def create(self, settings_manager: "SettingsManager"):
|
||||
# Here you would have logic to create and configure a DatabaseManager
|
||||
if not settings_service.settings.DATABASE_URL:
|
||||
if not settings_manager.settings.DATABASE_URL:
|
||||
raise ValueError("No database URL provided")
|
||||
return DatabaseManager(settings_service.settings.DATABASE_URL)
|
||||
return DatabaseManager(settings_manager.settings.DATABASE_URL)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from .flow import Flow
|
||||
from .user import User
|
||||
from .token import Token
|
||||
from .api_key import ApiKey
|
||||
|
||||
|
||||
__all__ = ["Flow"]
|
||||
__all__ = ["Flow", "User", "Token", "ApiKey"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .api_key import ApiKey, ApiKeyCreate, UnmaskedApiKeyRead, ApiKeyRead
|
||||
|
||||
__all__ = ["ApiKey", "ApiKeyCreate", "UnmaskedApiKeyRead", "ApiKeyRead"]
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from pydantic import validator
|
||||
from sqlmodel import Field, Relationship
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from langflow.services.database.models.base import SQLModelSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.user import User
|
||||
|
||||
|
||||
class ApiKeyBase(SQLModelSerializable):
|
||||
name: Optional[str] = Field(index=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
||||
class ApiKey(ApiKeyBase, table=True):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
|
||||
api_key: str = Field(index=True, unique=True)
|
||||
# User relationship
|
||||
user_id: UUID = Field(index=True, foreign_key="user.id")
|
||||
user: "User" = Relationship(back_populates="api_keys")
|
||||
|
||||
|
||||
class ApiKeyCreate(ApiKeyBase):
|
||||
api_key: Optional[str] = None
|
||||
user_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class UnmaskedApiKeyRead(ApiKeyBase):
|
||||
id: UUID
|
||||
api_key: str = Field(index=True, unique=True)
|
||||
user_id: UUID = Field()
|
||||
|
||||
|
||||
class ApiKeyRead(ApiKeyBase):
|
||||
id: UUID
|
||||
api_key: str = Field(index=True, unique=True)
|
||||
user_id: UUID = Field()
|
||||
|
||||
@validator("api_key", always=True)
|
||||
def mask_api_key(cls, v):
|
||||
# This validator will always run, and will mask the API key
|
||||
return f"{'*' * 8}{v[-4:]}"
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import secrets
|
||||
from uuid import UUID
|
||||
from typing import List
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from sqlmodel import Session, select
|
||||
from langflow.services.database.models.api_key import (
|
||||
ApiKey,
|
||||
ApiKeyCreate,
|
||||
UnmaskedApiKeyRead,
|
||||
ApiKeyRead,
|
||||
)
|
||||
|
||||
|
||||
def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]:
|
||||
query = select(ApiKey).where(ApiKey.user_id == user_id)
|
||||
api_keys = session.exec(query).all()
|
||||
return [ApiKeyRead.from_orm(api_key) for api_key in api_keys]
|
||||
|
||||
|
||||
def create_api_key(
|
||||
session: Session, api_key_create: ApiKeyCreate, user_id: UUID
|
||||
) -> UnmaskedApiKeyRead:
|
||||
# Generate a random API key with 32 bytes of randomness
|
||||
generated_api_key = secrets.token_urlsafe(32)
|
||||
|
||||
# hash the API key
|
||||
hashed_api_key = get_password_hash(generated_api_key)
|
||||
# Use the generated key to create the ApiKey object
|
||||
|
||||
api_key = ApiKey(api_key=hashed_api_key, name=api_key_create.name, user_id=user_id)
|
||||
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
unmasked = UnmaskedApiKeyRead.from_orm(api_key)
|
||||
unmasked.api_key = generated_api_key
|
||||
return unmasked
|
||||
|
||||
|
||||
def delete_api_key(session: Session, api_key_id: UUID) -> None:
|
||||
api_key = session.get(ApiKey, api_key_id)
|
||||
if api_key is None:
|
||||
raise ValueError("API Key not found")
|
||||
session.delete(api_key)
|
||||
session.commit()
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .component import Component, ComponentModel
|
||||
|
||||
__all__ = ["Component", "ComponentModel"]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .flow import Flow, FlowCreate, FlowRead, FlowUpdate
|
||||
|
||||
__all__ = ["Flow", "FlowCreate", "FlowRead", "FlowUpdate"]
|
||||
|
|
@ -6,8 +6,6 @@ from sqlmodel import Field, JSON, Column
|
|||
from uuid import UUID, uuid4
|
||||
from typing import Dict, Optional
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
|
||||
|
||||
class FlowBase(SQLModelSerializable):
|
||||
name: str = Field(index=True)
|
||||
|
|
@ -16,7 +14,6 @@ class FlowBase(SQLModelSerializable):
|
|||
|
||||
@validator("data")
|
||||
def validate_json(v):
|
||||
# dict_keys(['description', 'name', 'id', 'data'])
|
||||
if not v:
|
||||
return v
|
||||
if not isinstance(v, dict):
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from .token import Token
|
||||
|
||||
__all__ = [
|
||||
"Token",
|
||||
]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from .user import User, UserCreate, UserRead, UserUpdate
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserRead",
|
||||
"UserUpdate",
|
||||
]
|
||||
53
src/backend/langflow/services/database/models/user/crud.py
Normal file
53
src/backend/langflow/services/database/models/user/crud.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Union
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException
|
||||
from langflow.services.database.models.user.user import User, UserUpdate
|
||||
from langflow.services.utils import get_session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> Union[User, None]:
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
|
||||
def get_user_by_id(db: Session, id: UUID) -> Union[User, None]:
|
||||
return db.query(User).filter(User.id == id).first()
|
||||
|
||||
|
||||
def update_user(
|
||||
user_id: UUID, user: UserUpdate, db: Session = Depends(get_session)
|
||||
) -> User:
|
||||
user_db = get_user_by_id(db, user_id)
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user_db_by_username = get_user_by_username(db, user.username) # type: ignore
|
||||
if user_db_by_username and user_db_by_username.id != user_id:
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
user_data = user.dict(exclude_unset=True)
|
||||
for attr, value in user_data.items():
|
||||
if hasattr(user_db, attr) and value is not None:
|
||||
setattr(user_db, attr, value)
|
||||
|
||||
user_db.updated_at = datetime.now(timezone.utc)
|
||||
flag_modified(user_db, "updated_at")
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
|
||||
user_data = UserUpdate(last_login_at=datetime.now(timezone.utc)) # type: ignore
|
||||
|
||||
return update_user(user_id, user_data, db)
|
||||
44
src/backend/langflow/services/database/models/user/user.py
Normal file
44
src/backend/langflow/services/database/models/user/user.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from langflow.services.database.models.base import SQLModel, SQLModelSerializable
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.api_key import ApiKey
|
||||
|
||||
|
||||
class User(SQLModelSerializable, table=True):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
|
||||
username: str = Field(index=True, unique=True)
|
||||
password: str = Field()
|
||||
is_active: bool = Field(default=False)
|
||||
is_superuser: bool = Field(default=False)
|
||||
create_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
api_keys: list["ApiKey"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserCreate(SQLModel):
|
||||
username: str = Field()
|
||||
password: str = Field()
|
||||
|
||||
|
||||
class UserRead(SQLModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
username: str = Field()
|
||||
is_active: bool = Field()
|
||||
is_superuser: bool = Field()
|
||||
create_at: datetime = Field()
|
||||
updated_at: datetime = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
||||
|
||||
class UserUpdate(SQLModel):
|
||||
username: Optional[str] = Field()
|
||||
is_active: Optional[bool] = Field()
|
||||
is_superuser: Optional[bool] = Field()
|
||||
last_login_at: Optional[datetime] = Field()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.services.schema import ServiceType
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.factory import ServiceFactory
|
||||
|
|
@ -13,13 +13,21 @@ class ServiceManager:
|
|||
def __init__(self):
|
||||
self.services = {}
|
||||
self.factories = {}
|
||||
self.dependencies = {}
|
||||
|
||||
def register_factory(self, service_factory: "ServiceFactory"):
|
||||
def register_factory(
|
||||
self,
|
||||
service_factory: "ServiceFactory",
|
||||
dependencies: Optional[List[ServiceType]] = None,
|
||||
):
|
||||
"""
|
||||
Registers a new factory.
|
||||
Registers a new factory with dependencies.
|
||||
"""
|
||||
if service_factory.service_class.name not in self.factories:
|
||||
self.factories[service_factory.service_class.name] = service_factory
|
||||
if dependencies is None:
|
||||
dependencies = []
|
||||
service_name = service_factory.service_class.name
|
||||
self.factories[service_name] = service_factory
|
||||
self.dependencies[service_name] = dependencies
|
||||
|
||||
def get(self, service_name: ServiceType):
|
||||
"""
|
||||
|
|
@ -32,17 +40,25 @@ class ServiceManager:
|
|||
|
||||
def _create_service(self, service_name: ServiceType):
|
||||
"""
|
||||
Create a new service given its name.
|
||||
Create a new service given its name, handling dependencies.
|
||||
"""
|
||||
self._validate_service_creation(service_name)
|
||||
|
||||
if service_name == ServiceType.SETTINGS_MANAGER:
|
||||
self.services[service_name] = self.factories[service_name].create()
|
||||
else:
|
||||
settings_service = self.get(ServiceType.SETTINGS_MANAGER)
|
||||
self.services[service_name] = self.factories[service_name].create(
|
||||
settings_service
|
||||
)
|
||||
# Create dependencies first
|
||||
for dependency in self.dependencies.get(service_name, []):
|
||||
if dependency not in self.services:
|
||||
self._create_service(dependency)
|
||||
|
||||
# Collect the dependent services
|
||||
dependent_services = {
|
||||
dep.value: self.services[dep]
|
||||
for dep in self.dependencies.get(service_name, [])
|
||||
}
|
||||
|
||||
# Create the actual service
|
||||
self.services[service_name] = self.factories[service_name].create(
|
||||
**dependent_services
|
||||
)
|
||||
|
||||
def _validate_service_creation(self, service_name: ServiceType):
|
||||
"""
|
||||
|
|
@ -53,14 +69,6 @@ class ServiceManager:
|
|||
f"No factory registered for the service class '{service_name.name}'"
|
||||
)
|
||||
|
||||
if (
|
||||
ServiceType.SETTINGS_MANAGER not in self.factories
|
||||
and service_name != ServiceType.SETTINGS_MANAGER
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot create service '{service_name.name}' before the settings service"
|
||||
)
|
||||
|
||||
def update(self, service_name: ServiceType):
|
||||
"""
|
||||
Update a service by its name.
|
||||
|
|
@ -81,12 +89,24 @@ def initialize_services():
|
|||
from langflow.services.cache import factory as cache_factory
|
||||
from langflow.services.chat import factory as chat_factory
|
||||
from langflow.services.settings import factory as settings_factory
|
||||
from langflow.services.auth import factory as auth_factory
|
||||
|
||||
service_manager.register_factory(settings_factory.SettingsManagerFactory())
|
||||
service_manager.register_factory(database_factory.DatabaseManagerFactory())
|
||||
service_manager.register_factory(
|
||||
auth_factory.AuthManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER]
|
||||
)
|
||||
service_manager.register_factory(
|
||||
database_factory.DatabaseManagerFactory(),
|
||||
dependencies=[ServiceType.SETTINGS_MANAGER],
|
||||
)
|
||||
service_manager.register_factory(cache_factory.CacheManagerFactory())
|
||||
service_manager.register_factory(chat_factory.ChatManagerFactory())
|
||||
|
||||
# Test cache connection
|
||||
service_manager.get(ServiceType.CACHE_MANAGER)
|
||||
# Test database connection
|
||||
service_manager.get(ServiceType.DATABASE_MANAGER)
|
||||
|
||||
|
||||
def initialize_settings_manager():
|
||||
"""
|
||||
|
|
@ -95,3 +115,22 @@ def initialize_settings_manager():
|
|||
from langflow.services.settings import factory as settings_factory
|
||||
|
||||
service_manager.register_factory(settings_factory.SettingsManagerFactory())
|
||||
|
||||
|
||||
def initialize_session_manager():
|
||||
"""
|
||||
Initialize the session manager.
|
||||
"""
|
||||
from langflow.services.session import factory as session_manager_factory
|
||||
from langflow.services.cache import factory as cache_factory
|
||||
|
||||
initialize_settings_manager()
|
||||
|
||||
service_manager.register_factory(
|
||||
cache_factory.CacheManagerFactory(), dependencies=[ServiceType.SETTINGS_MANAGER]
|
||||
)
|
||||
|
||||
service_manager.register_factory(
|
||||
session_manager_factory.SessionManagerFactory(),
|
||||
dependencies=[ServiceType.CACHE_MANAGER],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ class ServiceType(str, Enum):
|
|||
registered with the service manager.
|
||||
"""
|
||||
|
||||
AUTH_MANAGER = "auth_manager"
|
||||
CACHE_MANAGER = "cache_manager"
|
||||
SETTINGS_MANAGER = "settings_manager"
|
||||
DATABASE_MANAGER = "database_manager"
|
||||
|
|
|
|||
35
src/backend/langflow/services/settings/auth.py
Normal file
35
src/backend/langflow/services/settings/auth.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
import secrets
|
||||
|
||||
from pydantic import BaseSettings
|
||||
from passlib.context import CryptContext
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
|
||||
class AuthSettings(BaseSettings):
|
||||
# Login settings
|
||||
SECRET_KEY: str = secrets.token_hex(32)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
|
||||
|
||||
# API Key to execute /process endpoint
|
||||
API_KEY_SECRET_KEY: Optional[
|
||||
str
|
||||
] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
|
||||
API_KEY_ALGORITHM: str = "HS256"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# If AUTO_LOGIN = True
|
||||
# > The application does not request login and logs in automatically as a super user.
|
||||
AUTO_LOGIN: bool = True
|
||||
FIRST_SUPERUSER: str = "langflow"
|
||||
FIRST_SUPERUSER_PASSWORD: str = "langflow"
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{API_V1_STR}/login")
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
extra = "ignore"
|
||||
env_prefix = "LANGFLOW_"
|
||||
|
|
@ -3,7 +3,6 @@ import json
|
|||
import orjson
|
||||
import os
|
||||
from shutil import copy2
|
||||
import secrets
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -42,24 +41,6 @@ class Settings(BaseSettings):
|
|||
REMOVE_API_KEYS: bool = False
|
||||
COMPONENTS_PATH: List[str] = []
|
||||
|
||||
# Login settings
|
||||
SECRET_KEY: str = secrets.token_hex(32)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
|
||||
|
||||
# API Key to execute /process endpoint
|
||||
API_KEY_SECRET_KEY: Optional[
|
||||
str
|
||||
] = "b82818e0ad4ff76615c5721ee21004b07d84cd9b87ba4d9cb42374da134b841a"
|
||||
API_KEY_ALGORITHM: str = "HS256"
|
||||
|
||||
# If AUTO_LOGIN = True
|
||||
# > The application does not request login and logs in automatically as a super user.
|
||||
AUTO_LOGIN: bool = True
|
||||
FIRST_SUPERUSER: str = "langflow"
|
||||
FIRST_SUPERUSER_PASSWORD: str = "langflow"
|
||||
|
||||
@validator("CONFIG_DIR", pre=True, allow_reuse=True)
|
||||
def set_langflow_dir(cls, value):
|
||||
if not value:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langflow.services.base import Service
|
||||
from langflow.services.settings.auth import AuthSettings
|
||||
from langflow.services.settings.base import Settings
|
||||
from langflow.utils.logger import logger
|
||||
import os
|
||||
|
|
@ -8,9 +9,10 @@ import yaml
|
|||
class SettingsManager(Service):
|
||||
name = "settings_manager"
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
def __init__(self, settings: Settings, auth_settings: AuthSettings):
|
||||
super().__init__()
|
||||
self.settings = settings
|
||||
self.auth_settings = auth_settings
|
||||
|
||||
@classmethod
|
||||
def load_settings_from_yaml(cls, file_path: str) -> "SettingsManager":
|
||||
|
|
@ -33,4 +35,5 @@ class SettingsManager(Service):
|
|||
)
|
||||
|
||||
settings = Settings(**settings_dict)
|
||||
return cls(settings)
|
||||
auth_settings = AuthSettings()
|
||||
return cls(settings, auth_settings)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from typing import AsyncGenerator, TYPE_CHECKING
|
|||
from langflow.api.v1.flows import get_session
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.user.user import User, UserCreate
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
|
@ -155,3 +157,38 @@ def session_getter_fixture(client):
|
|||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(client):
|
||||
user_data = UserCreate(
|
||||
username="testuser",
|
||||
password="testpassword",
|
||||
)
|
||||
response = client.post("/api/v1/user", json=user_data.dict())
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def active_user(session):
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash(
|
||||
"testpassword"
|
||||
), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_headers(client, active_user):
|
||||
login_data = {"username": active_user.username, "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
|
|
|||
50
tests/test_api_key.py
Normal file
50
tests/test_api_key.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
from langflow.services.database.models.api_key import ApiKeyCreate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(client, logged_in_headers, active_user):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = client.post(
|
||||
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_get_api_keys(client, logged_in_headers, api_key):
|
||||
response = client.get("api/v1/api_key", headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.text
|
||||
data = response.json()
|
||||
assert "total_count" in data
|
||||
assert "user_id" in data
|
||||
assert "api_keys" in data
|
||||
assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"])
|
||||
# assert all api keys in data["api_keys"] are masked
|
||||
assert all("**" in api_key["api_key"] for api_key in data["api_keys"])
|
||||
# Add more assertions as needed based on the expected data structure and content
|
||||
|
||||
|
||||
def test_create_api_key(client, logged_in_headers):
|
||||
api_key_name = "test-api-key"
|
||||
response = client.post(
|
||||
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
assert "api_key" in data
|
||||
# When creating the API key is returned which is
|
||||
# the only time the API key is unmasked
|
||||
assert "**" not in data["api_key"]
|
||||
|
||||
|
||||
def test_delete_api_key(client, logged_in_headers, active_user, api_key):
|
||||
# Assuming a function to create a test API key, returning the key ID
|
||||
api_key_id = api_key["id"]
|
||||
response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["detail"] == "API Key deleted"
|
||||
# Optionally, add a follow-up check to ensure that the key is actually removed from the database
|
||||
|
|
@ -27,4 +27,4 @@ def test_components_path(runner, client, default_settings):
|
|||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
settings_manager = utils.get_settings_manager()
|
||||
assert temp_dir in settings_manager.settings.COMPONENTS_PATH
|
||||
assert str(temp_dir) in settings_manager.settings.COMPONENTS_PATH
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import orjson
|
|||
import pytest
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
|
|
|||
47
tests/test_login.py
Normal file
47
tests/test_login.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return User(
|
||||
username="testuser",
|
||||
password=get_password_hash(
|
||||
"testpassword"
|
||||
), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
|
||||
def test_login_successful(client, test_user, session):
|
||||
# Adding the test user to the database
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_username(client):
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_password(client, test_user, session):
|
||||
# Adding the test user to the database
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
213
tests/test_user.py
Normal file
213
tests/test_user.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
from datetime import datetime
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.utils import get_settings_manager
|
||||
import pytest
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user(client, session):
|
||||
return create_super_user(session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user_headers(client, super_user):
|
||||
settings_manager = get_settings_manager()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
login_data = {
|
||||
"username": auth_settings.FIRST_SUPERUSER,
|
||||
"password": auth_settings.FIRST_SUPERUSER_PASSWORD,
|
||||
}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deactivated_user(session):
|
||||
user = User(
|
||||
username="deactivateduser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
is_superuser=False,
|
||||
last_login_at=datetime.now(),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_user_waiting_for_approval(client, session):
|
||||
# Create a user that is not active and has never logged in
|
||||
user = User(
|
||||
username="waitingforapproval",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
login_data = {"username": "waitingforapproval", "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Waiting for approval"
|
||||
|
||||
|
||||
def test_deactivated_user_cannot_login(client, deactivated_user):
|
||||
login_data = {"username": deactivated_user.username, "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json()["detail"] == "Inactive user"
|
||||
|
||||
|
||||
def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_headers):
|
||||
# Assuming the headers for deactivated_user
|
||||
response = client.get("/api/v1/users", headers=logged_in_headers)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json()["detail"] == "The user doesn't have enough privileges"
|
||||
|
||||
|
||||
def test_data_consistency_after_update(client, active_user, logged_in_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(username="newname")
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Fetch the updated user from the database
|
||||
response = client.get("/api/v1/user", headers=logged_in_headers)
|
||||
assert response.json()["username"] == "newname", response.json()
|
||||
|
||||
|
||||
def test_data_consistency_after_delete(client, test_user, super_user_headers):
|
||||
user_id = test_user["id"]
|
||||
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Attempt to fetch the deleted user from the database
|
||||
response = client.get("/api/v1/users", headers=super_user_headers)
|
||||
assert response.status_code == 200
|
||||
assert all(user["id"] != user_id for user in response.json()["users"])
|
||||
|
||||
|
||||
def test_inactive_user(client, session):
|
||||
# Create a user that is not active and has a last_login_at value
|
||||
user = User(
|
||||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
login_data = {"username": "inactiveuser", "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Inactive user"
|
||||
|
||||
|
||||
def test_add_user(client, test_user):
|
||||
assert test_user["username"] == "testuser"
|
||||
|
||||
|
||||
# This is not used in the Frontend at the moment
|
||||
# def test_read_current_user(client: TestClient, active_user):
|
||||
# # First we need to login to get the access token
|
||||
# login_data = {"username": "testuser", "password": "testpassword"}
|
||||
# response = client.post("/api/v1/login", data=login_data)
|
||||
# assert response.status_code == 200
|
||||
|
||||
# headers = {"Authorization": f"Bearer {response.json()['access_token']}"}
|
||||
|
||||
# response = client.get("/api/v1/user", headers=headers)
|
||||
# assert response.status_code == 200, response.json()
|
||||
# assert response.json()["username"] == "testuser"
|
||||
|
||||
|
||||
def test_read_all_users(client, super_user_headers):
|
||||
response = client.get("/api/v1/users", headers=super_user_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert isinstance(response.json()["users"], list)
|
||||
|
||||
|
||||
def test_normal_user_cant_read_all_users(client, logged_in_headers):
|
||||
response = client.get("/api/v1/users", headers=logged_in_headers)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json() == {"detail": "The user doesn't have enough privileges"}
|
||||
|
||||
|
||||
def test_patch_user(client, active_user, logged_in_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(
|
||||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
def test_patch_user_wrong_id(client, active_user, logged_in_headers):
|
||||
user_id = "wrong_id"
|
||||
update_data = UserUpdate(
|
||||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/user/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 422, response.json()
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_delete_user(client, test_user, super_user_headers):
|
||||
user_id = test_user["id"]
|
||||
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"detail": "User deleted"}
|
||||
|
||||
|
||||
def test_delete_user_wrong_id(client, test_user, super_user_headers):
|
||||
user_id = "wrong_id"
|
||||
response = client.delete(f"/api/v1/user/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):
|
||||
user_id = test_user["id"]
|
||||
response = client.delete(f"/api/v1/user/{user_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "The user doesn't have enough privileges"}
|
||||
|
||||
|
||||
# If you still want to test the superuser endpoint
|
||||
def test_add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(client):
|
||||
response = client.post("/api/v1/super_user")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "superuser"
|
||||
Loading…
Add table
Add a link
Reference in a new issue