From eef5043045d0495087f1035f38b2e3d173a74a2e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 6 Mar 2024 15:33:42 -0300 Subject: [PATCH] Refactor code and update documentation links --- src/backend/langflow/api/v1/flows.py | 10 ++-- .../RecursiveCharacterTextSplitter.py | 4 +- .../custom_component/custom_component.py | 48 +++++++++++---- src/backend/langflow/utils/util.py | 60 +++++-------------- 4 files changed, 59 insertions(+), 63 deletions(-) diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index 19bd86555..dd60d5fed 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -59,9 +59,9 @@ def read_flows( if auth_settings.AUTO_LOGIN: flows = session.exec( select(Flow).where( - (Flow.user_id == None) | (Flow.user_id == current_user.id) + (Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa ) - ).all() # noqa + ).all() else: flows = current_user.flows @@ -71,10 +71,10 @@ def read_flows( try: example_flows = session.exec( select(Flow).where( - Flow.user_id == None, - Flow.folder == STARTER_FOLDER_NAME, # noqa + Flow.user_id == None, # noqa + Flow.folder == STARTER_FOLDER_NAME, ) - ).all() # noqa + ).all() for example_flow in example_flows: if example_flow.id not in flow_ids: flows.append(example_flow) diff --git a/src/backend/langflow/components/textsplitters/RecursiveCharacterTextSplitter.py b/src/backend/langflow/components/textsplitters/RecursiveCharacterTextSplitter.py index 508cb0023..6b9cb865b 100644 --- a/src/backend/langflow/components/textsplitters/RecursiveCharacterTextSplitter.py +++ b/src/backend/langflow/components/textsplitters/RecursiveCharacterTextSplitter.py @@ -11,9 +11,7 @@ from langflow.utils.util import build_loader_repr_from_records class RecursiveCharacterTextSplitterComponent(CustomComponent): display_name: str = "Recursive Character Text Splitter" description: str = "Split text into chunks of a specified length." - documentation: str = ( - "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter" - ) + documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter" def build_config(self): return { diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index f0c3bfa80..11ec59bf3 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -77,13 +77,17 @@ class CustomComponent(Component): def update_state(self, name: str, value: Any): try: - self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id) + self.vertex.graph.update_state( + name=name, record=value, caller=self.vertex.id + ) except Exception as e: raise ValueError(f"Error updating state: {e}") def append_state(self, name: str, value: Any): try: - self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id) + self.vertex.graph.append_state( + name=name, record=value, caller=self.vertex.id + ) except Exception as e: raise ValueError(f"Error appending state: {e}") @@ -134,7 +138,9 @@ class CustomComponent(Component): def build_config(self): return self.field_config - def update_build_config(self, build_config: dict, field_name: str, field_value: Any): + def update_build_config( + self, build_config: dict, field_name: str, field_value: Any + ): build_config[field_name] = field_value return build_config @@ -142,7 +148,9 @@ class CustomComponent(Component): def tree(self): return self.get_code_tree(self.code or "") - def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]: + def to_records( + self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False + ) -> List[Record]: """ Converts input data into a list of Record objects. @@ -191,7 +199,9 @@ class CustomComponent(Component): return records - def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str: + def create_references_from_records( + self, records: List[Record], include_data: bool = False + ) -> str: """ Create references from a list of records. @@ -230,14 +240,20 @@ class CustomComponent(Component): if not self.code: return {} - component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]] + component_classes = [ + cls + for cls in self.tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] if not component_classes: return {} # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name ] return build_methods[0] if build_methods else {} @@ -294,7 +310,9 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) + return credential_service.get_credential( + user_id=self._user_id or "", name=name, session=session + ) return get_credential @@ -304,7 +322,9 @@ class CustomComponent(Component): credential_service = get_credential_service() db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.list_credentials(user_id=self._user_id, session=session) + return credential_service.list_credentials( + user_id=self._user_id, session=session + ) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable.""" @@ -343,7 +363,11 @@ class CustomComponent(Component): if not self._flows_records: self.list_flows() if not flow_id and self._flows_records: - flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name] + flow_ids = [ + flow.data["id"] + for flow in self._flows_records + if flow.data["name"] == flow_name + ] if not flow_ids: raise ValueError(f"Flow {flow_name} not found") elif len(flow_ids) > 1: @@ -365,7 +389,9 @@ class CustomComponent(Component): db_service = get_db_service() with get_session(db_service) as session: flows = session.exec( - select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa + select(Flow) + .where(Flow.user_id == self._user_id) + .where(Flow.is_component == False) # noqa ).all() flows_records = [flow.to_record() for flow in flows] diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index af4f02b59..825b5471c 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text): return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text) -def build_template_from_function( - name: str, type_to_loader_dict: Dict, add_function: bool = False -): - classes = [ - item.__annotations__["return"].__name__ for item in type_to_loader_dict.values() - ] +def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False): + classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()] # Raise error if name is not in chains if name not in classes: @@ -41,10 +37,8 @@ def build_template_from_function( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = ( - get_default_factory( - module=_class.__base__.__module__, function=value_ - ) + variables[class_field_items]["default"] = get_default_factory( + module=_class.__base__.__module__, function=value_ ) except Exception: variables[class_field_items]["default"] = None @@ -52,9 +46,7 @@ def build_template_from_function( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) # Adding function to base classes to allow # the output to be a function @@ -69,9 +61,7 @@ def build_template_from_function( } -def build_template_from_class( - name: str, type_to_cls_dict: Dict, add_function: bool = False -): +def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False): classes = [item.__name__ for item in type_to_cls_dict.values()] # Raise error if name is not in chains @@ -95,11 +85,9 @@ def build_template_from_class( for name_, value_ in value.__repr_args__(): if name_ == "default_factory": try: - variables[class_field_items]["default"] = ( - get_default_factory( - module=_class.__base__.__module__, - function=value_, - ) + variables[class_field_items]["default"] = get_default_factory( + module=_class.__base__.__module__, + function=value_, ) except Exception: variables[class_field_items]["default"] = None @@ -107,9 +95,7 @@ def build_template_from_class( variables[class_field_items][name_] = value_ variables[class_field_items]["placeholder"] = ( - docs.params[class_field_items] - if class_field_items in docs.params - else "" + docs.params[class_field_items] if class_field_items in docs.params else "" ) base_classes = get_base_classes(_class) # Adding function to base classes to allow @@ -141,9 +127,7 @@ def build_template_from_method( # Check if the method exists in this class if not hasattr(_class, method_name): - raise ValueError( - f"Method {method_name} not found in class {class_name}" - ) + raise ValueError(f"Method {method_name} not found in class {class_name}") # Get the method method = getattr(_class, method_name) @@ -162,14 +146,8 @@ def build_template_from_method( "_type": _type, **{ name: { - "default": ( - param.default if param.default != param.empty else None - ), - "type": ( - param.annotation - if param.annotation != param.empty - else None - ), + "default": (param.default if param.default != param.empty else None), + "type": (param.annotation if param.annotation != param.empty else None), "required": param.default == param.empty, } for name, param in params.items() @@ -256,9 +234,7 @@ def sync_to_async(func): return async_wrapper -def format_dict( - dictionary: Dict[str, Any], class_name: Optional[str] = None -) -> Dict[str, Any]: +def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]: """ Formats a dictionary by removing certain keys and modifying the values of other keys. @@ -344,9 +320,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str: The modified type string. """ if any(list_type in _type for list_type in ["List", "Sequence", "Set"]): - _type = ( - _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] - ) + _type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1] value["list"] = True else: value["list"] = False @@ -449,9 +423,7 @@ def set_headers_value(value: Dict[str, Any]) -> None: value["value"] = """{"Authorization": "Bearer "}""" -def add_options_to_field( - value: Dict[str, Any], class_name: Optional[str], key: str -) -> None: +def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None: """ Adds options to the field based on the class name and key. """