ref: Remove sync get_session and DbSession (#5146)
Remove sync get_session and DbSession
This commit is contained in:
parent
d338a3e86f
commit
9270f1a0f8
16 changed files with 71 additions and 99 deletions
|
|
@ -5,7 +5,7 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import AsyncDbSession
|
||||
from langflow.api.utils import DbSession
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.deps import get_chat_service
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ async def health():
|
|||
# It's a reliable health check for a langflow instance
|
||||
@health_check_router.get("/health_check")
|
||||
async def health_check(
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> HealthResponse:
|
||||
response = HealthResponse()
|
||||
# use a fixed valid UUId that UUID collision is very unlikely
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from fastapi import Depends, HTTPException, Query
|
|||
from fastapi_pagination import Params
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
|
@ -17,7 +16,7 @@ from langflow.services.database.models import User
|
|||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.transactions.model import TransactionTable
|
||||
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
|
||||
from langflow.services.deps import async_session_scope, get_async_session, get_session
|
||||
from langflow.services.deps import async_session_scope, get_session
|
||||
from langflow.services.store.utils import get_lf_version_from_pypi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -31,8 +30,7 @@ MAX_PAGE_SIZE = 50
|
|||
MIN_PAGE_SIZE = 1
|
||||
|
||||
CurrentActiveUser = Annotated[User, Depends(get_current_active_user)]
|
||||
DbSession = Annotated[Session, Depends(get_session)]
|
||||
AsyncDbSession = Annotated[AsyncSession, Depends(get_async_session)]
|
||||
DbSession = Annotated[AsyncSession, Depends(get_session)]
|
||||
|
||||
|
||||
def has_api_terms(word: str):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from uuid import UUID
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ router = APIRouter(tags=["APIKey"], prefix="/api_key")
|
|||
|
||||
@router.get("/")
|
||||
async def get_api_keys_route(
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
current_user: CurrentActiveUser,
|
||||
) -> ApiKeysResponse:
|
||||
try:
|
||||
|
|
@ -32,7 +32,7 @@ async def get_api_keys_route(
|
|||
async def create_api_key_route(
|
||||
req: ApiKeyCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
) -> UnmaskedApiKeyRead:
|
||||
try:
|
||||
user_id = current_user.id
|
||||
|
|
@ -44,7 +44,7 @@ async def create_api_key_route(
|
|||
@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
|
||||
async def delete_api_key_route(
|
||||
api_key_id: UUID,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
try:
|
||||
await delete_api_key(db, api_key_id)
|
||||
|
|
@ -58,7 +58,7 @@ async def save_store_api_key(
|
|||
api_key_request: ApiKeyCreateRequest,
|
||||
response: Response,
|
||||
current_user: CurrentActiveUser,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ from starlette.responses import ContentStream
|
|||
from starlette.types import Receive
|
||||
|
||||
from langflow.api.utils import (
|
||||
AsyncDbSession,
|
||||
CurrentActiveUser,
|
||||
DbSession,
|
||||
build_and_cache_graph_from_data,
|
||||
build_graph_from_data,
|
||||
build_graph_from_db,
|
||||
|
|
@ -44,7 +44,7 @@ from langflow.schema.schema import OutputValue
|
|||
from langflow.services.cache.utils import CacheMiss
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.deps import async_session_scope, get_async_session, get_chat_service, get_telemetry_service
|
||||
from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service
|
||||
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -77,7 +77,7 @@ async def retrieve_vertices_order(
|
|||
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
|
||||
stop_component_id: str | None = None,
|
||||
start_component_id: str | None = None,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> VerticesOrderResponse:
|
||||
"""Retrieve the vertices order for a given flow.
|
||||
|
||||
|
|
@ -153,7 +153,7 @@ async def build_flow(
|
|||
start_component_id: str | None = None,
|
||||
log_builds: bool | None = True,
|
||||
current_user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
):
|
||||
chat_service = get_chat_service()
|
||||
telemetry_service = get_telemetry_service()
|
||||
|
|
@ -512,7 +512,7 @@ async def build_vertex(
|
|||
# If there's no cache
|
||||
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
|
||||
graph: Graph = await build_graph_from_db(
|
||||
flow_id=flow_id_str, session=await anext(get_async_session()), chat_service=chat_service
|
||||
flow_id=flow_id_str, session=await anext(get_session()), chat_service=chat_service
|
||||
)
|
||||
else:
|
||||
graph = cache.get("result")
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from fastapi import (
|
|||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, parse_value
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession, parse_value
|
||||
from langflow.api.v1.schemas import (
|
||||
ConfigResponse,
|
||||
CustomComponentRequest,
|
||||
|
|
@ -379,7 +379,7 @@ async def webhook_run_flow(
|
|||
)
|
||||
async def experimental_run_flow(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_id: UUID,
|
||||
inputs: list[InputValueRequest] | None = None,
|
||||
outputs: list[str] | None = None,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from uuid import UUID
|
|||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.api.v1.schemas import UploadFileResponse
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.deps import get_storage_service
|
||||
|
|
@ -25,7 +25,7 @@ router = APIRouter(tags=["Files"], prefix="/files")
|
|||
async def get_flow_id(
|
||||
flow_id: UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
):
|
||||
flow_id_str = str(flow_id)
|
||||
# AttributeError: 'SelectOfScalar' object has no attribute 'first'
|
||||
|
|
@ -43,7 +43,7 @@ async def upload_file(
|
|||
file: UploadFile,
|
||||
flow_id: Annotated[UUID, Depends(get_flow_id)],
|
||||
current_user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
storage_service: Annotated[StorageService, Depends(get_storage_service)],
|
||||
) -> UploadFileResponse:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ from sqlmodel import and_, col, select
|
|||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.api.utils import (
|
||||
AsyncDbSession,
|
||||
CurrentActiveUser,
|
||||
DbSession,
|
||||
cascade_delete_flow,
|
||||
remove_api_keys,
|
||||
validate_is_component,
|
||||
|
|
@ -124,7 +124,7 @@ async def _new_flow(
|
|||
@router.post("/", response_model=FlowRead, status_code=201)
|
||||
async def create_flow(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow: FlowCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -154,7 +154,7 @@ async def create_flow(
|
|||
async def read_flows(
|
||||
*,
|
||||
current_user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
remove_example_flows: bool = False,
|
||||
components_only: bool = False,
|
||||
get_all: bool = True,
|
||||
|
|
@ -261,7 +261,7 @@ async def _read_flow(
|
|||
@router.get("/{flow_id}", response_model=FlowRead, status_code=200)
|
||||
async def read_flow(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_id: UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -274,7 +274,7 @@ async def read_flow(
|
|||
@router.patch("/{flow_id}", response_model=FlowRead, status_code=200)
|
||||
async def update_flow(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_id: UUID,
|
||||
flow: FlowUpdate,
|
||||
current_user: CurrentActiveUser,
|
||||
|
|
@ -334,7 +334,7 @@ async def update_flow(
|
|||
@router.delete("/{flow_id}", status_code=200)
|
||||
async def delete_flow(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_id: UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -355,7 +355,7 @@ async def delete_flow(
|
|||
@router.post("/batch/", response_model=list[FlowRead], status_code=201)
|
||||
async def create_flows(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_list: FlowListCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -375,7 +375,7 @@ async def create_flows(
|
|||
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
|
||||
async def upload_file(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
file: Annotated[UploadFile, File(...)],
|
||||
current_user: CurrentActiveUser,
|
||||
folder_id: UUID | None = None,
|
||||
|
|
@ -420,7 +420,7 @@ async def upload_file(
|
|||
async def delete_multiple_flows(
|
||||
flow_ids: list[UUID],
|
||||
user: CurrentActiveUser,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
"""Delete multiple flows by their IDs.
|
||||
|
||||
|
|
@ -458,7 +458,7 @@ async def delete_multiple_flows(
|
|||
async def download_multiple_file(
|
||||
flow_ids: list[UUID],
|
||||
user: CurrentActiveUser,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
"""Download all flows as a zip file."""
|
||||
flows = (await db.exec(select(Flow).where(and_(Flow.user_id == user.id, Flow.id.in_(flow_ids))))).all() # type: ignore[attr-defined]
|
||||
|
|
@ -499,7 +499,7 @@ async def download_multiple_file(
|
|||
@router.get("/basic_examples/", response_model=list[FlowRead], status_code=200)
|
||||
async def read_basic_examples(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
):
|
||||
"""Retrieve a list of basic example flows.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from sqlalchemy import or_, update
|
|||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, cascade_delete_flow, custom_params, remove_api_keys
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession, cascade_delete_flow, custom_params, remove_api_keys
|
||||
from langflow.api.v1.flows import create_flows
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.helpers.flow import generate_unique_flow_name
|
||||
|
|
@ -37,7 +37,7 @@ router = APIRouter(prefix="/folders", tags=["Folders"])
|
|||
@router.post("/", response_model=FolderRead, status_code=201)
|
||||
async def create_folder(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
folder: FolderCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -93,7 +93,7 @@ async def create_folder(
|
|||
@router.get("/", response_model=list[FolderRead], status_code=200)
|
||||
async def read_folders(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
try:
|
||||
|
|
@ -113,7 +113,7 @@ async def read_folders(
|
|||
@router.get("/{folder_id}", response_model=FolderWithPaginatedFlows | FolderReadWithFlows, status_code=200)
|
||||
async def read_folder(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
folder_id: str,
|
||||
current_user: CurrentActiveUser,
|
||||
params: Annotated[Params | None, Depends(custom_params)],
|
||||
|
|
@ -164,7 +164,7 @@ async def read_folder(
|
|||
@router.patch("/{folder_id}", response_model=FolderRead, status_code=200)
|
||||
async def update_folder(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
folder_id: str,
|
||||
folder: FolderUpdate, # Assuming FolderUpdate is a Pydantic model defining updatable fields
|
||||
current_user: CurrentActiveUser,
|
||||
|
|
@ -225,7 +225,7 @@ async def update_folder(
|
|||
@router.delete("/{folder_id}", status_code=204)
|
||||
async def delete_folder(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
folder_id: str,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -257,7 +257,7 @@ async def delete_folder(
|
|||
@router.get("/download/{folder_id}", status_code=200)
|
||||
async def download_file(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
folder_id: str,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -305,7 +305,7 @@ async def download_file(
|
|||
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
|
||||
async def upload_file(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
file: Annotated[UploadFile, File(...)],
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Annotated
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from langflow.api.utils import AsyncDbSession
|
||||
from langflow.api.utils import DbSession
|
||||
from langflow.api.v1.schemas import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
|
|
@ -24,7 +24,7 @@ router = APIRouter(tags=["Login"])
|
|||
async def login_to_get_access_token(
|
||||
response: Response,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
try:
|
||||
|
|
@ -78,7 +78,7 @@ async def login_to_get_access_token(
|
|||
|
||||
|
||||
@router.get("/auto_login")
|
||||
async def auto_login(response: Response, db: AsyncDbSession):
|
||||
async def auto_login(response: Response, db: DbSession):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
|
||||
if auth_settings.AUTO_LOGIN:
|
||||
|
|
@ -124,7 +124,7 @@ async def auto_login(response: Response, db: AsyncDbSession):
|
|||
async def refresh_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: AsyncDbSession,
|
||||
db: DbSession,
|
||||
):
|
||||
auth_settings = get_settings_service().auth_settings
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
|||
from sqlalchemy import delete
|
||||
from sqlmodel import col, select
|
||||
|
||||
from langflow.api.utils import AsyncDbSession
|
||||
from langflow.api.utils import DbSession
|
||||
from langflow.schema.message import MessageResponse
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
|
||||
|
|
@ -21,7 +21,7 @@ router = APIRouter(prefix="/monitor", tags=["Monitor"])
|
|||
|
||||
|
||||
@router.get("/builds")
|
||||
async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> VertexBuildMapModel:
|
||||
async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> VertexBuildMapModel:
|
||||
try:
|
||||
vertex_builds = await get_vertex_builds_by_flow_id(session, flow_id)
|
||||
return VertexBuildMapModel.from_list_of_dicts(vertex_builds)
|
||||
|
|
@ -30,7 +30,7 @@ async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbS
|
|||
|
||||
|
||||
@router.delete("/builds", status_code=204)
|
||||
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> None:
|
||||
async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None:
|
||||
try:
|
||||
await delete_vertex_builds_by_flow_id(session, flow_id)
|
||||
except Exception as e:
|
||||
|
|
@ -39,7 +39,7 @@ async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: Async
|
|||
|
||||
@router.get("/messages")
|
||||
async def get_messages(
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
flow_id: Annotated[str | None, Query()] = None,
|
||||
session_id: Annotated[str | None, Query()] = None,
|
||||
sender: Annotated[str | None, Query()] = None,
|
||||
|
|
@ -66,7 +66,7 @@ async def get_messages(
|
|||
|
||||
|
||||
@router.delete("/messages", status_code=204, dependencies=[Depends(get_current_active_user)])
|
||||
async def delete_messages(message_ids: list[UUID], session: AsyncDbSession) -> None:
|
||||
async def delete_messages(message_ids: list[UUID], session: DbSession) -> None:
|
||||
try:
|
||||
await session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
|
||||
await session.commit()
|
||||
|
|
@ -78,7 +78,7 @@ async def delete_messages(message_ids: list[UUID], session: AsyncDbSession) -> N
|
|||
async def update_message(
|
||||
message_id: UUID,
|
||||
message: MessageUpdate,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
):
|
||||
try:
|
||||
db_message = await session.get(MessageTable, message_id)
|
||||
|
|
@ -108,7 +108,7 @@ async def update_message(
|
|||
async def update_session_id(
|
||||
old_session_id: str,
|
||||
new_session_id: Annotated[str, Query(..., description="The new session ID to update to")],
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> list[MessageResponse]:
|
||||
try:
|
||||
# Get all messages with the old session ID
|
||||
|
|
@ -141,7 +141,7 @@ async def update_session_id(
|
|||
@router.delete("/messages/session/{session_id}", status_code=204)
|
||||
async def delete_messages_session(
|
||||
session_id: str,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
):
|
||||
try:
|
||||
await session.exec(
|
||||
|
|
@ -159,7 +159,7 @@ async def delete_messages_session(
|
|||
@router.get("/transactions")
|
||||
async def get_transactions(
|
||||
flow_id: Annotated[UUID, Query()],
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> list[TransactionReadResponse]:
|
||||
try:
|
||||
transactions = await get_transactions_by_flow_id(session, flow_id)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError
|
|||
from sqlmodel import select
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
|
|
@ -25,7 +25,7 @@ router = APIRouter(tags=["Users"], prefix="/users")
|
|||
@router.post("/", response_model=UserRead, status_code=201)
|
||||
async def add_user(
|
||||
user: UserCreate,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> User:
|
||||
"""Add a new user to the database."""
|
||||
new_user = User.model_validate(user, from_attributes=True)
|
||||
|
|
@ -58,7 +58,7 @@ async def read_all_users(
|
|||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> UsersResponse:
|
||||
"""Retrieve a list of users from the database with pagination."""
|
||||
query: SelectOfScalar = select(User).offset(skip).limit(limit)
|
||||
|
|
@ -78,7 +78,7 @@ async def patch_user(
|
|||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> User:
|
||||
"""Update an existing user's data."""
|
||||
update_password = bool(user_update.password)
|
||||
|
|
@ -105,7 +105,7 @@ async def reset_password(
|
|||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: CurrentActiveUser,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> User:
|
||||
"""Reset a user's password."""
|
||||
if user_id != user.id:
|
||||
|
|
@ -127,7 +127,7 @@ async def reset_password(
|
|||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_superuser)],
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
) -> dict:
|
||||
"""Delete a user from the database."""
|
||||
if current_user.id == user_id:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from uuid import UUID
|
|||
from fastapi import APIRouter, HTTPException
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser
|
||||
from langflow.api.utils import CurrentActiveUser, DbSession
|
||||
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
|
||||
from langflow.services.deps import get_variable_service
|
||||
from langflow.services.variable.constants import GENERIC_TYPE
|
||||
|
|
@ -15,7 +15,7 @@ router = APIRouter(prefix="/variables", tags=["Variables"])
|
|||
@router.post("/", response_model=VariableRead, status_code=201)
|
||||
async def create_variable(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
variable: VariableCreate,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
|
|
@ -50,7 +50,7 @@ async def create_variable(
|
|||
@router.get("/", response_model=list[VariableRead], status_code=200)
|
||||
async def read_variables(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
current_user: CurrentActiveUser,
|
||||
):
|
||||
"""Read all variables."""
|
||||
|
|
@ -67,7 +67,7 @@ async def read_variables(
|
|||
@router.patch("/{variable_id}", response_model=VariableRead, status_code=200)
|
||||
async def update_variable(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
variable_id: UUID,
|
||||
variable: VariableUpdate,
|
||||
current_user: CurrentActiveUser,
|
||||
|
|
@ -94,7 +94,7 @@ async def update_variable(
|
|||
@router.delete("/{variable_id}", status_code=204)
|
||||
async def delete_variable(
|
||||
*,
|
||||
session: AsyncDbSession,
|
||||
session: DbSession,
|
||||
variable_id: UUID,
|
||||
current_user: CurrentActiveUser,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from starlette.websockets import WebSocket
|
|||
from langflow.services.database.models.api_key.crud import check_key
|
||||
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
|
||||
from langflow.services.database.models.user.model import User, UserRead
|
||||
from langflow.services.deps import get_async_session, get_db_service, get_settings_service
|
||||
from langflow.services.deps import get_db_service, get_session, get_settings_service
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -79,7 +79,7 @@ async def get_current_user(
|
|||
token: Annotated[str, Security(oauth2_login)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
header_param: Annotated[str, Security(api_key_header)],
|
||||
db: Annotated[AsyncSession, Depends(get_async_session)],
|
||||
db: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> User:
|
||||
if token:
|
||||
return await get_current_user_by_jwt(token, db)
|
||||
|
|
@ -156,7 +156,7 @@ async def get_current_user_by_jwt(
|
|||
|
||||
async def get_current_user_for_websocket(
|
||||
websocket: WebSocket,
|
||||
db: Annotated[AsyncSession, Depends(get_async_session)],
|
||||
db: Annotated[AsyncSession, Depends(get_session)],
|
||||
query_param: Annotated[str, Security(api_key_query)],
|
||||
) -> User | None:
|
||||
token = websocket.query_params.get("token")
|
||||
|
|
|
|||
|
|
@ -1,23 +1,8 @@
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.deps import get_session
|
||||
from langflow.utils.version import get_version_info
|
||||
|
||||
from .model import Flow
|
||||
|
||||
|
||||
def get_flow_by_id(session: Annotated[Session, Depends(get_session)], flow_id: str | None = None) -> Flow | None:
|
||||
"""Get flow by id."""
|
||||
if flow_id is None:
|
||||
msg = "Flow id is required."
|
||||
raise ValueError(msg)
|
||||
|
||||
return session.get(Flow, flow_id)
|
||||
|
||||
|
||||
def get_webhook_component_in_flow(flow_data: dict):
|
||||
"""Get webhook component in flow data."""
|
||||
for node in flow_data.get("nodes", []):
|
||||
|
|
|
|||
|
|
@ -142,22 +142,11 @@ def get_db_service() -> DatabaseService:
|
|||
return get_service(ServiceType.DATABASE_SERVICE, DatabaseServiceFactory())
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Retrieves a session from the database service.
|
||||
|
||||
Yields:
|
||||
Session: A session object.
|
||||
|
||||
"""
|
||||
with get_db_service().with_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Retrieves an async session from the database service.
|
||||
|
||||
Yields:
|
||||
Session: An async session object.
|
||||
AsyncSession: An async session object.
|
||||
|
||||
"""
|
||||
async with get_db_service().with_async_session() as session:
|
||||
|
|
@ -173,7 +162,7 @@ def session_scope() -> Generator[Session, None, None]:
|
|||
and rolled back if an exception is raised.
|
||||
|
||||
Yields:
|
||||
session: The session object.
|
||||
Session: The session object.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the session scope.
|
||||
|
|
@ -199,7 +188,7 @@ async def async_session_scope() -> AsyncGenerator[AsyncSession, None]:
|
|||
and rolled back if an exception is raised.
|
||||
|
||||
Yields:
|
||||
session: The async session object.
|
||||
AsyncSession: The async session object.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the session scope.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langflow.graph.graph.base import Graph
|
|||
from langflow.graph.utils import log_vertex_build
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
from langflow.services.deps import get_async_session
|
||||
from langflow.services.deps import get_session
|
||||
|
||||
|
||||
def set_socketio_server(socketio_server) -> None:
|
||||
|
|
@ -23,7 +23,7 @@ def set_socketio_server(socketio_server) -> None:
|
|||
|
||||
async def get_vertices(sio, sid, flow_id, chat_service) -> None:
|
||||
try:
|
||||
session = await anext(get_async_session())
|
||||
session = await anext(get_session())
|
||||
stmt = select(Flow).where(Flow.id == flow_id)
|
||||
flow: Flow = (await session.exec(stmt)).first()
|
||||
if not flow or not flow.data:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue