Refactor code and update documentation links
This commit is contained in:
parent
f735e50fd2
commit
eef5043045
4 changed files with 59 additions and 63 deletions
|
|
@ -59,9 +59,9 @@ def read_flows(
|
|||
if auth_settings.AUTO_LOGIN:
|
||||
flows = session.exec(
|
||||
select(Flow).where(
|
||||
(Flow.user_id == None) | (Flow.user_id == current_user.id)
|
||||
(Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa
|
||||
)
|
||||
).all() # noqa
|
||||
).all()
|
||||
else:
|
||||
flows = current_user.flows
|
||||
|
||||
|
|
@ -71,10 +71,10 @@ def read_flows(
|
|||
try:
|
||||
example_flows = session.exec(
|
||||
select(Flow).where(
|
||||
Flow.user_id == None,
|
||||
Flow.folder == STARTER_FOLDER_NAME, # noqa
|
||||
Flow.user_id == None, # noqa
|
||||
Flow.folder == STARTER_FOLDER_NAME,
|
||||
)
|
||||
).all() # noqa
|
||||
).all()
|
||||
for example_flow in example_flows:
|
||||
if example_flow.id not in flow_ids:
|
||||
flows.append(example_flow)
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from langflow.utils.util import build_loader_repr_from_records
|
|||
class RecursiveCharacterTextSplitterComponent(CustomComponent):
|
||||
display_name: str = "Recursive Character Text Splitter"
|
||||
description: str = "Split text into chunks of a specified length."
|
||||
documentation: str = (
|
||||
"https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
|
||||
)
|
||||
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -77,13 +77,17 @@ class CustomComponent(Component):
|
|||
|
||||
def update_state(self, name: str, value: Any):
|
||||
try:
|
||||
self.vertex.graph.update_state(name=name, record=value, caller=self.vertex.id)
|
||||
self.vertex.graph.update_state(
|
||||
name=name, record=value, caller=self.vertex.id
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error updating state: {e}")
|
||||
|
||||
def append_state(self, name: str, value: Any):
|
||||
try:
|
||||
self.vertex.graph.append_state(name=name, record=value, caller=self.vertex.id)
|
||||
self.vertex.graph.append_state(
|
||||
name=name, record=value, caller=self.vertex.id
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error appending state: {e}")
|
||||
|
||||
|
|
@ -134,7 +138,9 @@ class CustomComponent(Component):
|
|||
def build_config(self):
|
||||
return self.field_config
|
||||
|
||||
def update_build_config(self, build_config: dict, field_name: str, field_value: Any):
|
||||
def update_build_config(
|
||||
self, build_config: dict, field_name: str, field_value: Any
|
||||
):
|
||||
build_config[field_name] = field_value
|
||||
return build_config
|
||||
|
||||
|
|
@ -142,7 +148,9 @@ class CustomComponent(Component):
|
|||
def tree(self):
|
||||
return self.get_code_tree(self.code or "")
|
||||
|
||||
def to_records(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Record]:
|
||||
def to_records(
|
||||
self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False
|
||||
) -> List[Record]:
|
||||
"""
|
||||
Converts input data into a list of Record objects.
|
||||
|
||||
|
|
@ -191,7 +199,9 @@ class CustomComponent(Component):
|
|||
|
||||
return records
|
||||
|
||||
def create_references_from_records(self, records: List[Record], include_data: bool = False) -> str:
|
||||
def create_references_from_records(
|
||||
self, records: List[Record], include_data: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Create references from a list of records.
|
||||
|
||||
|
|
@ -230,14 +240,20 @@ class CustomComponent(Component):
|
|||
if not self.code:
|
||||
return {}
|
||||
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in self.tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return {}
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
return build_methods[0] if build_methods else {}
|
||||
|
|
@ -294,7 +310,9 @@ class CustomComponent(Component):
|
|||
# Retrieve and decrypt the credential by name for the current user
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
|
||||
return credential_service.get_credential(
|
||||
user_id=self._user_id or "", name=name, session=session
|
||||
)
|
||||
|
||||
return get_credential
|
||||
|
||||
|
|
@ -304,7 +322,9 @@ class CustomComponent(Component):
|
|||
credential_service = get_credential_service()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.list_credentials(user_id=self._user_id, session=session)
|
||||
return credential_service.list_credentials(
|
||||
user_id=self._user_id, session=session
|
||||
)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable."""
|
||||
|
|
@ -343,7 +363,11 @@ class CustomComponent(Component):
|
|||
if not self._flows_records:
|
||||
self.list_flows()
|
||||
if not flow_id and self._flows_records:
|
||||
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
|
||||
flow_ids = [
|
||||
flow.data["id"]
|
||||
for flow in self._flows_records
|
||||
if flow.data["name"] == flow_name
|
||||
]
|
||||
if not flow_ids:
|
||||
raise ValueError(f"Flow {flow_name} not found")
|
||||
elif len(flow_ids) > 1:
|
||||
|
|
@ -365,7 +389,9 @@ class CustomComponent(Component):
|
|||
db_service = get_db_service()
|
||||
with get_session(db_service) as session:
|
||||
flows = session.exec(
|
||||
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
|
||||
select(Flow)
|
||||
.where(Flow.user_id == self._user_id)
|
||||
.where(Flow.is_component == False) # noqa
|
||||
).all()
|
||||
|
||||
flows_records = [flow.to_record() for flow in flows]
|
||||
|
|
|
|||
|
|
@ -15,12 +15,8 @@ def remove_ansi_escape_codes(text):
|
|||
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
|
||||
|
||||
|
||||
def build_template_from_function(
|
||||
name: str, type_to_loader_dict: Dict, add_function: bool = False
|
||||
):
|
||||
classes = [
|
||||
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
|
||||
]
|
||||
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
|
||||
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
if name not in classes:
|
||||
|
|
@ -41,10 +37,8 @@ def build_template_from_function(
|
|||
for name_, value_ in value.__repr_args__():
|
||||
if name_ == "default_factory":
|
||||
try:
|
||||
variables[class_field_items]["default"] = (
|
||||
get_default_factory(
|
||||
module=_class.__base__.__module__, function=value_
|
||||
)
|
||||
variables[class_field_items]["default"] = get_default_factory(
|
||||
module=_class.__base__.__module__, function=value_
|
||||
)
|
||||
except Exception:
|
||||
variables[class_field_items]["default"] = None
|
||||
|
|
@ -52,9 +46,7 @@ def build_template_from_function(
|
|||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items]
|
||||
if class_field_items in docs.params
|
||||
else ""
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
)
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
|
|
@ -69,9 +61,7 @@ def build_template_from_function(
|
|||
}
|
||||
|
||||
|
||||
def build_template_from_class(
|
||||
name: str, type_to_cls_dict: Dict, add_function: bool = False
|
||||
):
|
||||
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
|
|
@ -95,11 +85,9 @@ def build_template_from_class(
|
|||
for name_, value_ in value.__repr_args__():
|
||||
if name_ == "default_factory":
|
||||
try:
|
||||
variables[class_field_items]["default"] = (
|
||||
get_default_factory(
|
||||
module=_class.__base__.__module__,
|
||||
function=value_,
|
||||
)
|
||||
variables[class_field_items]["default"] = get_default_factory(
|
||||
module=_class.__base__.__module__,
|
||||
function=value_,
|
||||
)
|
||||
except Exception:
|
||||
variables[class_field_items]["default"] = None
|
||||
|
|
@ -107,9 +95,7 @@ def build_template_from_class(
|
|||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items]
|
||||
if class_field_items in docs.params
|
||||
else ""
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
)
|
||||
base_classes = get_base_classes(_class)
|
||||
# Adding function to base classes to allow
|
||||
|
|
@ -141,9 +127,7 @@ def build_template_from_method(
|
|||
|
||||
# Check if the method exists in this class
|
||||
if not hasattr(_class, method_name):
|
||||
raise ValueError(
|
||||
f"Method {method_name} not found in class {class_name}"
|
||||
)
|
||||
raise ValueError(f"Method {method_name} not found in class {class_name}")
|
||||
|
||||
# Get the method
|
||||
method = getattr(_class, method_name)
|
||||
|
|
@ -162,14 +146,8 @@ def build_template_from_method(
|
|||
"_type": _type,
|
||||
**{
|
||||
name: {
|
||||
"default": (
|
||||
param.default if param.default != param.empty else None
|
||||
),
|
||||
"type": (
|
||||
param.annotation
|
||||
if param.annotation != param.empty
|
||||
else None
|
||||
),
|
||||
"default": (param.default if param.default != param.empty else None),
|
||||
"type": (param.annotation if param.annotation != param.empty else None),
|
||||
"required": param.default == param.empty,
|
||||
}
|
||||
for name, param in params.items()
|
||||
|
|
@ -256,9 +234,7 @@ def sync_to_async(func):
|
|||
return async_wrapper
|
||||
|
||||
|
||||
def format_dict(
|
||||
dictionary: Dict[str, Any], class_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Formats a dictionary by removing certain keys and modifying the
|
||||
values of other keys.
|
||||
|
|
@ -344,9 +320,7 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
|
|||
The modified type string.
|
||||
"""
|
||||
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
|
||||
_type = (
|
||||
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
|
||||
)
|
||||
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
|
||||
value["list"] = True
|
||||
else:
|
||||
value["list"] = False
|
||||
|
|
@ -449,9 +423,7 @@ def set_headers_value(value: Dict[str, Any]) -> None:
|
|||
value["value"] = """{"Authorization": "Bearer <token>"}"""
|
||||
|
||||
|
||||
def add_options_to_field(
|
||||
value: Dict[str, Any], class_name: Optional[str], key: str
|
||||
) -> None:
|
||||
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
|
||||
"""
|
||||
Adds options to the field based on the class name and key.
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue