Merge remote-tracking branch 'origin/dev' into fix_db_location

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-16 19:21:47 -03:00
commit 7539ba3166
114 changed files with 6750 additions and 3275 deletions

View file

View file

@ -0,0 +1,177 @@
from uuid import UUID
from typing import Annotated
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta, timezone
from langflow.services.utils import get_settings_manager
from langflow.services.utils import get_session
from langflow.database.models.user import (
User,
get_user_by_id,
get_user_by_username,
update_user_last_login_at,
)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)
) -> User:
settings_manager = get_settings_manager()
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token,
settings_manager.settings.SECRET_KEY,
algorithms=[settings_manager.settings.ALGORITHM],
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if user_id is None or token_type:
raise credentials_exception
except JWTError as e:
raise credentials_exception from e
user = get_user_by_id(db, user_id) # type: ignore
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def create_token(data: dict, expires_delta: timedelta):
settings_manager = get_settings_manager()
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expires_delta
to_encode["exp"] = expire
return jwt.encode(
to_encode,
settings_manager.settings.SECRET_KEY,
algorithm=settings_manager.settings.ALGORITHM,
)
def create_user_longterm_token(
user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False
) -> dict:
access_token_expires_longterm = timedelta(days=365)
access_token = create_token(
data={"sub": str(user_id)},
expires_delta=access_token_expires_longterm,
)
# Update: last_login_at
if update_last_login:
update_user_last_login_at(user_id, db)
return {
"access_token": access_token,
"refresh_token": None,
"token_type": "bearer",
}
def create_user_tokens(
user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False
) -> dict:
settings_manager = get_settings_manager()
access_token_expires = timedelta(
minutes=settings_manager.settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
access_token = create_token(
data={"sub": str(user_id)},
expires_delta=access_token_expires,
)
refresh_token_expires = timedelta(
minutes=settings_manager.settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
refresh_token = create_token(
data={"sub": str(user_id), "type": "rf"},
expires_delta=refresh_token_expires,
)
# Update: last_login_at
if update_last_login:
update_user_last_login_at(user_id, db)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)):
settings_manager = get_settings_manager()
try:
payload = jwt.decode(
refresh_token,
settings_manager.settings.SECRET_KEY,
algorithms=[settings_manager.settings.ALGORITHM],
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if user_id is None or token_type is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
return create_user_tokens(user_id, db)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
) from e
def authenticate_user(
username: str, password: str, db: Session = Depends(get_session)
) -> User | None:
user = get_user_by_username(db, username)
if not user:
return None
if not user.is_active:
if not user.last_login_at:
raise HTTPException(status_code=400, detail="Waiting for approval")
raise HTTPException(status_code=400, detail="Inactive user")
return user if verify_password(password, user.password) else None

View file

@ -0,0 +1,7 @@
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str

View file

@ -0,0 +1,94 @@
from sqlmodel import Field
from uuid import UUID, uuid4
from pydantic import BaseModel
from typing import Optional, List
from sqlalchemy.orm import Session
from datetime import timezone, datetime
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, Depends
from langflow.services.utils import get_session
from langflow.services.database.models.base import SQLModelSerializable, SQLModel
class User(SQLModelSerializable, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
username: str = Field(index=True, unique=True)
password: str = Field()
is_active: bool = Field(default=False)
is_superuser: bool = Field(default=False)
create_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
last_login_at: Optional[datetime] = Field()
class UserAddModel(SQLModel):
username: str = Field()
password: str = Field()
class UserListModel(SQLModel):
id: UUID = Field(default_factory=uuid4)
username: str = Field()
is_active: bool = Field()
is_superuser: bool = Field()
create_at: datetime = Field()
updated_at: datetime = Field()
last_login_at: Optional[datetime] = Field()
class UserPatchModel(SQLModel):
username: Optional[str] = Field()
is_active: Optional[bool] = Field()
is_superuser: Optional[bool] = Field()
last_login_at: Optional[datetime] = Field()
class UsersResponse(BaseModel):
total_count: int
users: List[UserListModel]
def get_user_by_username(db: Session, username: str) -> User:
db_user = db.query(User).filter(User.username == username).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def get_user_by_id(db: Session, id: UUID) -> User:
db_user = db.query(User).filter(User.id == id).first()
return User.from_orm(db_user) if db_user else None # type: ignore
def update_user(
user_id: UUID, user: UserPatchModel, db: Session = Depends(get_session)
) -> User:
user_db = get_user_by_username(db, user.username) # type: ignore
if user_db and user_db.id != user_id:
raise HTTPException(status_code=409, detail="Username already exists")
user_db = get_user_by_id(db, user_id)
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
try:
user_data = user.dict(exclude_unset=True)
for key, value in user_data.items():
setattr(user_db, key, value)
user_db.updated_at = datetime.now(timezone.utc)
user_db = db.merge(user_db)
db.commit()
if db.identity_key(instance=user_db) is not None:
db.refresh(user_db)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e)) from e
return user_db
def update_user_last_login_at(user_id: UUID, db: Session = Depends(get_session)):
user_data = UserPatchModel(last_login_at=datetime.now(timezone.utc)) # type: ignore
return update_user(user_id, user_data, db)

View file

@ -6,6 +6,8 @@ from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from langflow.api import router
from langflow.routers import login, users, health
from langflow.interface.utils import setup_llm_caching
from langflow.services.database.utils import initialize_database
from langflow.services.manager import initialize_services
@ -19,13 +21,7 @@ def create_app():
app = FastAPI()
origins = [
"*",
]
@app.get("/health")
def get_health():
return {"status": "OK"}
origins = ["*"]
app.add_middleware(
CORSMiddleware,
@ -34,6 +30,11 @@ def create_app():
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(login.router)
app.include_router(users.router)
app.include_router(health.router)
app.include_router(router)
app.on_event("startup")(initialize_services)

View file

View file

@ -0,0 +1,8 @@
from fastapi import APIRouter
router = APIRouter()
@router.get("/health")
def get_health():
return {"status": "OK"}

View file

@ -0,0 +1,62 @@
from uuid import UUID
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from langflow.services.utils import get_session
from langflow.database.models.token import Token
from langflow.auth.auth import (
authenticate_user,
create_user_tokens,
create_refresh_token,
create_user_longterm_token,
)
from langflow.services.utils import get_settings_manager
router = APIRouter()
@router.post("/login", response_model=Token)
async def login_to_get_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_session),
# _: Session = Depends(get_current_active_user)
):
if user := authenticate_user(form_data.username, form_data.password, db):
return create_user_tokens(user_id=user.id, db=db, update_last_login=True)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
@router.get("/auto_login")
async def auto_login(db: Session = Depends(get_session)):
settings_manager = get_settings_manager()
if settings_manager.settings.AUTO_LOGIN:
user_id = UUID("3fa85f64-5717-4562-b3fc-2c963f66afa6")
return create_user_longterm_token(user_id, db)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": "Auto login is disabled. Please enable it in the settings",
"auto_login": False,
},
)
@router.post("/refresh")
async def refresh_token(token: str):
if token:
return create_refresh_token(token)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
headers={"WWW-Authenticate": "Bearer"},
)

View file

@ -0,0 +1,133 @@
from uuid import UUID
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select
from fastapi import APIRouter, Depends, HTTPException
from langflow.services.utils import get_session
from langflow.auth.auth import get_current_active_user, get_password_hash
from langflow.database.models.user import (
User,
UserAddModel,
UserListModel,
UserPatchModel,
UsersResponse,
update_user,
)
router = APIRouter(tags=["Login"])
@router.post("/user", response_model=UserListModel)
def add_user(
user: UserAddModel,
db: Session = Depends(get_session),
) -> User:
"""
Add a new user to the database.
"""
new_user = User(**user.dict())
try:
new_user.password = get_password_hash(user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail="User exists") from e
return new_user
@router.get("/user", response_model=UserListModel)
def read_current_user(current_user: User = Depends(get_current_active_user)) -> User:
"""
Retrieve the current user's data.
"""
return current_user
@router.get("/users", response_model=UsersResponse)
def read_all_users(
skip: int = 0,
limit: int = 10,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> UsersResponse:
"""
Retrieve a list of users from the database with pagination.
"""
query = select(User).offset(skip).limit(limit)
users = db.execute(query).fetchall()
count_query = select(func.count()).select_from(User) # type: ignore
total_count = db.execute(count_query).scalar()
return UsersResponse(
total_count=total_count, # type: ignore
users=[UserListModel(**dict(user.User)) for user in users],
)
@router.patch("/user/{user_id}", response_model=UserListModel)
def patch_user(
user_id: UUID,
user: UserPatchModel,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> User:
"""
Update an existing user's data.
"""
return update_user(user_id, user, db)
@router.delete("/user/{user_id}")
def delete_user(
user_id: UUID,
_: Session = Depends(get_current_active_user),
db: Session = Depends(get_session),
) -> dict:
"""
Delete a user from the database.
"""
user_db = db.query(User).filter(User.id == user_id).first()
if not user_db:
raise HTTPException(status_code=404, detail="User not found")
db.delete(user_db)
db.commit()
return {"detail": "User deleted"}
# TODO: REMOVE - Just for testing purposes
@router.post("/super_user", response_model=User)
def add_super_user_for_testing_purposes_delete_me_before_merge_into_dev(
db: Session = Depends(get_session),
) -> User:
"""
Add a superuser for testing purposes.
(This should be removed in production)
"""
new_user = User(
username="superuser",
password="12345",
is_active=True,
is_superuser=True,
last_login_at=None,
)
try:
new_user.password = get_password_hash(new_user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
except IntegrityError as e:
db.rollback()
raise HTTPException(status_code=400, detail="User exists") from e
return new_user

View file

@ -1,11 +1,16 @@
from pathlib import Path
from typing import TYPE_CHECKING
from langflow.services.base import Service
from langflow.services.utils import get_settings_manager
from sqlmodel import SQLModel, Session, create_engine
from langflow.utils.logger import logger
from alembic.config import Config
from alembic import command
from langflow.services.database import models # noqa
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
class DatabaseManager(Service):
name = "database_manager"
@ -17,7 +22,19 @@ class DatabaseManager(Service):
langflow_dir = Path(__file__).parent.parent.parent
self.script_location = langflow_dir / "alembic"
self.alembic_cfg_path = langflow_dir / "alembic.ini"
self.engine = create_engine(database_url)
self.engine = self._create_engine()
def _create_engine(self) -> "Engine":
"""Create the engine for the database."""
settings_manager = get_settings_manager()
if (
settings_manager.settings.DATABASE_URL
and settings_manager.settings.DATABASE_URL.startswith("sqlite")
):
connect_args = {"check_same_thread": False}
else:
connect_args = {}
return create_engine(self.database_url, connect_args=connect_args)
def __enter__(self):
self._session = Session(self.engine)

View file

@ -2,6 +2,7 @@ import contextlib
import json
import os
from shutil import copy2
import secrets
from typing import Optional, List
from pathlib import Path
@ -39,6 +40,15 @@ class Settings(BaseSettings):
REMOVE_API_KEYS: bool = False
COMPONENTS_PATH: List[str] = []
# Login settings
SECRET_KEY: str = secrets.token_hex(32)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
# If AUTO_LOGIN = True
# > The application does not request login and logs in automatically as a super user.
AUTO_LOGIN: bool = True
@validator("CONFIG_DIR", pre=True, allow_reuse=True)
def set_langflow_dir(cls, value):
if not value:

View file

@ -1,171 +0,0 @@
import contextlib
import json
import os
from typing import Optional, List
from pathlib import Path
import yaml
from pydantic import BaseSettings, root_validator, validator
from langflow.utils.logger import logger
BASE_COMPONENTS_PATH = str(Path(__file__).parent / "components")
class Settings(BaseSettings):
CHAINS: dict = {}
AGENTS: dict = {}
PROMPTS: dict = {}
LLMS: dict = {}
TOOLS: dict = {}
MEMORIES: dict = {}
EMBEDDINGS: dict = {}
VECTORSTORES: dict = {}
DOCUMENTLOADERS: dict = {}
WRAPPERS: dict = {}
RETRIEVERS: dict = {}
TOOLKITS: dict = {}
TEXTSPLITTERS: dict = {}
UTILITIES: dict = {}
OUTPUT_PARSERS: dict = {}
CUSTOM_COMPONENTS: dict = {}
DEV: bool = False
DATABASE_URL: Optional[str] = None
CACHE: str = "InMemoryCache"
REMOVE_API_KEYS: bool = False
COMPONENTS_PATH: List[str] = []
@validator("DATABASE_URL", pre=True)
def set_database_url(cls, value):
if not value:
logger.debug(
"No database_url provided, trying LANGFLOW_DATABASE_URL env variable"
)
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
value = langflow_database_url
logger.debug("Using LANGFLOW_DATABASE_URL env variable.")
else:
logger.debug("No DATABASE_URL env variable, using sqlite database")
value = "sqlite:///./langflow.db"
return value
@validator("COMPONENTS_PATH", pre=True)
def set_components_path(cls, value):
if os.getenv("LANGFLOW_COMPONENTS_PATH"):
logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path")
langflow_component_path = os.getenv("LANGFLOW_COMPONENTS_PATH")
if (
Path(langflow_component_path).exists()
and langflow_component_path not in value
):
if isinstance(langflow_component_path, list):
for path in langflow_component_path:
if path not in value:
value.append(path)
logger.debug(
f"Extending {langflow_component_path} to components_path"
)
elif langflow_component_path not in value:
value.append(langflow_component_path)
logger.debug(
f"Appending {langflow_component_path} to components_path"
)
if not value:
value = [BASE_COMPONENTS_PATH]
logger.debug("Setting default components path to components_path")
elif BASE_COMPONENTS_PATH not in value:
value.append(BASE_COMPONENTS_PATH)
logger.debug("Adding default components path to components_path")
logger.debug(f"Components path: {value}")
return value
class Config:
validate_assignment = True
extra = "ignore"
env_prefix = "LANGFLOW_"
@root_validator(allow_reuse=True)
def validate_lists(cls, values):
for key, value in values.items():
if key != "dev" and not value:
values[key] = []
return values
def update_from_yaml(self, file_path: str, dev: bool = False):
new_settings = load_settings_from_yaml(file_path)
self.CHAINS = new_settings.CHAINS or {}
self.AGENTS = new_settings.AGENTS or {}
self.PROMPTS = new_settings.PROMPTS or {}
self.LLMS = new_settings.LLMS or {}
self.TOOLS = new_settings.TOOLS or {}
self.MEMORIES = new_settings.MEMORIES or {}
self.WRAPPERS = new_settings.WRAPPERS or {}
self.TOOLKITS = new_settings.TOOLKITS or {}
self.TEXTSPLITTERS = new_settings.TEXTSPLITTERS or {}
self.UTILITIES = new_settings.UTILITIES or {}
self.EMBEDDINGS = new_settings.EMBEDDINGS or {}
self.VECTORSTORES = new_settings.VECTORSTORES or {}
self.DOCUMENTLOADERS = new_settings.DOCUMENTLOADERS or {}
self.RETRIEVERS = new_settings.RETRIEVERS or {}
self.OUTPUT_PARSERS = new_settings.OUTPUT_PARSERS or {}
self.CUSTOM_COMPONENTS = new_settings.CUSTOM_COMPONENTS or {}
self.COMPONENTS_PATH = new_settings.COMPONENTS_PATH or []
self.DEV = dev
def update_settings(self, **kwargs):
logger.debug("Updating settings")
for key, value in kwargs.items():
# value may contain sensitive information, so we don't want to log it
if not hasattr(self, key):
logger.debug(f"Key {key} not found in settings")
continue
logger.debug(f"Updating {key}")
if isinstance(getattr(self, key), list):
# value might be a '[something]' string
with contextlib.suppress(json.decoder.JSONDecodeError):
value = json.loads(str(value))
if isinstance(value, list):
for item in value:
if item not in getattr(self, key):
getattr(self, key).append(item)
logger.debug(f"Extended {key}")
else:
getattr(self, key).append(value)
logger.debug(f"Appended {key}")
else:
setattr(self, key, value)
logger.debug(f"Updated {key}")
logger.debug(f"{key}: {getattr(self, key)}")
def save_settings_to_yaml(settings: Settings, file_path: str):
with open(file_path, "w") as f:
settings_dict = settings.dict()
yaml.dump(settings_dict, f)
def load_settings_from_yaml(file_path: str) -> Settings:
# Check if a string is a valid path or a file name
if "/" not in file_path:
# Get current path
current_path = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_path, file_path)
with open(file_path, "r") as f:
settings_dict = yaml.safe_load(f)
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.__fields__.keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
return Settings(**settings_dict)
settings = load_settings_from_yaml("config.yaml")