ref: Some fixes for ruff rules preview mode (#4197)

* Some fixes for ruff rules preview mode

* Fix mypy error

* More Annotated[] types for fastapi endpoints

* Use type aliases for Depends(get_session) and Depends(get_current_active_user)
This commit is contained in:
Christophe Bornet 2024-10-22 16:04:31 +02:00 committed by GitHub
commit f0eb7b50a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 189 additions and 192 deletions

View file

@ -78,6 +78,7 @@ def set_var_for_macos_issue() -> None:
@app.command()
def run(
*,
host: str | None = typer.Option(None, help="Host to bind the server to.", show_default=False),
workers: int | None = typer.Option(None, help="Number of worker processes.", show_default=False),
worker_timeout: int | None = typer.Option(None, help="Worker timeout in seconds.", show_default=False),

View file

@ -1,13 +1,13 @@
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, HTTPException, status
from loguru import logger
from pydantic import BaseModel
from sqlmodel import Session, select
from sqlmodel import select
from langflow.api.utils import DbSession
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_chat_service, get_session
from langflow.services.deps import get_chat_service
health_check_router = APIRouter(tags=["Health Check"])
@ -38,7 +38,7 @@ async def health():
# It's a reliable health check for a langflow instance
@health_check_router.get("/health_check")
async def health_check(
session: Annotated[Session, Depends(get_session)],
session: DbSession,
) -> HealthResponse:
response = HealthResponse()
# use a fixed valid UUId that UUID collision is very unlikely

View file

@ -2,22 +2,24 @@ from __future__ import annotations
import uuid
from datetime import timedelta
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Annotated, Any
from fastapi import HTTPException, Query
from fastapi import Depends, HTTPException, Query
from fastapi_pagination import Params
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models import User
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from langflow.services.deps import get_session
from langflow.services.store.utils import get_lf_version_from_pypi
if TYPE_CHECKING:
from sqlmodel import Session
from langflow.services.chat.service import ChatService
from langflow.services.store.schema import StoreComponentCreate
@ -27,6 +29,9 @@ API_WORDS = ["api", "key", "token"]
MAX_PAGE_SIZE = 50
MIN_PAGE_SIZE = 1
CurrentActiveUser = Annotated[User, Depends(get_current_active_user)]
DbSession = Annotated[Session, Depends(get_session)]
def has_api_terms(word: str):
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))

View file

@ -1,17 +1,16 @@
from typing import TYPE_CHECKING, Annotated
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Response
from sqlmodel import Session
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils
# Assuming you have these methods in your service layer
from langflow.services.database.models.api_key.crud import create_api_key, delete_api_key, get_api_keys
from langflow.services.database.models.api_key.model import ApiKeyCreate, UnmaskedApiKeyRead
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_settings_service
from langflow.services.deps import get_settings_service
if TYPE_CHECKING:
pass
@ -21,8 +20,8 @@ router = APIRouter(tags=["APIKey"], prefix="/api_key")
@router.get("/")
def get_api_keys_route(
db: Annotated[Session, Depends(get_session)],
current_user: Annotated[User, Depends(auth_utils.get_current_active_user)],
db: DbSession,
current_user: CurrentActiveUser,
) -> ApiKeysResponse:
try:
user_id = current_user.id
@ -36,8 +35,8 @@ def get_api_keys_route(
@router.post("/")
def create_api_key_route(
req: ApiKeyCreate,
current_user: Annotated[User, Depends(auth_utils.get_current_active_user)],
db: Annotated[Session, Depends(get_session)],
current_user: CurrentActiveUser,
db: DbSession,
) -> UnmaskedApiKeyRead:
try:
user_id = current_user.id
@ -49,7 +48,7 @@ def create_api_key_route(
@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
def delete_api_key_route(
api_key_id: UUID,
db: Session = Depends(get_session),
db: DbSession,
):
try:
delete_api_key(db, api_key_id)
@ -62,10 +61,10 @@ def delete_api_key_route(
def save_store_api_key(
api_key_request: ApiKeyCreateRequest,
response: Response,
current_user: Annotated[User, Depends(auth_utils.get_current_active_user)],
db: Annotated[Session, Depends(get_session)],
settings_service=Depends(get_settings_service),
current_user: CurrentActiveUser,
db: DbSession,
):
settings_service = get_settings_service()
auth_settings = settings_service.auth_settings
try:
@ -95,8 +94,8 @@ def save_store_api_key(
@router.delete("/store")
def delete_store_api_key(
current_user: Annotated[User, Depends(auth_utils.get_current_active_user)],
db: Annotated[Session, Depends(get_session)],
current_user: CurrentActiveUser,
db: DbSession,
):
try:
current_user.store_api_key = None

View file

@ -8,7 +8,7 @@ import typing
import uuid
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.background import BackgroundTask
@ -16,6 +16,8 @@ from starlette.responses import ContentStream
from starlette.types import Receive
from langflow.api.utils import (
CurrentActiveUser,
DbSession,
build_and_cache_graph_from_data,
build_graph_from_data,
build_graph_from_db,
@ -38,7 +40,6 @@ from langflow.exceptions.component import ComponentBuildError
from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
from langflow.schema.schema import OutputValue
from langflow.services.auth.utils import get_current_active_user
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
@ -67,12 +68,13 @@ async def try_running_celery_task(vertex, user_id):
@router.post("/build/{flow_id}/vertices")
async def retrieve_vertices_order(
*,
flow_id: uuid.UUID,
background_tasks: BackgroundTasks,
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
session=Depends(get_session),
session: DbSession,
) -> VerticesOrderResponse:
"""Retrieve the vertices order for a given flow.
@ -147,8 +149,8 @@ async def build_flow(
stop_component_id: str | None = None,
start_component_id: str | None = None,
log_builds: bool | None = True,
current_user=Depends(get_current_active_user),
session=Depends(get_session),
current_user: CurrentActiveUser,
session: DbSession,
):
chat_service = get_chat_service()
telemetry_service = get_telemetry_service()
@ -220,7 +222,7 @@ async def build_flow(
lock = chat_service._async_cache_locks[flow_id_str]
vertex_build_result = await graph.build_vertex(
vertex_id=vertex_id,
user_id=current_user.id,
user_id=str(current_user.id),
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
get_cache=chat_service.get_cache,
@ -456,12 +458,13 @@ class DisconnectHandlerStreamingResponse(StreamingResponse):
@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
*,
flow_id: uuid.UUID,
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
files: list[str] | None = None,
current_user=Depends(get_current_active_user),
current_user: CurrentActiveUser,
) -> VertexBuildResponse:
"""Build a vertex instead of the entire graph.
@ -505,7 +508,7 @@ async def build_vertex(
lock = chat_service._async_cache_locks[flow_id_str]
vertex_build_result = await graph.build_vertex(
vertex_id=vertex_id,
user_id=current_user.id,
user_id=str(current_user.id),
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
get_cache=chat_service.get_cache,

View file

@ -8,9 +8,9 @@ from uuid import UUID
import sqlalchemy as sa
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
from loguru import logger
from sqlmodel import Session, select
from sqlmodel import select
from langflow.api.utils import parse_value
from langflow.api.utils import CurrentActiveUser, DbSession, parse_value
from langflow.api.v1.schemas import (
ConfigResponse,
CustomComponentRequest,
@ -40,7 +40,6 @@ from langflow.services.database.models.flow.model import FlowRead
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_session,
get_session_service,
get_settings_service,
get_task_service,
@ -166,7 +165,7 @@ async def simplified_run_flow(
flow: Annotated[FlowRead | None, Depends(get_flow_by_id_or_endpoint_name)],
input_request: SimplifiedAPIRequest | None = None,
stream: bool = False,
api_key_user: UserRead = Depends(api_key_security),
api_key_user: Annotated[UserRead, Depends(api_key_security)],
) -> RunResponse:
"""Executes a specified flow by ID.
@ -352,7 +351,7 @@ async def webhook_run_flow(
RunPayload(
run_is_webhook=True,
run_seconds=int(time.perf_counter() - start_time),
run_success=error_msg == "",
run_success=not error_msg,
run_error_message=error_msg,
),
)
@ -363,14 +362,14 @@ async def webhook_run_flow(
@router.post("/run/advanced/{flow_id}", response_model=RunResponse, response_model_exclude_none=True)
async def experimental_run_flow(
*,
session: Annotated[Session, Depends(get_session)],
session: DbSession,
flow_id: UUID,
inputs: list[InputValueRequest] | None = None,
outputs: list[str] | None = None,
tweaks: Annotated[Tweaks | None, Body(embed=True)] = None,
stream: Annotated[bool, Body(embed=True)] = False,
session_id: Annotated[None | str, Body(embed=True)] = None,
api_key_user: UserRead = Depends(api_key_security),
api_key_user: Annotated[UserRead, Depends(api_key_security)],
) -> RunResponse:
"""Executes a specified flow by ID with optional input values, output selection, tweaks, and streaming capability.
@ -556,7 +555,7 @@ def get_version():
@router.post("/custom_component", status_code=HTTPStatus.OK)
async def custom_component(
raw_code: CustomComponentRequest,
user: Annotated[User, Depends(get_current_active_user)],
user: CurrentActiveUser,
) -> CustomComponentResponse:
component = Component(_code=raw_code.code)
@ -571,7 +570,7 @@ async def custom_component(
@router.post("/custom_component/update", status_code=HTTPStatus.OK)
async def custom_component_update(
code_request: UpdateCustomComponentRequest,
user: Annotated[User, Depends(get_current_active_user)],
user: CurrentActiveUser,
):
"""Update a custom component with the provided code request.

View file

@ -9,10 +9,10 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UploadFileResponse
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.flow import Flow
from langflow.services.deps import get_session, get_settings_service, get_storage_service
from langflow.services.deps import get_settings_service, get_storage_service
from langflow.services.storage.service import StorageService
from langflow.services.storage.utils import build_content_type_from_extension
@ -24,8 +24,8 @@ router = APIRouter(tags=["Files"], prefix="/files")
# using the current user as the owner
def get_flow_id(
flow_id: UUID,
current_user=Depends(get_current_active_user),
session=Depends(get_session),
current_user: CurrentActiveUser,
session: DbSession,
):
flow_id_str = str(flow_id)
# AttributeError: 'SelectOfScalar' object has no attribute 'first'
@ -42,8 +42,8 @@ async def upload_file(
*,
file: UploadFile,
flow_id: Annotated[UUID, Depends(get_flow_id)],
current_user=Depends(get_current_active_user),
session=Depends(get_session),
current_user: CurrentActiveUser,
session: DbSession,
storage_service: Annotated[StorageService, Depends(get_storage_service)],
) -> UploadFileResponse:
try:

View file

@ -16,10 +16,9 @@ from fastapi_pagination import Page, Params, add_pagination
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlmodel import Session, and_, col, select
from langflow.api.utils import cascade_delete_flow, remove_api_keys, validate_is_component
from langflow.api.utils import CurrentActiveUser, DbSession, cascade_delete_flow, remove_api_keys, validate_is_component
from langflow.api.v1.schemas import FlowListCreate
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.flow import Flow, FlowCreate, FlowRead, FlowUpdate
from langflow.services.database.models.flow.model import FlowHeader
from langflow.services.database.models.flow.utils import get_webhook_component_in_flow
@ -28,7 +27,7 @@ from langflow.services.database.models.folder.model import Folder
from langflow.services.database.models.transactions.crud import get_transactions_by_flow_id
from langflow.services.database.models.user.model import User
from langflow.services.database.models.vertex_builds.crud import get_vertex_builds_by_flow_id
from langflow.services.deps import get_session, get_settings_service
from langflow.services.deps import get_settings_service
from langflow.services.settings.service import SettingsService
# build router
@ -38,9 +37,9 @@ router = APIRouter(prefix="/flows", tags=["Flows"])
@router.post("/", response_model=FlowRead, status_code=201)
def create_flow(
*,
session: Session = Depends(get_session),
session: DbSession,
flow: FlowCreate,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
try:
"""Create a new flow."""
@ -127,13 +126,13 @@ def create_flow(
@router.get("/", response_model=list[FlowRead] | Page[FlowRead] | list[FlowHeader], status_code=200)
def read_flows(
*,
current_user: User = Depends(get_current_active_user),
session: Session = Depends(get_session),
current_user: CurrentActiveUser,
session: DbSession,
remove_example_flows: bool = False,
components_only: bool = False,
get_all: bool = True,
folder_id: UUID | None = None,
params: Params = Depends(),
params: Annotated[Params, Depends()],
header_flows: bool = False,
):
"""Retrieve a list of flows with pagination support.
@ -229,9 +228,9 @@ def _read_flow(
@router.get("/{flow_id}", response_model=FlowRead, status_code=200)
def read_flow(
*,
session: Session = Depends(get_session),
session: DbSession,
flow_id: UUID,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Read a flow."""
if user_flow := _read_flow(session, flow_id, current_user, get_settings_service()):
@ -242,10 +241,10 @@ def read_flow(
@router.patch("/{flow_id}", response_model=FlowRead, status_code=200)
def update_flow(
*,
session: Session = Depends(get_session),
session: DbSession,
flow_id: UUID,
flow: FlowUpdate,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Update a flow."""
settings_service = get_settings_service()
@ -302,9 +301,9 @@ def update_flow(
@router.delete("/{flow_id}", status_code=200)
async def delete_flow(
*,
session: Session = Depends(get_session),
session: DbSession,
flow_id: UUID,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Delete a flow."""
flow = _read_flow(
@ -323,9 +322,9 @@ async def delete_flow(
@router.post("/batch/", response_model=list[FlowRead], status_code=201)
def create_flows(
*,
session: Session = Depends(get_session),
session: DbSession,
flow_list: FlowListCreate,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Create multiple new flows."""
db_flows = []
@ -343,9 +342,9 @@ def create_flows(
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: Session = Depends(get_session),
file: UploadFile = File(...),
current_user: User = Depends(get_current_active_user),
session: DbSession,
file: Annotated[UploadFile, File(...)],
current_user: CurrentActiveUser,
folder_id: UUID | None = None,
):
"""Upload flows from a file."""
@ -367,8 +366,8 @@ async def upload_file(
@router.delete("/")
async def delete_multiple_flows(
flow_ids: list[UUID],
user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[Session, Depends(get_session)],
user: CurrentActiveUser,
db: DbSession,
):
"""Delete multiple flows by their IDs.
@ -403,8 +402,8 @@ async def delete_multiple_flows(
@router.post("/download/", status_code=200)
async def download_multiple_file(
flow_ids: list[UUID],
user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[Session, Depends(get_session)],
user: CurrentActiveUser,
db: DbSession,
):
"""Download all flows as a zip file."""
flows = db.exec(select(Flow).where(and_(Flow.user_id == user.id, Flow.id.in_(flow_ids)))).all() # type: ignore[attr-defined]
@ -445,7 +444,7 @@ async def download_multiple_file(
@router.get("/basic_examples/", response_model=list[FlowRead], status_code=200)
def read_basic_examples(
*,
session: Session = Depends(get_session),
session: DbSession,
):
"""Retrieve a list of basic example flows.

View file

@ -1,17 +1,18 @@
from typing import Annotated
import orjson
from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status
from fastapi_pagination import Params
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import or_, update
from sqlmodel import Session, select
from sqlmodel import select
from langflow.api.utils import cascade_delete_flow, custom_params
from langflow.api.utils import CurrentActiveUser, DbSession, cascade_delete_flow, custom_params
from langflow.api.v1.flows import create_flows
from langflow.api.v1.schemas import FlowListCreate, FlowListReadWithFolderName
from langflow.helpers.flow import generate_unique_flow_name
from langflow.helpers.folders import generate_unique_folder_name
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.flow.model import Flow, FlowCreate, FlowRead
from langflow.services.database.models.folder.constants import DEFAULT_FOLDER_NAME
from langflow.services.database.models.folder.model import (
@ -22,8 +23,6 @@ from langflow.services.database.models.folder.model import (
FolderUpdate,
)
from langflow.services.database.models.folder.pagination_model import FolderWithPaginatedFlows
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session
router = APIRouter(prefix="/folders", tags=["Folders"])
@ -31,9 +30,9 @@ router = APIRouter(prefix="/folders", tags=["Folders"])
@router.post("/", response_model=FolderRead, status_code=201)
def create_folder(
*,
session: Session = Depends(get_session),
session: DbSession,
folder: FolderCreate,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
try:
new_folder = Folder.model_validate(folder, from_attributes=True)
@ -85,8 +84,8 @@ def create_folder(
@router.get("/", response_model=list[FolderRead], status_code=200)
def read_folders(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
session: DbSession,
current_user: CurrentActiveUser,
):
try:
folders = session.exec(
@ -103,10 +102,10 @@ def read_folders(
@router.get("/{folder_id}", response_model=FolderWithPaginatedFlows | FolderReadWithFlows, status_code=200)
def read_folder(
*,
session: Session = Depends(get_session),
session: DbSession,
folder_id: str,
current_user: User = Depends(get_current_active_user),
params: Params | None = Depends(custom_params),
current_user: CurrentActiveUser,
params: Annotated[Params | None, Depends(custom_params)],
is_component: bool = False,
is_flow: bool = False,
search: str = "",
@ -148,10 +147,10 @@ def read_folder(
@router.patch("/{folder_id}", response_model=FolderRead, status_code=200)
def update_folder(
*,
session: Session = Depends(get_session),
session: DbSession,
folder_id: str,
folder: FolderUpdate, # Assuming FolderUpdate is a Pydantic model defining updatable fields
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
try:
existing_folder = session.exec(
@ -209,9 +208,9 @@ def update_folder(
@router.delete("/{folder_id}", status_code=204)
async def delete_folder(
*,
session: Session = Depends(get_session),
session: DbSession,
folder_id: str,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
try:
flows = session.exec(select(Flow).where(Flow.folder_id == folder_id, Folder.user_id == current_user.id)).all()
@ -237,9 +236,9 @@ async def delete_folder(
@router.get("/download/{folder_id}", response_model=FlowListReadWithFolderName, status_code=200)
async def download_file(
*,
session: Session = Depends(get_session),
session: DbSession,
folder_id: str,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Download all flows from folder."""
try:
@ -258,9 +257,9 @@ async def download_file(
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: Session = Depends(get_session),
file: UploadFile = File(...),
current_user: User = Depends(get_current_active_user),
session: DbSession,
file: Annotated[UploadFile, File(...)],
current_user: CurrentActiveUser,
):
"""Upload flows from a file."""
contents = await file.read()

View file

@ -4,8 +4,8 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from langflow.api.utils import DbSession
from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
@ -15,8 +15,7 @@ from langflow.services.auth.utils import (
)
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.database.models.user.crud import get_user_by_id
from langflow.services.deps import get_session, get_settings_service, get_variable_service
from langflow.services.variable.service import VariableService
from langflow.services.deps import get_settings_service, get_variable_service
router = APIRouter(tags=["Login"])
@ -25,12 +24,9 @@ router = APIRouter(tags=["Login"])
async def login_to_get_access_token(
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_session)],
# _: Session = Depends(get_current_active_user)
settings_service=Depends(get_settings_service),
variable_service: VariableService = Depends(get_variable_service),
db: DbSession,
):
auth_settings = settings_service.auth_settings
auth_settings = get_settings_service().auth_settings
try:
user = authenticate_user(form_data.username, form_data.password, db)
except Exception as exc:
@ -70,7 +66,7 @@ async def login_to_get_access_token(
expires=None, # Set to None to make it a session cookie
domain=auth_settings.COOKIE_DOMAIN,
)
variable_service.initialize_user_variables(user.id, db)
get_variable_service().initialize_user_variables(user.id, db)
# Create default folder for user if it doesn't exist
create_default_folder_if_it_doesnt_exist(db, user.id)
return tokens
@ -82,7 +78,7 @@ async def login_to_get_access_token(
@router.get("/auto_login")
async def auto_login(response: Response, db: Annotated[Session, Depends(get_session)]):
async def auto_login(response: Response, db: DbSession):
auth_settings = get_settings_service().auth_settings
if auth_settings.AUTO_LOGIN:
@ -128,7 +124,7 @@ async def auto_login(response: Response, db: Annotated[Session, Depends(get_sess
async def refresh_token(
request: Request,
response: Response,
db: Annotated[Session, Depends(get_session)],
db: DbSession,
):
auth_settings = get_settings_service().auth_settings

View file

@ -3,8 +3,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import delete
from sqlmodel import Session, col, select
from sqlmodel import col, select
from langflow.api.utils import DbSession
from langflow.schema.message import MessageResponse
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
@ -15,16 +16,12 @@ from langflow.services.database.models.vertex_builds.crud import (
get_vertex_builds_by_flow_id,
)
from langflow.services.database.models.vertex_builds.model import VertexBuildMapModel
from langflow.services.deps import get_session
router = APIRouter(prefix="/monitor", tags=["Monitor"])
@router.get("/builds")
async def get_vertex_builds(
flow_id: Annotated[UUID, Query()],
session: Annotated[Session, Depends(get_session)],
) -> VertexBuildMapModel:
async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> VertexBuildMapModel:
try:
vertex_builds = get_vertex_builds_by_flow_id(session, flow_id)
return VertexBuildMapModel.from_list_of_dicts(vertex_builds)
@ -33,10 +30,7 @@ async def get_vertex_builds(
@router.delete("/builds", status_code=204)
async def delete_vertex_builds(
flow_id: Annotated[UUID, Query()],
session: Annotated[Session, Depends(get_session)],
) -> None:
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None:
try:
delete_vertex_builds_by_flow_id(session, flow_id)
except Exception as e:
@ -45,7 +39,7 @@ async def delete_vertex_builds(
@router.get("/messages")
async def get_messages(
session: Annotated[Session, Depends(get_session)],
session: DbSession,
flow_id: Annotated[str | None, Query()] = None,
session_id: Annotated[str | None, Query()] = None,
sender: Annotated[str | None, Query()] = None,
@ -72,10 +66,7 @@ async def get_messages(
@router.delete("/messages", status_code=204, dependencies=[Depends(get_current_active_user)])
async def delete_messages(
message_ids: list[UUID],
session: Annotated[Session, Depends(get_session)],
) -> None:
async def delete_messages(message_ids: list[UUID], session: DbSession) -> None:
try:
session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
session.commit()
@ -87,7 +78,7 @@ async def delete_messages(
async def update_message(
message_id: UUID,
message: MessageUpdate,
session: Annotated[Session, Depends(get_session)],
session: DbSession,
):
try:
db_message = session.get(MessageTable, message_id)
@ -117,7 +108,7 @@ async def update_message(
async def update_session_id(
old_session_id: str,
new_session_id: Annotated[str, Query(..., description="The new session ID to update to")],
session: Annotated[Session, Depends(get_session)],
session: DbSession,
) -> list[MessageResponse]:
try:
# Get all messages with the old session ID
@ -150,7 +141,7 @@ async def update_session_id(
@router.delete("/messages/session/{session_id}", status_code=204)
async def delete_messages_session(
session_id: str,
session: Annotated[Session, Depends(get_session)],
session: DbSession,
):
try:
session.exec(
@ -168,7 +159,7 @@ async def delete_messages_session(
@router.get("/transactions")
async def get_transactions(
flow_id: Annotated[UUID, Query()],
session: Annotated[Session, Depends(get_session)],
session: DbSession,
) -> list[TransactionReadResponse]:
try:
transactions = get_transactions_by_flow_id(session, flow_id)

View file

@ -4,9 +4,8 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from langflow.api.utils import check_langflow_version
from langflow.api.utils import CurrentActiveUser, check_langflow_version
from langflow.services.auth import utils as auth_utils
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_settings_service, get_store_service
from langflow.services.store.exceptions import CustomError
from langflow.services.store.schema import (
@ -22,7 +21,7 @@ router = APIRouter(prefix="/store", tags=["Components Store"])
def get_user_store_api_key(
user: User = Depends(auth_utils.get_current_active_user),
user: CurrentActiveUser,
settings_service=Depends(get_settings_service),
):
if not user.store_api_key:
@ -34,7 +33,7 @@ def get_user_store_api_key(
def get_optional_user_store_api_key(
user: User = Depends(auth_utils.get_current_active_user),
user: CurrentActiveUser,
settings_service=Depends(get_settings_service),
):
if not user.store_api_key:
@ -107,7 +106,7 @@ async def get_components(
fields: Annotated[list[str] | None, Query()] = None,
page: int = 1,
limit: int = 10,
store_api_key: str | None = Depends(get_optional_user_store_api_key),
store_api_key: Annotated[str | None, Depends(get_optional_user_store_api_key)],
) -> ListComponentResponseModel:
try:
return await get_store_service().get_list_component_response_model(

View file

@ -4,20 +4,20 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
get_current_active_superuser,
get_current_active_user,
get_password_hash,
verify_password,
)
from langflow.services.database.models.folder.utils import create_default_folder_if_it_doesnt_exist
from langflow.services.database.models.user import User, UserCreate, UserRead, UserUpdate
from langflow.services.database.models.user.crud import get_user_by_id, update_user
from langflow.services.deps import get_session, get_settings_service
from langflow.services.deps import get_settings_service
router = APIRouter(tags=["Users"], prefix="/users")
@ -25,7 +25,7 @@ router = APIRouter(tags=["Users"], prefix="/users")
@router.post("/", response_model=UserRead, status_code=201)
def add_user(
user: UserCreate,
session: Annotated[Session, Depends(get_session)],
session: DbSession,
) -> User:
"""Add a new user to the database."""
new_user = User.model_validate(user, from_attributes=True)
@ -47,7 +47,7 @@ def add_user(
@router.get("/whoami", response_model=UserRead)
def read_current_user(
current_user: Annotated[User, Depends(get_current_active_user)],
current_user: CurrentActiveUser,
) -> User:
"""Retrieve the current user's data."""
return current_user
@ -55,9 +55,10 @@ def read_current_user(
@router.get("/", dependencies=[Depends(get_current_active_superuser)])
def read_all_users(
*,
skip: int = 0,
limit: int = 10,
session: Session = Depends(get_session),
session: DbSession,
) -> UsersResponse:
"""Retrieve a list of users from the database with pagination."""
query: SelectOfScalar = select(User).offset(skip).limit(limit)
@ -76,11 +77,11 @@ def read_all_users(
def patch_user(
user_id: UUID,
user_update: UserUpdate,
user: Annotated[User, Depends(get_current_active_user)],
session: Annotated[Session, Depends(get_session)],
user: CurrentActiveUser,
session: DbSession,
) -> User:
"""Update an existing user's data."""
update_password = user_update.password is not None and user_update.password != ""
update_password = bool(user_update.password)
if not user.is_superuser and user_update.is_superuser:
raise HTTPException(status_code=403, detail="Permission denied")
@ -103,8 +104,8 @@ def patch_user(
def reset_password(
user_id: UUID,
user_update: UserUpdate,
user: Annotated[User, Depends(get_current_active_user)],
session: Annotated[Session, Depends(get_session)],
user: CurrentActiveUser,
session: DbSession,
) -> User:
"""Reset a user's password."""
if user_id != user.id:
@ -126,7 +127,7 @@ def reset_password(
def delete_user(
user_id: UUID,
current_user: Annotated[User, Depends(get_current_active_superuser)],
session: Annotated[Session, Depends(get_session)],
session: DbSession,
) -> dict:
"""Delete a user from the database."""
if current_user.id == user_id:

View file

@ -1,14 +1,11 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, HTTPException
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_variable_service
from langflow.services.variable.base import VariableService
from langflow.services.deps import get_variable_service
from langflow.services.variable.constants import GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
@ -18,9 +15,9 @@ router = APIRouter(prefix="/variables", tags=["Variables"])
@router.post("/", response_model=VariableRead, status_code=201)
def create_variable(
*,
session: Session = Depends(get_session),
session: DbSession,
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
current_user: CurrentActiveUser,
):
"""Create a new variable."""
variable_service = get_variable_service()
@ -53,11 +50,14 @@ def create_variable(
@router.get("/", response_model=list[VariableRead], status_code=200)
def read_variables(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
session: DbSession,
current_user: CurrentActiveUser,
):
"""Read all variables."""
variable_service = get_variable_service()
if not isinstance(variable_service, DatabaseVariableService):
msg = "Variable service is not an instance of DatabaseVariableService"
raise TypeError(msg)
try:
return variable_service.get_all(user_id=current_user.id, session=session)
except Exception as e:
@ -67,13 +67,16 @@ def read_variables(
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
def update_variable(
*,
session: Session = Depends(get_session),
session: DbSession,
variable_id: UUID,
variable: VariableUpdate,
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
current_user: CurrentActiveUser,
):
"""Update a variable."""
variable_service = get_variable_service()
if not isinstance(variable_service, DatabaseVariableService):
msg = "Variable service is not an instance of DatabaseVariableService"
raise TypeError(msg)
try:
return variable_service.update_variable_fields(
user_id=current_user.id,
@ -91,12 +94,12 @@ def update_variable(
@router.delete("/{variable_id}", status_code=204)
def delete_variable(
*,
session: Session = Depends(get_session),
session: DbSession,
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: VariableService = Depends(get_variable_service),
current_user: CurrentActiveUser,
) -> None:
"""Delete a variable."""
variable_service = get_variable_service()
try:
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
except Exception as e:

View file

@ -56,7 +56,7 @@ def get_chat_result(
inputs: list | dict = messages or {}
try:
if config and config.get("output_parser") is not None:
runnable = runnable | config["output_parser"]
runnable |= config["output_parser"]
if config:
runnable = runnable.with_config(

View file

@ -104,7 +104,7 @@ class AstraAssistantManager(ComponentWithCache):
logger.info(self.tool)
tools = []
tool_obj = None
if self.tool is not None and self.tool != "":
if self.tool:
tool_cls = tools_and_names[self.tool]
tool_obj = tool_cls()
tools.append(tool_obj)

View file

@ -36,7 +36,7 @@ class RedisIndexChatMemory(LCChatMemoryComponent):
password: str | None = self.password
if self.key_prefix:
kwargs["key_prefix"] = self.key_prefix
if password is not None and password != "":
if password:
password = parse.quote_plus(password)
url = f"redis://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"

View file

@ -111,7 +111,7 @@ class AstraDBCQLToolComponent(LCToolComponent):
url = f'{astra_url}{"/".join(key)}?page-size={self.number_of_results}'
if self.projection_fields != "*":
url += f'&fields={urllib.parse.quote(self.projection_fields.replace(" ","")) }'
url += f'&fields={urllib.parse.quote(self.projection_fields.replace(" ", ""))}'
res = requests.request("GET", url=url, headers=headers)

View file

@ -147,9 +147,9 @@ class PythonCodeStructuredTool(LCToolComponent):
params: dict = {}
def run(**kwargs):
for key in kwargs:
for key, arg in kwargs.items():
if key not in PythonCodeToolFunc.params:
PythonCodeToolFunc.params[key] = kwargs[key]
PythonCodeToolFunc.params[key] = arg
return _local_namespace[self.tool_function](**PythonCodeToolFunc.params)
_globals = globals()

View file

@ -15,8 +15,11 @@ from .WikipediaAPI import WikipediaAPIComponent
from .WolframAlphaAPI import WolframAlphaAPIComponent
__all__ = [
"AstraDBCQLToolComponent",
"AstraDBToolComponent",
"BingSearchAPIComponent",
"CalculatorToolComponent",
"CalculatorToolComponent",
"GleanSearchAPIComponent",
"GoogleSearchAPIComponent",
"GoogleSerperAPIComponent",
@ -28,7 +31,4 @@ __all__ = [
"SerpAPIComponent",
"WikipediaAPIComponent",
"WolframAlphaAPIComponent",
"CalculatorToolComponent",
"AstraDBToolComponent",
"AstraDBCQLToolComponent",
]

View file

@ -178,7 +178,7 @@ class ElasticsearchVectorStoreComponent(LCVectorStoreComponent):
if query:
search_type = self.search_type.lower()
if search_type not in ["similarity", "mmr"]:
if search_type not in {"similarity", "mmr"}:
msg = f"Invalid search type: {self.search_type}"
logger.error(msg)
raise ValueError(msg)

View file

@ -1416,6 +1416,7 @@ class Graph:
def get_vertex_edges(
self,
vertex_id: str,
*,
is_target: bool | None = None,
is_source: bool | None = None,
) -> list[CycleEdge]:
@ -1887,11 +1888,11 @@ class Graph:
def sort_chat_inputs_first(self, vertices_layers: list[list[str]]) -> list[list[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
if "ChatInput" in vertex_id:
# Remove the ChatInput from the layer
layer.remove(vertex_id)
chat_inputs_first.append(vertex_id)
layer_chat_inputs_first = [vertex_id for vertex_id in layer if "ChatInput" in vertex_id]
chat_inputs_first.extend(layer_chat_inputs_first)
for vertex_id in layer_chat_inputs_first:
# Remove the ChatInput from the layer
layer.remove(vertex_id)
if not chat_inputs_first:
return vertices_layers

View file

@ -413,7 +413,7 @@ class Vertex:
if isinstance(val, bool):
params[field_name] = val
elif isinstance(val, str):
params[field_name] = val != ""
params[field_name] = bool(val)
elif field.get("type") == "table" and val is not None:
# check if the value is a list of dicts
# if it is, create a pandas dataframe from it

View file

@ -156,10 +156,9 @@ class ComponentVertex(Vertex):
List[str]: The extracted messages.
"""
messages = []
for key in artifacts:
artifact = artifacts[key]
for key, artifact in artifacts.items():
if any(
key not in artifact for key in ["text", "sender", "sender_name", "session_id", "stream_url"]
k not in artifact for k in ["text", "sender", "sender_name", "session_id", "stream_url"]
) and not isinstance(artifact, Message):
continue
message_dict = artifact if isinstance(artifact, dict) else artifact.model_dump()

View file

@ -29,14 +29,17 @@ __all__ = [
"CodeInput",
"DataInput",
"DefaultPromptField",
"DefaultPromptField",
"DictInput",
"DropdownInput",
"FileInput",
"FloatInput",
"HandleInput",
"Input",
"Input",
"IntInput",
"LinkInput",
"LinkInput",
"MessageInput",
"MessageTextInput",
"MultilineInput",
@ -45,10 +48,7 @@ __all__ = [
"NestedDictInput",
"PromptInput",
"SecretStrInput",
"SliderInput",
"StrInput",
"TableInput",
"Input",
"DefaultPromptField",
"LinkInput",
"SliderInput",
]

View file

@ -29,6 +29,7 @@ __all__ = [
"CodeInput",
"DataInput",
"DefaultPromptField",
"DefaultPromptField",
"DictInput",
"DropdownInput",
"FileInput",
@ -36,6 +37,7 @@ __all__ = [
"HandleInput",
"IntInput",
"LinkInput",
"LinkInput",
"MessageInput",
"MessageTextInput",
"MultilineInput",
@ -45,9 +47,7 @@ __all__ = [
"Output",
"PromptInput",
"SecretStrInput",
"SliderInput",
"StrInput",
"TableInput",
"DefaultPromptField",
"LinkInput",
"SliderInput",
]

View file

@ -33,9 +33,9 @@ MINIMUM_KEY_LENGTH = 32
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
async def api_key_security(
query_param: str = Security(api_key_query),
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[Session, Depends(get_session)],
) -> UserRead | None:
settings_service = get_settings_service()
result: ApiKey | User | None = None
@ -75,10 +75,10 @@ async def api_key_security(
async def get_current_user(
token: str = Security(oauth2_login),
query_param: str = Security(api_key_query),
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
token: Annotated[str, Security(oauth2_login)],
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[Session, Depends(get_session)],
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
@ -94,7 +94,7 @@ async def get_current_user(
async def get_current_user_by_jwt(
token: Annotated[str, Depends(oauth2_login)],
db: Session = Depends(get_session),
db: Annotated[Session, Depends(get_session)],
) -> User:
settings_service = get_settings_service()
@ -155,8 +155,8 @@ async def get_current_user_by_jwt(
async def get_current_user_for_websocket(
websocket: WebSocket,
db: Session = Depends(get_session),
query_param: str = Security(api_key_query),
db: Annotated[Session, Depends(get_session)],
query_param: Annotated[str, Security(api_key_query)],
) -> User | None:
token = websocket.query_params.get("token")
api_key = websocket.query_params.get("x-api-key")

View file

@ -1,3 +1,5 @@
from typing import Annotated
from fastapi import Depends
from sqlmodel import Session
@ -7,7 +9,7 @@ from langflow.utils.version import get_version_info
from .model import Flow
def get_flow_by_id(session: Session = Depends(get_session), flow_id: str | None = None) -> Flow | None:
def get_flow_by_id(session: Annotated[Session, Depends(get_session)], flow_id: str | None = None) -> Flow | None:
"""Get flow by id."""
if flow_id is None:
msg = "Flow id is required."