Refactor class instantiation and update params with load_from_db_fields
This commit is contained in:
parent
937a50498a
commit
108d0b2e12
1 changed files with 29 additions and 9 deletions
|
|
@ -11,9 +11,6 @@ from langchain.chains.base import Chain
|
|||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
from langchain_core.documents import Document
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type
|
||||
from langflow.interface.initialize.llm import initialize_vertexai
|
||||
|
|
@ -25,6 +22,8 @@ from langflow.interface.toolkits.base import toolkits_creator
|
|||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.utils import validate
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow import CustomComponent
|
||||
|
|
@ -37,7 +36,9 @@ def build_vertex_in_params(params: Dict) -> Dict:
|
|||
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
|
||||
|
||||
|
||||
async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
|
||||
async def instantiate_class(
|
||||
node_type: str, base_type: str, load_from_db_fields: list[str], params: Dict, user_id=None
|
||||
) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
|
|
@ -49,7 +50,9 @@ async def instantiate_class(node_type: str, base_type: str, params: Dict, user_i
|
|||
return custom_node(**params)
|
||||
logger.debug(f"Instantiating {node_type} of type {base_type}")
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
|
||||
return await instantiate_based_on_type(
|
||||
class_object, base_type, node_type, load_from_db_fields, params, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
|
|
@ -76,7 +79,7 @@ def convert_kwargs(params):
|
|||
return params
|
||||
|
||||
|
||||
async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
|
||||
async def instantiate_based_on_type(class_object, base_type, node_type, load_from_db_fields, params, user_id):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(node_type, class_object, params)
|
||||
elif base_type == "prompts":
|
||||
|
|
@ -110,18 +113,32 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params,
|
|||
elif base_type == "memory":
|
||||
return instantiate_memory(node_type, class_object, params)
|
||||
elif base_type == "custom_components":
|
||||
return await instantiate_custom_component(node_type, class_object, params, user_id)
|
||||
return await instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id)
|
||||
elif base_type == "wrappers":
|
||||
return instantiate_wrapper(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
async def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
def update_params_with_load_from_db_fields(custom_component, params, load_from_db_fields):
|
||||
# For each field in load_from_db_fields, we will check if it's in the params
|
||||
# and if it is, we will get the value from the custom_component.keys(name)
|
||||
# and update the params with the value
|
||||
for field in load_from_db_fields:
|
||||
if field in params:
|
||||
try:
|
||||
params[field] = custom_component.keys(field)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to get value for {field} from custom component. Error: {exc}")
|
||||
pass
|
||||
return params
|
||||
|
||||
|
||||
async def instantiate_custom_component(node_type, class_object, load_from_db_fields, params, user_id):
|
||||
params_copy = params.copy()
|
||||
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component = class_object(user_id=user_id)
|
||||
|
||||
params_copy = update_params_with_load_from_db_fields(custom_component, params_copy, load_from_db_fields)
|
||||
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
|
||||
params_copy["retriever"] = params_copy["retriever"].as_retriever()
|
||||
|
||||
|
|
@ -497,3 +514,6 @@ def build_prompt_template(prompt, tools):
|
|||
}
|
||||
|
||||
return prompt
|
||||
return prompt
|
||||
return prompt
|
||||
return prompt
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue