From 4cc336fa45e138fad94228382cbfb3c43cf7a6bb Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 12 Dec 2024 13:28:32 +0100 Subject: [PATCH] ref: Use async list_variables (#5224) Use async list_variables --- .../custom_component/custom_component.py | 8 +++--- .../base/langflow/services/variable/base.py | 27 ------------------- .../langflow/services/variable/kubernetes.py | 14 +++------- .../langflow/services/variable/service.py | 13 ++++----- 4 files changed, 13 insertions(+), 49 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index c0fd80931..b51cc287b 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from langflow.custom.custom_component.base_component import BaseComponent from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.schema import Data -from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope +from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service from langflow.services.storage.service import StorageService from langflow.template.utils import update_frontend_node_with_template_values from langflow.type_extraction.type_extraction import post_process_type @@ -442,7 +442,7 @@ class CustomComponent(BaseComponent): user_id = self.user_id or "" return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) - def list_key_names(self): + async def list_key_names(self): """Lists the names of the variables for the current user. Raises: @@ -456,8 +456,8 @@ class CustomComponent(BaseComponent): raise ValueError(msg) variable_service = get_variable_service() - with session_scope() as session: - return variable_service.list_variables_sync(user_id=self.user_id, session=session) + async with async_session_scope() as session: + return await variable_service.list_variables(user_id=self.user_id, session=session) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable. diff --git a/src/backend/base/langflow/services/variable/base.py b/src/backend/base/langflow/services/variable/base.py index 0e99b2578..8c9bce10c 100644 --- a/src/backend/base/langflow/services/variable/base.py +++ b/src/backend/base/langflow/services/variable/base.py @@ -1,7 +1,6 @@ import abc from uuid import UUID -from sqlmodel import Session from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.base import Service @@ -23,19 +22,6 @@ class VariableService(Service): """ @abc.abstractmethod - def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str: - """Get a variable value. - - Args: - user_id: The user ID. - name: The name of the variable. - field: The field of the variable. - session: The database session. - - Returns: - The value of the variable. - """ - async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str: """Async get a variable value. @@ -48,20 +34,8 @@ class VariableService(Service): Returns: The value of the variable. """ - return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_)) @abc.abstractmethod - def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: - """List all variables. - - Args: - user_id: The user ID. - session: The database session. - - Returns: - A list of variable names. - """ - async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: """List all variables. @@ -72,7 +46,6 @@ class VariableService(Service): Returns: A list of variable names. """ - return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_)) @abc.abstractmethod async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable: diff --git a/src/backend/base/langflow/services/variable/kubernetes.py b/src/backend/base/langflow/services/variable/kubernetes.py index 5dc8d069a..8efb67d52 100644 --- a/src/backend/base/langflow/services/variable/kubernetes.py +++ b/src/backend/base/langflow/services/variable/kubernetes.py @@ -79,15 +79,9 @@ class KubernetesSecretService(VariableService, Service): raise ValueError(msg) @override - def get_variable_sync( - self, - user_id: UUID | str, - name: str, - field: str, - session: Session, - ) -> str: + async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str: secret_name = encode_user_id(user_id) - key, value = self.resolve_variable(secret_name, user_id, name) + key, value = await asyncio.to_thread(self.resolve_variable, secret_name, user_id, name) if key.startswith(CREDENTIAL_TYPE + "_") and field == "session_id": msg = ( f"variable {name} of type 'Credential' cannot be used in a Session ID field " @@ -97,12 +91,12 @@ class KubernetesSecretService(VariableService, Service): return value @override - def list_variables_sync( + async def list_variables( self, user_id: UUID | str, session: Session, ) -> list[str | None]: - variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id)) + variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=encode_user_id(user_id)) if not variables: return [] diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index 0053ed69d..1a4f6000c 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING from loguru import logger -from sqlmodel import Session, select +from sqlmodel import select from langflow.services.auth import utils as auth_utils from langflow.services.base import Service @@ -53,16 +53,17 @@ class DatabaseVariableService(VariableService, Service): except Exception as e: # noqa: BLE001 logger.exception(f"Error processing {var_name} variable: {e!s}") - def get_variable_sync( + async def get_variable( self, user_id: UUID | str, name: str, field: str, - session: Session, + session: AsyncSession, ) -> str: # we get the credential from the database # credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first() - variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first() + stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name) + variable = (await session.exec(stmt)).first() if not variable or not variable.value: msg = f"{name} variable not found." @@ -82,10 +83,6 @@ class DatabaseVariableService(VariableService, Service): stmt = select(Variable).where(Variable.user_id == user_id) return list((await session.exec(stmt)).all()) - def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: - variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all() - return [variable.name for variable in variables if variable] - async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: variables = await self.get_all(user_id=user_id, session=session) return [variable.name for variable in variables if variable]