From 43fdc7bbba37808c8dca51db4432165c0976d08b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 5 Mar 2024 23:01:54 -0300 Subject: [PATCH] Refactor CustomComponent class and update to_records method --- .../custom_component/custom_component.py | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index dde147d77..35bdd2c5f 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -1,7 +1,15 @@ import operator from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, - Sequence, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + List, + Optional, + Sequence, + Union, +) from uuid import UUID import yaml @@ -12,13 +20,17 @@ from sqlmodel import select from langflow.interface.custom.code_parser.utils import ( extract_inner_type_from_generic_alias, - extract_union_types_from_generic_alias) + extract_union_types_from_generic_alias, +) from langflow.interface.custom.custom_component.component import Component from langflow.schema import Record from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter -from langflow.services.deps import (get_credential_service, get_db_service, - get_storage_service) +from langflow.services.deps import ( + get_credential_service, + get_db_service, + get_storage_service, +) from langflow.services.storage.service import StorageService from langflow.utils import validate @@ -126,7 +138,9 @@ class CustomComponent(Component): def build_config(self): return self.field_config - def update_build_config(self, build_config: dict, field_name: str, field_value: Any): + def update_build_config( + self, build_config: dict, field_name: str, field_value: Any + ): build_config[field_name] = field_value return build_config @@ -135,7 +149,7 @@ class CustomComponent(Component): return self.get_code_tree(self.code or "") def to_records( - self, data: Any, text_key: str = "text", data_key: str = "data" + self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False ) -> List[Record]: """ Converts input data into a list of Record objects. @@ -144,8 +158,9 @@ class CustomComponent(Component): data (Any): The input data to be converted. It can be a single item or a sequence of items. If the input data is a Langchain Document, text_key and data_key are ignored. - text_key (str, optional): The key to access the text value in each item. Defaults to "text". - data_key (str, optional): The key to access the data value in each item. Defaults to "data". + keys (List[str], optional): The keys to access the text and data values in each item. + It should be a list of strings where the first element is the text key and the second element is the data key. + Defaults to None, in which case the default keys "text" and "data" are used. Returns: List[Record]: A list of Record objects. @@ -158,27 +173,29 @@ class CustomComponent(Component): if not isinstance(data, Sequence): data = [data] for item in data: + data_dict = {} if isinstance(item, Document): - item = {"text": item.page_content, "data": item.metadata} + data_dict = item.metadata + data_dict["text"] = item.page_content elif isinstance(item, BaseModel): model_dump = item.model_dump() - if text_key not in model_dump: - raise ValueError(f"Key '{text_key}' not found in BaseModel item.") - if data_key not in model_dump: - raise ValueError(f"Key '{data_key}' not found in BaseModel item.") - item = {"text": model_dump[text_key], "data": model_dump[data_key]} + for key in keys: + if silent_errors: + data_dict[key] = model_dump.get(key, "") + else: + try: + data_dict[key] = model_dump[key] + except KeyError: + raise ValueError(f"Key {key} not found in {item}") + elif isinstance(item, str): - item = {"text": item, "data": {}} + data_dict = {"text": item} elif isinstance(item, dict): - if text_key not in item: - raise ValueError(f"Key '{text_key}' not found in dictionary item.") - if data_key not in item: - raise ValueError(f"Key '{data_key}' not found in dictionary item.") - item = {"text": item[text_key], "data": item[data_key]} + data_dict = item.copy() else: raise ValueError(f"Invalid data type: {type(item)}") - records.append(Record(**item)) + records.append(Record(data=data_dict)) return records