From 108d0b2e125068b0facd467dd6514ee72dbef3b9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Feb 2024 21:16:34 -0300 Subject: [PATCH] Refactor class instantiation and update params with load_from_db_fields --- .../langflow/interface/initialize/loading.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 652cd1a0c..00fdf0773 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -11,9 +11,6 @@ from langchain.chains.base import Chain from langchain.document_loaders.base import BaseLoader from langchain_community.vectorstores import VectorStore from langchain_core.documents import Document -from loguru import logger -from pydantic import ValidationError - from langflow.interface.custom_lists import CUSTOM_NODES from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type from langflow.interface.initialize.llm import initialize_vertexai @@ -25,6 +22,8 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.utils import load_file_into_dict from langflow.interface.wrappers.base import wrapper_creator from langflow.utils import validate +from loguru import logger +from pydantic import ValidationError if TYPE_CHECKING: from langflow import CustomComponent @@ -37,7 +36,9 @@ def build_vertex_in_params(params: Dict) -> Dict: return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()} -async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any: +async def instantiate_class( + node_type: str, base_type: str, load_from_db_fields: list[str], params: Dict, user_id=None +) -> Any: """Instantiate class from module type and key, and params""" params = convert_params_to_sets(params) params = convert_kwargs(params) @@ -49,7 +50,9 @@ async def instantiate_class(node_type: str, base_type: str, params: Dict, user_i return custom_node(**params) logger.debug(f"Instantiating {node_type} of type {base_type}") class_object = import_by_type(_type=base_type, name=node_type) - return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id) + return await instantiate_based_on_type( + class_object, base_type, node_type, load_from_db_fields, params, user_id=user_id + ) def convert_params_to_sets(params): @@ -76,7 +79,7 @@ def convert_kwargs(params): return params -async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id): +async def instantiate_based_on_type(class_object, base_type, node_type, load_from_db_fields, params, user_id): if base_type == "agents": return instantiate_agent(node_type, class_object, params) elif base_type == "prompts": @@ -110,18 +113,32 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params, elif base_type == "memory": return instantiate_memory(node_type, class_object, params) elif base_type == "custom_components": - return await instantiate_custom_component(node_type, class_object, params, user_id) + return await instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id) elif base_type == "wrappers": return instantiate_wrapper(node_type, class_object, params) else: return class_object(**params) -async def instantiate_custom_component(node_type, class_object, params, user_id): +def update_params_with_load_from_db_fields(custom_component, params, load_from_db_fields): + # For each field in load_from_db_fields, we will check if it's in the params + # and if it is, we will get the value from the custom_component.keys(name) + # and update the params with the value + for field in load_from_db_fields: + if field in params: + try: + params[field] = custom_component.keys(field) + except Exception as exc: + logger.error(f"Failed to get value for {field} from custom component. Error: {exc}") + pass + return params + + +async def instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id): params_copy = params.copy() class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code")) custom_component = class_object(user_id=user_id) - + params_copy = update_params_with_load_from_db_fields(custom_component, params_copy, load_from_db_fields) if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"): params_copy["retriever"] = params_copy["retriever"].as_retriever() @@ -497,3 +514,6 @@ def build_prompt_template(prompt, tools): } return prompt + return prompt + return prompt + return prompt