Refactor class instantiation and update params with load_from_db_fields

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-06 21:16:34 -03:00
commit 108d0b2e12

View file

@ -11,9 +11,6 @@ from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain_community.vectorstores import VectorStore
from langchain_core.documents import Document
from loguru import logger
from pydantic import ValidationError
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type
from langflow.interface.initialize.llm import initialize_vertexai
@ -25,6 +22,8 @@ from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.utils import load_file_into_dict
from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils import validate
from loguru import logger
from pydantic import ValidationError
if TYPE_CHECKING:
from langflow import CustomComponent
@ -37,7 +36,9 @@ def build_vertex_in_params(params: Dict) -> Dict:
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
async def instantiate_class(
node_type: str, base_type: str, load_from_db_fields: list[str], params: Dict, user_id=None
) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
@ -49,7 +50,9 @@ async def instantiate_class(node_type: str, base_type: str, params: Dict, user_i
return custom_node(**params)
logger.debug(f"Instantiating {node_type} of type {base_type}")
class_object = import_by_type(_type=base_type, name=node_type)
return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
return await instantiate_based_on_type(
class_object, base_type, node_type, load_from_db_fields, params, user_id=user_id
)
def convert_params_to_sets(params):
@ -76,7 +79,7 @@ def convert_kwargs(params):
return params
async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
async def instantiate_based_on_type(class_object, base_type, node_type, load_from_db_fields, params, user_id):
if base_type == "agents":
return instantiate_agent(node_type, class_object, params)
elif base_type == "prompts":
@ -110,18 +113,32 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params,
elif base_type == "memory":
return instantiate_memory(node_type, class_object, params)
elif base_type == "custom_components":
return await instantiate_custom_component(node_type, class_object, params, user_id)
return await instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id)
elif base_type == "wrappers":
return instantiate_wrapper(node_type, class_object, params)
else:
return class_object(**params)
async def instantiate_custom_component(node_type, class_object, params, user_id):
def update_params_with_load_from_db_fields(custom_component, params, load_from_db_fields):
# For each field in load_from_db_fields, we will check if it's in the params
# and if it is, we will get the value from the custom_component.keys(name)
# and update the params with the value
for field in load_from_db_fields:
if field in params:
try:
params[field] = custom_component.keys(field)
except Exception as exc:
logger.error(f"Failed to get value for {field} from custom component. Error: {exc}")
pass
return params
async def instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id):
params_copy = params.copy()
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
custom_component = class_object(user_id=user_id)
params_copy = update_params_with_load_from_db_fields(custom_component, params_copy, load_from_db_fields)
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
params_copy["retriever"] = params_copy["retriever"].as_retriever()
@ -497,3 +514,6 @@ def build_prompt_template(prompt, tools):
}
return prompt
return prompt
return prompt
return prompt