Refactor code and update documentation links

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 15:33:42 -03:00
commit eef5043045
4 changed files with 59 additions and 63 deletions

View file

@ -59,9 +59,9 @@ def read_flows(
if auth_settings.AUTO_LOGIN:
flows = session.exec(
select(Flow).where(
(Flow.user_id == None) | (Flow.user_id == current_user.id)
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
)
).all() # noqa
).all()
else:
flows = current_user.flows
@ -71,10 +71,10 @@ def read_flows(
try:
example_flows = session.exec(
select(Flow).where(
Flow.user_id == None,
Flow.folder == STARTER_FOLDER_NAME, # noqa
Flow.user_id == None, # noqa
Flow.folder == STARTER_FOLDER_NAME,
)
).all() # noqa
).all()
for example_flow in example_flows:
if example_flow.id not in flow_ids:
flows.append(example_flow)

View file

@ -11,9 +11,7 @@ from langflow.utils.util import build_loader_repr_from_records
class RecursiveCharacterTextSplitterComponent(CustomComponent):
display_name: str = "Recursive Character Text Splitter"
description: str = "Split text into chunks of a specified length."
documentation: str = (
"https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
)
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
def build_config(self):
return {

View file

@ -77,13 +77,17 @@ class CustomComponent(Component):
def update_state(self, name: str, value: Any):
try:
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
self.vertex.graph.update_state(
name=name, record=value, caller=self.vertex.id
)
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def append_state(self, name: str, value: Any):
try:
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
self.vertex.graph.append_state(
name=name, record=value, caller=self.vertex.id
)
except Exception as e:
raise ValueError(f"Error appending state: {e}")
@ -134,7 +138,9 @@ class CustomComponent(Component):
def build_config(self):
return self.field_config
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
def update_build_config(
self, build_config: dict, field_name: str, field_value: Any
):
build_config[field_name] = field_value
return build_config
@ -142,7 +148,9 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]:
def to_records(
self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False
) -> List[Record]:
"""
Converts input data into a list of Record objects.
@ -191,7 +199,9 @@ class CustomComponent(Component):
return records
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
def create_references_from_records(
self, records: List[Record], include_data: bool = False
) -> str:
"""
Create references from a list of records.
@ -230,14 +240,20 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -294,7 +310,9 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return get_credential
@ -304,7 +322,9 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -343,7 +363,11 @@ class CustomComponent(Component):
if not self._flows_records:
self.list_flows()
if not flow_id and self._flows_records:
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
flow_ids = [
flow.data["id"]
for flow in self._flows_records
if flow.data["name"] == flow_name
]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
@ -365,7 +389,9 @@ class CustomComponent(Component):
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
select(Flow)
.where(Flow.user_id == self._user_id)
.where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]

View file

@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
def build_template_from_function(
name: str, type_to_loader_dict: Dict, add_function: bool = False
):
classes = [
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
]
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
# Raise error if name is not in chains
if name not in classes:
@ -41,10 +37,8 @@ def build_template_from_function(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__, function=value_
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__, function=value_
)
except Exception:
variables[class_field_items]["default"] = None
@ -52,9 +46,7 @@ def build_template_from_function(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
# Adding function to base classes to allow
# the output to be a function
@ -69,9 +61,7 @@ def build_template_from_function(
}
def build_template_from_class(
name: str, type_to_cls_dict: Dict, add_function: bool = False
):
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
classes = [item.__name__ for item in type_to_cls_dict.values()]
# Raise error if name is not in chains
@ -95,11 +85,9 @@ def build_template_from_class(
for name_, value_ in value.__repr_args__():
if name_ == "default_factory":
try:
variables[class_field_items]["default"] = (
get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
variables[class_field_items]["default"] = get_default_factory(
module=_class.__base__.__module__,
function=value_,
)
except Exception:
variables[class_field_items]["default"] = None
@ -107,9 +95,7 @@ def build_template_from_class(
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items]
if class_field_items in docs.params
else ""
docs.params[class_field_items] if class_field_items in docs.params else ""
)
base_classes = get_base_classes(_class)
# Adding function to base classes to allow
@ -141,9 +127,7 @@ def build_template_from_method(
# Check if the method exists in this class
if not hasattr(_class, method_name):
raise ValueError(
f"Method {method_name} not found in class {class_name}"
)
raise ValueError(f"Method {method_name} not found in class {class_name}")
# Get the method
method = getattr(_class, method_name)
@ -162,14 +146,8 @@ def build_template_from_method(
"_type": _type,
**{
name: {
"default": (
param.default if param.default != param.empty else None
),
"type": (
param.annotation
if param.annotation != param.empty
else None
),
"default": (param.default if param.default != param.empty else None),
"type": (param.annotation if param.annotation != param.empty else None),
"required": param.default == param.empty,
}
for name, param in params.items()
@ -256,9 +234,7 @@ def sync_to_async(func):
return async_wrapper
def format_dict(
dictionary: Dict[str, Any], class_name: Optional[str] = None
) -> Dict[str, Any]:
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
"""
Formats a dictionary by removing certain keys and modifying the
values of other keys.
@ -344,9 +320,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
The modified type string.
"""
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
_type = (
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
)
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
value["list"] = True
else:
value["list"] = False
@ -449,9 +423,7 @@ def set_headers_value(value: Dict[str, Any]) -> None:
value["value"] = """{"Authorization": "Bearer <token>"}"""
def add_options_to_field(
value: Dict[str, Any], class_name: Optional[str], key: str
) -> None:
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
"""
Adds options to the field based on the class name and key.
"""