Refactor CustomComponent class methods and improve code readability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-29 19:27:16 -03:00
commit 925127a5a3

View file

@ -119,7 +119,9 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(self, data: Any, text_key: str = "text", data_key: str = "data") -> List[Record]:
def to_records(
self, data: Any, text_key: str = "text", data_key: str = "data"
) -> List[Record]:
"""
Convert data into a list of records.
@ -146,7 +148,9 @@ class CustomComponent(Component):
return records
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
def create_references_from_records(
self, records: List[Record], include_data: bool = False
) -> str:
"""
Create references from a list of records.
@ -181,7 +185,8 @@ class CustomComponent(Component):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
),
},
)
@ -195,14 +200,20 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -259,7 +270,9 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return get_credential
@ -269,7 +282,9 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -313,7 +328,9 @@ class CustomComponent(Component):
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(select(Flow).where(Flow.user_id == self._user_id)).all()
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id)
).all()
return flows
except Exception as e:
raise ValueError("Session is invalid") from e