ref: Use async list_variables (#5224)

Use async list_variables
This commit is contained in:
Christophe Bornet 2024-12-12 13:28:32 +01:00 committed by GitHub
commit 4cc336fa45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 13 additions and 49 deletions

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel
from langflow.custom.custom_component.base_component import BaseComponent
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.schema import Data
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service
from langflow.services.storage.service import StorageService
from langflow.template.utils import update_frontend_node_with_template_values
from langflow.type_extraction.type_extraction import post_process_type
@ -442,7 +442,7 @@ class CustomComponent(BaseComponent):
user_id = self.user_id or ""
return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
def list_key_names(self):
async def list_key_names(self):
"""Lists the names of the variables for the current user.
Raises:
@ -456,8 +456,8 @@ class CustomComponent(BaseComponent):
raise ValueError(msg)
variable_service = get_variable_service()
with session_scope() as session:
return variable_service.list_variables_sync(user_id=self.user_id, session=session)
async with async_session_scope() as session:
return await variable_service.list_variables(user_id=self.user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable.

View file

@ -1,7 +1,6 @@
import abc
from uuid import UUID
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.services.base import Service
@ -23,19 +22,6 @@ class VariableService(Service):
"""
@abc.abstractmethod
def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
"""Get a variable value.
Args:
user_id: The user ID.
name: The name of the variable.
field: The field of the variable.
session: The database session.
Returns:
The value of the variable.
"""
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
"""Async get a variable value.
@ -48,20 +34,8 @@ class VariableService(Service):
Returns:
The value of the variable.
"""
return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_))
@abc.abstractmethod
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
"""List all variables.
Args:
user_id: The user ID.
session: The database session.
Returns:
A list of variable names.
"""
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
"""List all variables.
@ -72,7 +46,6 @@ class VariableService(Service):
Returns:
A list of variable names.
"""
return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_))
@abc.abstractmethod
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:

View file

@ -79,15 +79,9 @@ class KubernetesSecretService(VariableService, Service):
raise ValueError(msg)
@override
def get_variable_sync(
self,
user_id: UUID | str,
name: str,
field: str,
session: Session,
) -> str:
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
secret_name = encode_user_id(user_id)
key, value = self.resolve_variable(secret_name, user_id, name)
key, value = await asyncio.to_thread(self.resolve_variable, secret_name, user_id, name)
if key.startswith(CREDENTIAL_TYPE + "_") and field == "session_id":
msg = (
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
@ -97,12 +91,12 @@ class KubernetesSecretService(VariableService, Service):
return value
@override
def list_variables_sync(
async def list_variables(
self,
user_id: UUID | str,
session: Session,
) -> list[str | None]:
variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id))
variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=encode_user_id(user_id))
if not variables:
return []

View file

@ -5,7 +5,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING
from loguru import logger
from sqlmodel import Session, select
from sqlmodel import select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
@ -53,16 +53,17 @@ class DatabaseVariableService(VariableService, Service):
except Exception as e: # noqa: BLE001
logger.exception(f"Error processing {var_name} variable: {e!s}")
def get_variable_sync(
async def get_variable(
self,
user_id: UUID | str,
name: str,
field: str,
session: Session,
session: AsyncSession,
) -> str:
# we get the credential from the database
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name)
variable = (await session.exec(stmt)).first()
if not variable or not variable.value:
msg = f"{name} variable not found."
@ -82,10 +83,6 @@ class DatabaseVariableService(VariableService, Service):
stmt = select(Variable).where(Variable.user_id == user_id)
return list((await session.exec(stmt)).all())
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables if variable]
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
variables = await self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]