Refactor CustomComponent class methods

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 16:22:45 -03:00
commit d65563dd62
2 changed files with 26 additions and 75 deletions

View file

@ -88,9 +88,7 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(
self, data: Any, text_key: str = "text", data_key: str = "data"
) -> List[dict]:
def to_records(self, data: Any, text_key: str = "text", data_key: str = "data") -> List[dict]:
"""
Convert data into a list of records.
@ -117,9 +115,7 @@ class CustomComponent(Component):
return records
def create_references_from_records(
self, records: List[dict], include_data: bool = False
) -> str:
def create_references_from_records(self, records: List[dict], include_data: bool = False) -> str:
"""
Create references from a list of records.
@ -130,6 +126,8 @@ class CustomComponent(Component):
Returns:
str: A string containing the references in markdown format.
"""
if not records:
return ""
markdown_string = "---\n"
for record in records:
markdown_string += f"- Text: {record['text']}"
@ -152,8 +150,7 @@ class CustomComponent(Component):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
),
},
)
@ -167,20 +164,14 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -237,9 +228,7 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
@ -249,9 +238,7 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
return credential_service.list_credentials(user_id=self._user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -302,11 +289,7 @@ class CustomComponent(Component):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")

View file

@ -27,18 +27,14 @@ from langflow.utils import validate
from langflow.utils.util import get_base_classes
def add_output_types(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -67,18 +63,14 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.template.fields = reordered_fields
def add_base_classes(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid return type. Please check your code and try again."
),
"error": ("Invalid return type. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
)
@ -153,14 +145,10 @@ def add_new_custom_field(
# If options is a list, then it's a dropdown
# If options is None, then it's a list of strings
is_list = isinstance(field_config.get("options"), list)
field_config["is_list"] = (
is_list or field_config.get("is_list", False) or field_contains_list
)
field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -191,9 +179,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
if "name" not in extra_field or extra_field["name"] == "self":
continue
field_name, field_type, field_value, field_required = get_field_properties(
extra_field
)
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
@ -231,9 +217,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -261,9 +245,7 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -318,24 +300,16 @@ def build_custom_component_template(
frontend_node = build_frontend_node(custom_component.template_config)
logger.debug("Updated attributes")
field_config, custom_instance = run_build_config(
custom_component, user_id=user_id, update_field=update_field
)
field_config, custom_instance = run_build_config(custom_component, user_id=user_id, update_field=update_field)
logger.debug("Built field config")
entrypoint_args = custom_component.get_function_entrypoint_args
add_extra_fields(frontend_node, field_config, entrypoint_args)
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
logger.debug("Added base classes")
reorder_fields(frontend_node, custom_instance._get_field_order())
@ -347,9 +321,7 @@ def build_custom_component_template(
raise HTTPException(
status_code=400,
detail={
"error": (
"Invalid type convertion. Please check your code and try again."
),
"error": ("Invalid type convertion. Please check your code and try again."),
"traceback": traceback.format_exc(),
},
) from exc
@ -373,9 +345,7 @@ def build_custom_components(settings_service):
if not settings_service.settings.COMPONENTS_PATH:
return {}
logger.info(
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
)
logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}")
custom_components_from_file = {}
processed_paths = set()
for path in settings_service.settings.COMPONENTS_PATH:
@ -386,9 +356,7 @@ def build_custom_components(settings_service):
custom_component_dict = build_custom_component_list_from_path(path_str)
if custom_component_dict:
category = next(iter(custom_component_dict))
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)