From b1ae5e41595092169cdf8751624bbb9f9d33fd1b Mon Sep 17 00:00:00 2001 From: Deon Sanchez <69873175+deon-sanchez@users.noreply.github.com> Date: Thu, 3 Jul 2025 08:55:55 -0600 Subject: [PATCH] refactor: reuse single session when getting variables from db (#8814) * refactor: enhance database session management in custom components - Updated `get_variables` method in `CustomComponent` to accept an optional session parameter, allowing for session reuse and reducing connection pool exhaustion. - Modified `update_params_with_load_from_db_fields` to pass the session when calling `get_variables`. - Adjusted `get_instance_results` to support session management for database operations. - Increased connection pool size and max overflow in settings for improved performance under load. * [autofix.ci] apply automated fixes * Prefer single session by default: * remove unused session * Revert pool size changes * refactor: update get_variables method for backward compatibility - Added a new async `get_variables` method in `CustomComponent` to maintain backward compatibility with the deprecated method, ensuring it calls the existing `get_variable` method with session management. - This change enhances the robustness of the component while preserving existing functionality. * refactor: remove unused session import from endpoints.py - Eliminated the unused `session_scope` import from the `endpoints.py` file to streamline the code and improve clarity. This change contributes to maintaining a clean and efficient codebase. * refactor: update deprecated variables method in CustomComponent - Modified the `variables` method to call the new `get_variables` method for improved clarity and consistency. This change maintains backward compatibility while encouraging the use of the updated async method. * refactor: update method calls to use get_variables because we don't have session in update_build_config - Replaced instances of the deprecated `get_variable` method with the new `get_variables` method in `LMStudioEmbeddingsComponent`, `LMStudioModelComponent`, and `ChatOllamaComponent`. This change enhances code clarity and maintains consistency across components while ensuring backward compatibility. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jordan Frazier Co-authored-by: Gabriel Luiz Freitas Almeida Co-authored-by: Carlos Coelho <80289056+carlosrcoelho@users.noreply.github.com> --- src/backend/base/langflow/api/v1/endpoints.py | 1 + .../components/lmstudio/lmstudioembeddings.py | 4 +- .../custom_component/custom_component.py | 8 +++- .../langflow/interface/initialize/loading.py | 48 ++++++++++--------- 4 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 3e8585148..85e1eac6c 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -717,6 +717,7 @@ async def custom_component_update( for field_name, field_dict in template.items() if isinstance(field_dict, dict) and field_dict.get("load_from_db") and field_dict.get("value") ] + params = await update_params_with_load_from_db_fields(cc_instance, params, load_from_db_fields) cc_instance.set_attributes(params) updated_build_config = code_request.get_template() diff --git a/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py b/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py index 57f783fc1..e3e86c121 100644 --- a/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py +++ b/src/backend/base/langflow/components/lmstudio/lmstudioembeddings.py @@ -2,7 +2,6 @@ from typing import Any from urllib.parse import urljoin import httpx -from typing_extensions import override from langflow.base.embeddings.model import LCEmbeddingsModel from langflow.field_typing import Embeddings @@ -15,8 +14,7 @@ class LMStudioEmbeddingsComponent(LCEmbeddingsModel): description: str = "Generate embeddings using LM Studio." icon = "LMStudio" - @override - async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): + async def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): # noqa: ARG002 if field_name == "model": base_url_dict = build_config.get("base_url", {}) base_url_load_from_db = base_url_dict.get("load_from_db", False) diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 885604fba..3ddea976e 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -405,6 +405,11 @@ class CustomComponent(BaseComponent): return run_until_complete(self.get_variables(name, field)) async def get_variables(self, name: str, field: str): + """DEPRECATED - This is kept for backward compatibility. Use get_variable instead.""" + async with session_scope() as session: + return await self.get_variable(name, field, session) + + async def get_variable(self, name: str, field: str, session): """Returns the variable for the current user with the specified name. Raises: @@ -425,8 +430,7 @@ class CustomComponent(BaseComponent): else: msg = f"Invalid user id: {self.user_id}" raise TypeError(msg) - async with session_scope() as session: - return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) + return await variable_service.get_variable(user_id=user_id, name=name, field=field, session=session) async def list_key_names(self): """Lists the names of the variables for the current user. diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index f20629c0b..5a25f8a75 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -12,7 +12,7 @@ from pydantic import PydanticDeprecatedSince20 from langflow.custom.eval import eval_custom_component_code from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data -from langflow.services.deps import get_tracing_service +from langflow.services.deps import get_tracing_service, session_scope if TYPE_CHECKING: from langflow.custom.custom_component.component import Component @@ -59,7 +59,10 @@ async def get_instance_results( base_type: str = "component", ): custom_params = await update_params_with_load_from_db_fields( - custom_component, custom_params, vertex.load_from_db_fields, fallback_to_env_vars=fallback_to_env_vars + custom_component, + custom_params, + vertex.load_from_db_fields, + fallback_to_env_vars=fallback_to_env_vars, ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) @@ -111,30 +114,31 @@ async def update_params_with_load_from_db_fields( *, fallback_to_env_vars=False, ): - for field in load_from_db_fields: - if field not in params or not params[field]: - continue + async with session_scope() as session: + for field in load_from_db_fields: + if field not in params or not params[field]: + continue - try: - key = await custom_component.get_variables(params[field], field) - except ValueError as e: - if any(reason in str(e) for reason in ["User id is not set", "variable not found."]): - raise - logger.debug(str(e)) - key = None + try: + key = await custom_component.get_variable(name=params[field], field=field, session=session) + except ValueError as e: + if any(reason in str(e) for reason in ["User id is not set", "variable not found."]): + raise + logger.debug(str(e)) + key = None - if fallback_to_env_vars and key is None: - key = os.getenv(params[field]) - if key: - logger.info(f"Using environment variable {params[field]} for {field}") - else: - logger.error(f"Environment variable {params[field]} is not set.") + if fallback_to_env_vars and key is None: + key = os.getenv(params[field]) + if key: + logger.info(f"Using environment variable {params[field]} for {field}") + else: + logger.error(f"Environment variable {params[field]} is not set.") - params[field] = key if key is not None else None - if key is None: - logger.warning(f"Could not get value for {field}. Setting it to None.") + params[field] = key if key is not None else None + if key is None: + logger.warning(f"Could not get value for {field}. Setting it to None.") - return params + return params async def build_component(