From 2930b4ed1aa4a2c2da15ff61434e06278fd5e5b6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 7 Mar 2024 12:26:57 -0300 Subject: [PATCH] Update CustomComponent class methods to use dotdict --- .../custom_component/custom_component.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index f0c3bfa80..7c2b6b873 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -24,6 +24,7 @@ from langflow.interface.custom.code_parser.utils import ( ) from langflow.interface.custom.custom_component.component import Component from langflow.schema import Record +from langflow.schema.dotdict import dotdict from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter from langflow.services.deps import ( @@ -77,13 +78,17 @@ class CustomComponent(Component): def update_state(self, name: str, value: Any): try: - self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id) + self.vertex.graph.update_state( + name=name, record=value, caller=self.vertex.id + ) except Exception as e: raise ValueError(f"Error updating state: {e}") def append_state(self, name: str, value: Any): try: - self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id) + self.vertex.graph.append_state( + name=name, record=value, caller=self.vertex.id + ) except Exception as e: raise ValueError(f"Error appending state: {e}") @@ -134,7 +139,9 @@ class CustomComponent(Component): def build_config(self): return self.field_config - def update_build_config(self, build_config: dict, field_name: str, field_value: Any): + def update_build_config( + self, build_config: dotdict, field_name: str, field_value: Any + ): build_config[field_name] = field_value return build_config @@ -142,7 +149,9 @@ class CustomComponent(Component): def tree(self): return self.get_code_tree(self.code or "") - def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]: + def to_records( + self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False + ) -> List[Record]: """ Converts input data into a list of Record objects. @@ -191,7 +200,9 @@ class CustomComponent(Component): return records - def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str: + def create_references_from_records( + self, records: List[Record], include_data: bool = False + ) -> str: """ Create references from a list of records. @@ -230,14 +241,20 @@ class CustomComponent(Component): if not self.code: return {} - component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]] + component_classes = [ + cls + for cls in self.tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] if not component_classes: return {} # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name ] return build_methods[0] if build_methods else {} @@ -294,7 +311,9 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) + return credential_service.get_credential( + user_id=self._user_id or "", name=name, session=session + ) return get_credential @@ -304,7 +323,9 @@ class CustomComponent(Component): credential_service = get_credential_service() db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.list_credentials(user_id=self._user_id, session=session) + return credential_service.list_credentials( + user_id=self._user_id, session=session + ) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable.""" @@ -343,7 +364,11 @@ class CustomComponent(Component): if not self._flows_records: self.list_flows() if not flow_id and self._flows_records: - flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name] + flow_ids = [ + flow.data["id"] + for flow in self._flows_records + if flow.data["name"] == flow_name + ] if not flow_ids: raise ValueError(f"Flow {flow_name} not found") elif len(flow_ids) > 1: @@ -365,7 +390,9 @@ class CustomComponent(Component): db_service = get_db_service() with get_session(db_service) as session: flows = session.exec( - select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa + select(Flow) + .where(Flow.user_id == self._user_id) + .where(Flow.is_component == False) # noqa ).all() flows_records = [flow.to_record() for flow in flows]