parent
46c9d13657
commit
4cc336fa45
4 changed files with 13 additions and 49 deletions
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|||
from langflow.custom.custom_component.base_component import BaseComponent
|
||||
from langflow.helpers.flow import list_flows, load_flow, run_flow
|
||||
from langflow.schema import Data
|
||||
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.deps import async_session_scope, get_storage_service, get_variable_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
|
@ -442,7 +442,7 @@ class CustomComponent(BaseComponent):
|
|||
user_id = self.user_id or ""
|
||||
return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)
|
||||
|
||||
def list_key_names(self):
|
||||
async def list_key_names(self):
|
||||
"""Lists the names of the variables for the current user.
|
||||
|
||||
Raises:
|
||||
|
|
@ -456,8 +456,8 @@ class CustomComponent(BaseComponent):
|
|||
raise ValueError(msg)
|
||||
variable_service = get_variable_service()
|
||||
|
||||
with session_scope() as session:
|
||||
return variable_service.list_variables_sync(user_id=self.user_id, session=session)
|
||||
async with async_session_scope() as session:
|
||||
return await variable_service.list_variables(user_id=self.user_id, session=session)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import abc
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from langflow.services.base import Service
|
||||
|
|
@ -23,19 +22,6 @@ class VariableService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_variable_sync(self, user_id: UUID | str, name: str, field: str, session: Session) -> str:
|
||||
"""Get a variable value.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
name: The name of the variable.
|
||||
field: The field of the variable.
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
The value of the variable.
|
||||
"""
|
||||
|
||||
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
|
||||
"""Async get a variable value.
|
||||
|
||||
|
|
@ -48,20 +34,8 @@ class VariableService(Service):
|
|||
Returns:
|
||||
The value of the variable.
|
||||
"""
|
||||
return await session.run_sync(lambda session_: self.get_variable_sync(user_id, name, field, session_))
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
A list of variable names.
|
||||
"""
|
||||
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
"""List all variables.
|
||||
|
||||
|
|
@ -72,7 +46,6 @@ class VariableService(Service):
|
|||
Returns:
|
||||
A list of variable names.
|
||||
"""
|
||||
return await session.run_sync(lambda session_: self.list_variables_sync(user_id, session_))
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_variable(self, user_id: UUID | str, name: str, value: str, session: AsyncSession) -> Variable:
|
||||
|
|
|
|||
|
|
@ -79,15 +79,9 @@ class KubernetesSecretService(VariableService, Service):
|
|||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def get_variable_sync(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
session: Session,
|
||||
) -> str:
|
||||
async def get_variable(self, user_id: UUID | str, name: str, field: str, session: AsyncSession) -> str:
|
||||
secret_name = encode_user_id(user_id)
|
||||
key, value = self.resolve_variable(secret_name, user_id, name)
|
||||
key, value = await asyncio.to_thread(self.resolve_variable, secret_name, user_id, name)
|
||||
if key.startswith(CREDENTIAL_TYPE + "_") and field == "session_id":
|
||||
msg = (
|
||||
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
|
||||
|
|
@ -97,12 +91,12 @@ class KubernetesSecretService(VariableService, Service):
|
|||
return value
|
||||
|
||||
@override
|
||||
def list_variables_sync(
|
||||
async def list_variables(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
session: Session,
|
||||
) -> list[str | None]:
|
||||
variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id))
|
||||
variables = await asyncio.to_thread(self.kubernetes_secrets.get_secret, name=encode_user_id(user_id))
|
||||
if not variables:
|
||||
return []
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
|
|
@ -53,16 +53,17 @@ class DatabaseVariableService(VariableService, Service):
|
|||
except Exception as e: # noqa: BLE001
|
||||
logger.exception(f"Error processing {var_name} variable: {e!s}")
|
||||
|
||||
def get_variable_sync(
|
||||
async def get_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
name: str,
|
||||
field: str,
|
||||
session: Session,
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
# we get the credential from the database
|
||||
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
|
||||
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
|
||||
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name)
|
||||
variable = (await session.exec(stmt)).first()
|
||||
|
||||
if not variable or not variable.value:
|
||||
msg = f"{name} variable not found."
|
||||
|
|
@ -82,10 +83,6 @@ class DatabaseVariableService(VariableService, Service):
|
|||
stmt = select(Variable).where(Variable.user_id == user_id)
|
||||
return list((await session.exec(stmt)).all())
|
||||
|
||||
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]:
|
||||
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
||||
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]:
|
||||
variables = await self.get_all(user_id=user_id, session=session)
|
||||
return [variable.name for variable in variables if variable]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue