feat: Add ruff rules for comprehensions (C4) (#3958)

Add ruff rules for comprehensions (C4)
This commit is contained in:
Christophe Bornet 2024-09-30 22:00:00 +02:00 committed by GitHub
commit e82de5d89c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 23 additions and 23 deletions

View file

@ -27,7 +27,7 @@ class ChatComponent(Component):
if hasattr(self, "_event_manager") and self._event_manager and stored_message.id:
if not isinstance(message.text, str):
complete_message = self._stream_message(message, stored_message.id)
message_table = update_message(message_id=stored_message.id, message=dict(text=complete_message))
message_table = update_message(message_id=stored_message.id, message={"text": complete_message})
stored_message = Message(**message_table.model_dump())
self.vertex._added_message = stored_message
self.status = stored_message

View file

@ -18,7 +18,7 @@ class ShouldRunNextComponent(CustomComponent):
error_message = ""
for i in range(retries):
result = chain.invoke(
dict(question=question, context=context, error_message=error_message),
{"question": question, "context": context, "error_message": error_message},
config={"callbacks": self.get_langchain_callbacks()},
)
if isinstance(result, BaseMessage):

View file

@ -65,7 +65,7 @@ class LangChainHubPromptComponent(Component):
full_template = full_template + "\n" + message.template
# No need to reprocess if we have them already
if all(["param_" + custom_field in build_config for custom_field in custom_fields]):
if all("param_" + custom_field in build_config for custom_field in custom_fields):
return build_config
# Easter egg: Show template in info popup

View file

@ -28,7 +28,7 @@ class ComposioAPIComponent(LCToolComponent):
DropdownInput(
name="app_names",
display_name="App Name",
options=[app_name for app_name in App.__annotations__],
options=list(App.__annotations__),
value="",
info="The app name to use. Please refresh after selecting app name",
refresh_button=True,
@ -128,7 +128,7 @@ class ComposioAPIComponent(LCToolComponent):
def _get_connected_app_names_for_entity(self) -> list[str]:
toolset = self._build_wrapper()
connections = toolset.client.get_entity(id=self.entity_id).get_connections()
return list(set(connection.appUniqueId for connection in connections))
return list({connection.appUniqueId for connection in connections})
def _update_app_names_with_connected_status(self, build_config: dict) -> dict:
connected_app_names = self._get_connected_app_names_for_entity()
@ -157,7 +157,7 @@ class ComposioAPIComponent(LCToolComponent):
build_config["auth_status_config"]["value"] = self._check_for_authorization(
self._get_normalized_app_name()
)
all_action_names = [action_name for action_name in Action.__annotations__]
all_action_names = list(Action.__annotations__)
app_action_names = [
action_name
for action_name in all_action_names

View file

@ -1875,7 +1875,7 @@ class Graph:
def __to_dict(self) -> dict[str, dict[str, list[str]]]:
"""Converts the graph to a dictionary."""
result: dict = dict()
result: dict = {}
for vertex in self.vertices:
vertex_id = vertex.id
sucessors = [i.id for i in self.get_all_successors(vertex)]
@ -1922,7 +1922,7 @@ class Graph:
first_layer = vertices_layers[0]
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.vertices_to_run = set(chain.from_iterable(vertices_layers))
self.build_run_map()
# Return just the first layer
self._first_layer = first_layer

View file

@ -56,7 +56,7 @@ class SizedLogBuffer:
return len(self.buffer)
def get_after_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]:
rc = dict()
rc = {}
self._rsemaphore.acquire()
try:

View file

@ -81,7 +81,7 @@ def get_message(payload):
def build_output_logs(vertex, result) -> dict:
outputs: dict[str, OutputValue] = dict()
outputs: dict[str, OutputValue] = {}
component_instance = result[0]
for index, output in enumerate(vertex.outputs):
if component_instance.status is None:

View file

@ -22,7 +22,7 @@ class Folder(FolderBase, table=True): # type: ignore
parent: Optional["Folder"] = Relationship(
back_populates="children",
sa_relationship_kwargs=dict(remote_side="Folder.id"),
sa_relationship_kwargs={"remote_side": "Folder.id"},
)
children: list["Folder"] = Relationship(back_populates="parent")
user_id: UUID | None = Field(default=None, foreign_key="user.id")

View file

@ -15,7 +15,7 @@ def get_transactions_by_flow_id(db: Session, flow_id: UUID, limit: int | None =
)
transactions = db.exec(stmt)
return [t for t in transactions]
return list(transactions)
def log_transaction(db: Session, transaction: TransactionBase) -> TransactionTable:

View file

@ -15,7 +15,7 @@ def get_vertex_builds_by_flow_id(db: Session, flow_id: UUID, limit: int | None =
)
builds = db.exec(stmt)
return [t for t in builds]
return list(builds)
def log_vertex_build(db: Session, vertex_build: VertexBuildBase) -> VertexBuildTable:

View file

@ -43,7 +43,7 @@ class ListComponentResponse(BaseModel):
# if so, return v else transform to TagResponse
if not v:
return v
if all(["id" in tag and "name" in tag for tag in v]):
if all("id" in tag and "name" in tag for tag in v):
return v
else:
return [TagResponse(**tag.get("tags_id")) for tag in v if tag.get("tags_id")]

View file

@ -74,7 +74,7 @@ class Metric:
self.unit = unit
self.labels = labels
self.mandatory_labels = [label for label, required in labels.items() if required]
self.allowed_labels = [label for label in labels.keys()]
self.allowed_labels = list(labels.keys())
def validate_labels(self, labels: Mapping[str, str]):
"""
@ -109,7 +109,7 @@ class ThreadSafeSingletonMetaUsingWeakref(type):
class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref):
_metrics_registry: dict[str, Metric] = dict()
_metrics_registry: dict[str, Metric] = {}
def _add_metric(self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool]):
metric = Metric(name=name, description=description, type=metric_type, unit=unit, labels=labels)

View file

@ -64,7 +64,7 @@ class FrontendNode(BaseModel):
def process_base_classes(self, base_classes: list[str]) -> list[str]:
"""Removes unwanted base classes from the list of base classes."""
sorted_base_classes = sorted(list(set(base_classes)), key=lambda x: x.lower())
sorted_base_classes = sorted(set(base_classes), key=lambda x: x.lower())
return sorted_base_classes
@field_serializer("display_name")
@ -123,7 +123,7 @@ class FrontendNode(BaseModel):
input_names = [input_.name for input_ in self.template.fields]
overlap = set(output_names).intersection(input_names)
if overlap:
overlap_str = ", ".join(map(lambda x: f"'{x}'", overlap))
overlap_str = ", ".join(f"'{x}'" for x in overlap)
raise ValueError(
f"There should be no overlap between input and output names. Names {overlap_str} are duplicated."
)
@ -151,10 +151,10 @@ class FrontendNode(BaseModel):
input_overlap = set(input_names).intersection(attributes)
error_message = ""
if output_overlap:
output_overlap_str = ", ".join(map(lambda x: f"'{x}'", output_overlap))
output_overlap_str = ", ".join(f"'{x}'" for x in output_overlap)
error_message += f"Output names {output_overlap_str} are reserved attributes.\n"
if input_overlap:
input_overlap_str = ", ".join(map(lambda x: f"'{x}'", input_overlap))
input_overlap_str = ", ".join(f"'{x}'" for x in input_overlap)
error_message += f"Input names {input_overlap_str} are reserved attributes."
def add_base_class(self, base_class: str | list[str]) -> None:

View file

@ -42,7 +42,7 @@ def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list:
_inner_arg
for _type in return_type
for _inner_arg in _type.__args__
if _inner_arg not in set((Any, type(None), type(Any)))
if _inner_arg not in {Any, type(None), type(Any)}
]
return list(return_type.__args__)
@ -81,7 +81,7 @@ def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
_inner_arg
for _type in return_type
for _inner_arg in _type.__args__
if _inner_arg not in set((Any, type(None), type(Any)))
if _inner_arg not in {Any, type(None), type(Any)}
]
return list(return_type.__args__)

View file

@ -148,7 +148,7 @@ exclude = ["langflow/alembic"]
line-length = 120
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "UP"]
select = ["C4", "E4", "E7", "E9", "F", "I", "UP"]
[build-system]
requires = ["hatchling"]