Update CustomComponent class methods to use dotdict

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-07 12:26:57 -03:00
commit 2930b4ed1a

View file

@ -24,6 +24,7 @@ from langflow.interface.custom.code_parser.utils import (
)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import (
@ -77,13 +78,17 @@ class CustomComponent(Component):
def update_state(self, name: str, value: Any):
try:
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
self.vertex.graph.update_state(
name=name, record=value, caller=self.vertex.id
)
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def append_state(self, name: str, value: Any):
try:
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
self.vertex.graph.append_state(
name=name, record=value, caller=self.vertex.id
)
except Exception as e:
raise ValueError(f"Error appending state: {e}")
@ -134,7 +139,9 @@ class CustomComponent(Component):
def build_config(self):
return self.field_config
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
def update_build_config(
self, build_config: dotdict, field_name: str, field_value: Any
):
build_config[field_name] = field_value
return build_config
@ -142,7 +149,9 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]:
def to_records(
self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False
) -> List[Record]:
"""
Converts input data into a list of Record objects.
@ -191,7 +200,9 @@ class CustomComponent(Component):
return records
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
def create_references_from_records(
self, records: List[Record], include_data: bool = False
) -> str:
"""
Create references from a list of records.
@ -230,14 +241,20 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -294,7 +311,9 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return get_credential
@ -304,7 +323,9 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -343,7 +364,11 @@ class CustomComponent(Component):
if not self._flows_records:
self.list_flows()
if not flow_id and self._flows_records:
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
flow_ids = [
flow.data["id"]
for flow in self._flows_records
if flow.data["name"] == flow_name
]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
@ -365,7 +390,9 @@ class CustomComponent(Component):
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
select(Flow)
.where(Flow.user_id == self._user_id)
.where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]