From 33cca64f322f48ea713cda3ac362f28d657b55cb Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 6 Mar 2024 15:48:43 -0300 Subject: [PATCH] Update schema and utils modules --- src/backend/langflow/schema/schema.py | 6 ++- src/backend/langflow/utils/util.py | 62 +++++++++++++++++++-------- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/backend/langflow/schema/schema.py b/src/backend/langflow/schema/schema.py index fd3bd4ea0..18cf9e13f 100644 --- a/src/backend/langflow/schema/schema.py +++ b/src/backend/langflow/schema/schema.py @@ -13,7 +13,7 @@ class Record(BaseModel): """ data: dict = {} - _default_value = None + _default_value: str = "" @classmethod def from_document(cls, document: Document) -> "Record": @@ -63,7 +63,9 @@ class Record(BaseModel): return self.data.get(key, self._default_value) except KeyError: # Fallback to default behavior to raise AttributeError for undefined attributes - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{key}'" + ) def __setattr__(self, key, value): """ diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 825b5471c..fb704b2bf 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -15,8 +15,12 @@ def remove_ansi_escape_codes(text): return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text) -def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False): - classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()] +def build_template_from_function( + name: str, type_to_loader_dict: Dict, add_function: bool = False +): + classes = [ + item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() + ] # Raise error if name is not in chains if name not in classes: @@ -37,8 +41,10 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = get_default_factory( - module=_class.__base__.__module__, function=value_ + variables[class_field_items]["default"] = ( + get_default_factory( + module=_class.__base__.__module__, function=value_ + ) ) except Exception: variables[class_field_items]["default"] = None @@ -46,7 +52,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] if class_field_items in docs.params else "" + docs.params[class_field_items] + if class_field_items in docs.params + else "" ) # Adding function to base classes to allow # the output to be a function @@ -61,7 +69,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct } -def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False): +def build_template_from_class( + name: str, type_to_cls_dict: Dict, add_function: bool = False +): classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains @@ -85,9 +95,11 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = get_default_factory( - module=_class.__base__.__module__, - function=value_, + variables[class_field_items]["default"] = ( + get_default_factory( + module=_class.__base__.__module__, + function=value_, + ) ) except Exception: variables[class_field_items]["default"] = None @@ -95,7 +107,9 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] if class_field_items in docs.params else "" + docs.params[class_field_items] + if class_field_items in docs.params + else "" ) base_classes = get_base_classes(_class) # Adding function to base classes to allow @@ -127,7 +141,9 @@ def build_template_from_method( # Check if the method exists in this class if not hasattr(_class, method_name): - raise ValueError(f"Method {method_name} not found in class {class_name}") + raise ValueError( + f"Method {method_name} not found in class {class_name}" + ) # Get the method method = getattr(_class, method_name) @@ -146,8 +162,14 @@ def build_template_from_method( "_type": _type, **{ name: { - "default": (param.default if param.default != param.empty else None), - "type": (param.annotation if param.annotation != param.empty else None), + "default": ( + param.default if param.default != param.empty else None + ), + "type": ( + param.annotation + if param.annotation != param.empty + else None + ), "required": param.default == param.empty, } for name, param in params.items() @@ -234,7 +256,9 @@ def sync_to_async(func): return async_wrapper -def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]: +def format_dict( + dictionary: Dict[str, Any], class_name: Optional[str] = None +) -> Dict[str, Any]: """ Formats a dictionary by removing certain keys and modifying the values of other keys. @@ -320,7 +344,9 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str: The modified type string. """ if any(list_type in _type for list_type in ["List", "Sequence", "Set"]): - _type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] + _type = ( + _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] + ) value["list"] = True else: value["list"] = False @@ -423,7 +449,9 @@ def set_headers_value(value: Dict[str, Any]) -> None: value["value"] = """{"Authorization": "Bearer "}""" -def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None: +def add_options_to_field( + value: Dict[str, Any], class_name: Optional[str], key: str +) -> None: """ Adds options to the field based on the class name and key. """ @@ -442,7 +470,7 @@ def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: def build_loader_repr_from_records(records: List[Record]) -> str: if records: - avg_length = sum(len(doc.page_content) for doc in records) / len(records) + avg_length = sum(len(doc.text) for doc in records) / len(records) return f"""{len(records)} records \nAvg. Record Length (characters): {int(avg_length)} Records: {records[:3]}..."""