fix: environment variable loading and improve error handling in DB retrieval (#4168)

* test: rewrite and enable variable loading test

* refactor: simplify environment variable storage logic

* refactor: simplify parameter loading logic from db

* fix: handle additional error case when loading variables
This commit is contained in:
Ítalo Johnny 2024-10-16 08:53:40 -03:00 committed by GitHub
commit 1a72aa71b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 60 additions and 78 deletions

View file

@ -109,38 +109,28 @@ def update_params_with_load_from_db_fields(
*,
fallback_to_env_vars=False,
):
# For each field in load_from_db_fields, we will check if it's in the params
# and if it is, we will get the value from the custom_component.keys(name)
# and update the params with the value
for field in load_from_db_fields:
if field in params:
try:
key = None
try:
key = custom_component.variables(params[field], field)
except ValueError as e:
# check if "User id is not set" is in the error message, this is an internal bug
if "User id is not set" in str(e):
raise
logger.debug(str(e))
if fallback_to_env_vars and key is None:
key = os.getenv(params[field])
if key is None:
msg = f"Environment variable {params[field]} is not set."
logger.error(msg)
else:
logger.info(f"Using environment variable {params[field]} for {field}")
if key is None:
logger.warning(f"Could not get value for {field}. Setting it to None.")
if field not in params:
continue
params[field] = key
except TypeError:
try:
key = custom_component.variables(params[field], field)
except ValueError as e:
if any(reason in str(e) for reason in ["User id is not set", "variable not found."]):
raise
logger.debug(str(e))
key = None
except Exception: # noqa: BLE001
logger.exception(f"Failed to get value for {field} from custom component. Setting it to None.")
params[field] = None
if fallback_to_env_vars and key is None:
key = os.getenv(params[field])
if key:
logger.info(f"Using environment variable {params[field]} for {field}")
else:
logger.error(f"Environment variable {params[field]} is not set.")
params[field] = key if key is not None else None
if key is None:
logger.warning(f"Could not get value for {field}. Setting it to None.")
return params

View file

@ -27,46 +27,31 @@ class DatabaseVariableService(VariableService, Service):
self.settings_service = settings_service
def initialize_user_variables(self, user_id: UUID | str, session: Session = Depends(get_session)):
# Check for environment variables that should be stored in the database
should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not"
logger.info(f"{should_or_should_not} store environment variables in the database.")
if self.settings_service.settings.store_environment_variables:
for var in self.settings_service.settings.variables_to_get_from_environment:
if var in os.environ:
logger.debug(f"Creating {var} variable from environment.")
if found_variable := session.exec(
select(Variable).where(Variable.user_id == user_id, Variable.name == var)
).first():
# Update it
value = os.environ[var]
if isinstance(value, str):
value = value.strip()
# If the secret_key changes the stored value could be invalid
# so we need to re-encrypt it
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service)
found_variable.value = encrypted
session.add(found_variable)
session.commit()
else:
# Create it
try:
value = os.environ[var]
if isinstance(value, str):
value = value.strip()
self.create_variable(
user_id=user_id,
name=var,
value=value,
default_fields=[],
_type=CREDENTIAL_TYPE,
session=session,
)
except Exception: # noqa: BLE001
logger.exception(f"Error creating {var} variable")
else:
if not self.settings_service.settings.store_environment_variables:
logger.info("Skipping environment variable storage.")
return
logger.info("Storing environment variables in the database.")
for var_name in self.settings_service.settings.variables_to_get_from_environment:
if var_name in os.environ and os.environ[var_name].strip():
value = os.environ[var_name].strip()
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name)
existing = session.exec(query).first()
try:
if existing:
self.update_variable(user_id, var_name, value, session)
else:
self.create_variable(
user_id=user_id,
name=var_name,
value=value,
default_fields=[],
_type=CREDENTIAL_TYPE,
session=session,
)
logger.info(f"Processed {var_name} variable from environment.")
except Exception as e: # noqa: BLE001
logger.exception(f"Error processing {var_name} variable: {e!s}")
def get_variable(
self,

View file

@ -9,6 +9,7 @@ from langflow.services.database.models.variable.model import VariableUpdate
from langflow.services.deps import get_settings_service
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
@pytest.fixture
@ -25,26 +26,32 @@ def session():
yield session
@pytest.mark.skip(reason="Temporarily disabled")
def test_initialize_user_variables__donkey(service, session):
def test_initialize_user_variables__create_and_update(service, session):
user_id = uuid4()
name = "OPENAI_API_KEY"
value = "donkey"
service.initialize_user_variables(user_id, session=session)
result = service.create_variable(user_id, "OPENAI_API_KEY", "donkey", session=session)
new_service = DatabaseVariableService(get_settings_service())
new_service.initialize_user_variables(user_id, session=session)
field = ""
good_vars = {k: f"value{i}" for i, k in enumerate(VARIABLES_TO_GET_FROM_ENVIRONMENT)}
bad_vars = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"}
env_vars = {**good_vars, **bad_vars}
result = new_service.get_variable(user_id, name, "", session=session)
service.create_variable(user_id, "OPENAI_API_KEY", "outdate", session=session)
env_vars["OPENAI_API_KEY"] = "updated_value"
assert result != value
with patch.dict("os.environ", env_vars, clear=True):
service.initialize_user_variables(user_id=user_id, session=session)
variables = service.list_variables(user_id, session=session)
for name in variables:
value = service.get_variable(user_id, name, field, session=session)
assert value == env_vars[name]
assert all([i in variables for i in good_vars.keys()])
assert all([i not in variables for i in bad_vars.keys()])
def test_initialize_user_variables__not_found_variable(service, session):
with patch("langflow.services.variable.service.DatabaseVariableService.create_variable") as m:
m.side_effect = Exception()
service.initialize_user_variables(uuid4(), session=session)
assert True