fix: environment variable loading and improve error handling in DB retrieval (#4168)
* test: rewrite and enable variable loading test * refactor: simplify environment variable storage logic * refactor: simplify parameter loading logic from db * fix: handle additional error case when loading variables
This commit is contained in:
parent
c670778ecb
commit
1a72aa71b6
3 changed files with 60 additions and 78 deletions
|
|
@ -109,38 +109,28 @@ def update_params_with_load_from_db_fields(
|
|||
*,
|
||||
fallback_to_env_vars=False,
|
||||
):
|
||||
# For each field in load_from_db_fields, we will check if it's in the params
|
||||
# and if it is, we will get the value from the custom_component.keys(name)
|
||||
# and update the params with the value
|
||||
for field in load_from_db_fields:
|
||||
if field in params:
|
||||
try:
|
||||
key = None
|
||||
try:
|
||||
key = custom_component.variables(params[field], field)
|
||||
except ValueError as e:
|
||||
# check if "User id is not set" is in the error message, this is an internal bug
|
||||
if "User id is not set" in str(e):
|
||||
raise
|
||||
logger.debug(str(e))
|
||||
if fallback_to_env_vars and key is None:
|
||||
key = os.getenv(params[field])
|
||||
if key is None:
|
||||
msg = f"Environment variable {params[field]} is not set."
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.info(f"Using environment variable {params[field]} for {field}")
|
||||
if key is None:
|
||||
logger.warning(f"Could not get value for {field}. Setting it to None.")
|
||||
if field not in params:
|
||||
continue
|
||||
|
||||
params[field] = key
|
||||
|
||||
except TypeError:
|
||||
try:
|
||||
key = custom_component.variables(params[field], field)
|
||||
except ValueError as e:
|
||||
if any(reason in str(e) for reason in ["User id is not set", "variable not found."]):
|
||||
raise
|
||||
logger.debug(str(e))
|
||||
key = None
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Failed to get value for {field} from custom component. Setting it to None.")
|
||||
params[field] = None
|
||||
if fallback_to_env_vars and key is None:
|
||||
key = os.getenv(params[field])
|
||||
if key:
|
||||
logger.info(f"Using environment variable {params[field]} for {field}")
|
||||
else:
|
||||
logger.error(f"Environment variable {params[field]} is not set.")
|
||||
|
||||
params[field] = key if key is not None else None
|
||||
if key is None:
|
||||
logger.warning(f"Could not get value for {field}. Setting it to None.")
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
|||
|
|
@ -27,46 +27,31 @@ class DatabaseVariableService(VariableService, Service):
|
|||
self.settings_service = settings_service
|
||||
|
||||
def initialize_user_variables(self, user_id: UUID | str, session: Session = Depends(get_session)):
|
||||
# Check for environment variables that should be stored in the database
|
||||
should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not"
|
||||
logger.info(f"{should_or_should_not} store environment variables in the database.")
|
||||
if self.settings_service.settings.store_environment_variables:
|
||||
for var in self.settings_service.settings.variables_to_get_from_environment:
|
||||
if var in os.environ:
|
||||
logger.debug(f"Creating {var} variable from environment.")
|
||||
|
||||
if found_variable := session.exec(
|
||||
select(Variable).where(Variable.user_id == user_id, Variable.name == var)
|
||||
).first():
|
||||
# Update it
|
||||
value = os.environ[var]
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
# If the secret_key changes the stored value could be invalid
|
||||
# so we need to re-encrypt it
|
||||
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service)
|
||||
found_variable.value = encrypted
|
||||
session.add(found_variable)
|
||||
session.commit()
|
||||
else:
|
||||
# Create it
|
||||
try:
|
||||
value = os.environ[var]
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
self.create_variable(
|
||||
user_id=user_id,
|
||||
name=var,
|
||||
value=value,
|
||||
default_fields=[],
|
||||
_type=CREDENTIAL_TYPE,
|
||||
session=session,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error creating {var} variable")
|
||||
|
||||
else:
|
||||
if not self.settings_service.settings.store_environment_variables:
|
||||
logger.info("Skipping environment variable storage.")
|
||||
return
|
||||
|
||||
logger.info("Storing environment variables in the database.")
|
||||
for var_name in self.settings_service.settings.variables_to_get_from_environment:
|
||||
if var_name in os.environ and os.environ[var_name].strip():
|
||||
value = os.environ[var_name].strip()
|
||||
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name)
|
||||
existing = session.exec(query).first()
|
||||
try:
|
||||
if existing:
|
||||
self.update_variable(user_id, var_name, value, session)
|
||||
else:
|
||||
self.create_variable(
|
||||
user_id=user_id,
|
||||
name=var_name,
|
||||
value=value,
|
||||
default_fields=[],
|
||||
_type=CREDENTIAL_TYPE,
|
||||
session=session,
|
||||
)
|
||||
logger.info(f"Processed {var_name} variable from environment.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.exception(f"Error processing {var_name} variable: {e!s}")
|
||||
|
||||
def get_variable(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langflow.services.database.models.variable.model import VariableUpdate
|
|||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -25,26 +26,32 @@ def session():
|
|||
yield session
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_initialize_user_variables__donkey(service, session):
|
||||
def test_initialize_user_variables__create_and_update(service, session):
|
||||
user_id = uuid4()
|
||||
name = "OPENAI_API_KEY"
|
||||
value = "donkey"
|
||||
service.initialize_user_variables(user_id, session=session)
|
||||
result = service.create_variable(user_id, "OPENAI_API_KEY", "donkey", session=session)
|
||||
new_service = DatabaseVariableService(get_settings_service())
|
||||
new_service.initialize_user_variables(user_id, session=session)
|
||||
field = ""
|
||||
good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)}
|
||||
bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"}
|
||||
env_vars = {**good_vars, **bad_vars}
|
||||
|
||||
result = new_service.get_variable(user_id, name, "", session=session)
|
||||
service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
|
||||
env_vars["OPENAI_API_KEY"] = "updated_value"
|
||||
|
||||
assert result != value
|
||||
with patch.dict("os.environ", env_vars, clear=True):
|
||||
service.initialize_user_variables(user_id=user_id, session=session)
|
||||
|
||||
variables = service.list_variables(user_id, session=session)
|
||||
for name in variables:
|
||||
value = service.get_variable(user_id, name, field, session=session)
|
||||
assert value == env_vars[name]
|
||||
|
||||
assert all([i in variables for i in good_vars.keys()])
|
||||
assert all([i not in variables for i in bad_vars.keys()])
|
||||
|
||||
|
||||
def test_initialize_user_variables__not_found_variable(service, session):
|
||||
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
|
||||
m.side_effect = Exception()
|
||||
service.initialize_user_variables(uuid4(), session=session)
|
||||
|
||||
assert True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue