diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 3e8585148..85e1eac6c 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -717,6 +717,7 @@ async def custom_component_update( for field_name, field_dict in template.items() if isinstance(field_dict, dict) and field_dict.get("load_from_db") and field_dict.get("value") ] + params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields) cc_instance.set_attributes(params) updated_build_config = code_request.get_template() diff --git a/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py b/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py index 57f783fc1..e3e86c121 100644 --- a/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py +++ b/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py @@ -2,7 +2,6 @@ from typing import Any from urllib.parse import urljoin import httpx -from typing_extensions import override from langflow.base.embeddings.model import LCEmbeddingsModel from langflow.field_typing import Embeddings @@ -15,8 +14,7 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel): description: str = "Generate embeddings using LM Studio." icon = "LMStudio" - @override - async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): # noqa: ARG002 if field_name == "model": base_url_dict = build_config.get("base_url", {}) base_url_load_from_db = base_url_dict.get("load_from_db", False) diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 885604fba..3ddea976e 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -405,6 +405,11 @@ class CustomComponent(BaseComponent): return run_until_complete(self.get_variables(name, field)) async def get_variables(self, name: str, field: str): + """DEPRECATED - This is kept for backward compatibility. Use get_variable instead.""" + async with session_scope() as session: + return await self.get_variable(name, field, session) + + async def get_variable(self, name: str, field: str, session): """Returns the variable for the current user with the specified name. Raises: @@ -425,8 +430,7 @@ class CustomComponent(BaseComponent): else: msg = f"Invalid user id: {self.user_id}" raise TypeError(msg) - async with session_scope() as session: - return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) + return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) async def list_key_names(self): """Lists the names of the variables for the current user. diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index f20629c0b..5a25f8a75 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -12,7 +12,7 @@ from pydantic import PydanticDeprecatedSince20 from langflow.custom.eval import eval_custom_component_code from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data -from langflow.services.deps import get_tracing_service +from langflow.services.deps import get_tracing_service, session_scope if TYPE_CHECKING: from langflow.custom.custom_component.component import Component @@ -59,7 +59,10 @@ async def get_instance_results( base_type: str = "component", ): custom_params = await update_params_with_load_from_db_fields( - custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars=fallback_to_env_vars + custom_component, + custom_params, + vertex.load_from_db_fields, + fallback_to_env_vars=fallback_to_env_vars, ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) @@ -111,30 +114,31 @@ async def update_params_with_load_from_db_fields( *, fallback_to_env_vars=False, ): - for field in load_from_db_fields: - if field not in params or not params[field]: - continue + async with session_scope() as session: + for field in load_from_db_fields: + if field not in params or not params[field]: + continue - try: - key = await custom_component.get_variables(params[field], field) - except ValueError as e: - if any(reason in str(e) for reason in ["User id is not set", "variable not found."]): - raise - logger.debug(str(e)) - key = None + try: + key = await custom_component.get_variable(name=params[field], field=field, session=session) + except ValueError as e: + if any(reason in str(e) for reason in ["User id is not set", "variable not found."]): + raise + logger.debug(str(e)) + key = None - if fallback_to_env_vars and key is None: - key = os.getenv(params[field]) - if key: - logger.info(f"Using environment variable {params[field]} for {field}") - else: - logger.error(f"Environment variable {params[field]} is not set.") + if fallback_to_env_vars and key is None: + key = os.getenv(params[field]) + if key: + logger.info(f"Using environment variable {params[field]} for {field}") + else: + logger.error(f"Environment variable {params[field]} is not set.") - params[field] = key if key is not None else None - if key is None: - logger.warning(f"Could not get value for {field}. Setting it to None.") + params[field] = key if key is not None else None + if key is None: + logger.warning(f"Could not get value for {field}. Setting it to None.") - return params + return params async def build_component(