ref: Some ruff rule fixes from preview mode (#4131)

Some ruff rule fixes from preview mode
This commit is contained in:
Christophe Bornet 2024-10-14 16:42:09 +02:00 committed by GitHub
commit 3e181b91c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 54 additions and 54 deletions

View file

@ -166,7 +166,7 @@ def update_folder(
folder_data = existing_folder.model_dump(exclude_unset=True)
for key, value in folder_data.items():
if key not in ("components", "flows"):
if key not in {"components", "flows"}:
setattr(existing_folder, key, value)
session.add(existing_folder)
session.commit()

View file

@ -109,7 +109,7 @@ class ChatResponse(ChatMessage):
@field_validator("type")
@classmethod
def validate_message_type(cls, v):
if v not in ["start", "stream", "end", "error", "info", "file"]:
if v not in {"start", "stream", "end", "error", "info", "file"}:
msg = "type must be start, stream, end, error, info, or file"
raise ValueError(msg)
return v
@ -134,7 +134,7 @@ class FileResponse(ChatMessage):
@field_validator("data_type")
@classmethod
def validate_data_type(cls, v):
if v not in ["image", "csv"]:
if v not in {"image", "csv"}:
msg = "data_type must be image or csv"
raise ValueError(msg)
return v

View file

@ -83,30 +83,30 @@ def parse_curl_command(curl_command):
i += 1
args["method"] = tokens[i].lower()
method_on_curl = tokens[i].lower()
elif token in ("-d", "--data"):
elif token in {"-d", "--data"}:
i += 1
args["data"] = tokens[i]
elif token in ("-b", "--data-binary", "--data-raw"):
elif token in {"-b", "--data-binary", "--data-raw"}:
i += 1
args["data_binary"] = tokens[i]
elif token in ("-H", "--header"):
elif token in {"-H", "--header"}:
i += 1
args["headers"].append(tokens[i])
elif token == "--compressed":
args["compressed"] = True
elif token in ("-k", "--insecure"):
elif token in {"-k", "--insecure"}:
args["insecure"] = True
elif token in ("-u", "--user"):
elif token in {"-u", "--user"}:
i += 1
args["user"] = tuple(tokens[i].split(":"))
elif token in ("-I", "--include"):
elif token in {"-I", "--include"}:
args["include"] = True
elif token in ("-s", "--silent"):
elif token in {"-s", "--silent"}:
args["silent"] = True
elif token in ("-x", "--proxy"):
elif token in {"-x", "--proxy"}:
i += 1
args["proxy"] = tokens[i]
elif token in ("-U", "--proxy-user"):
elif token in {"-U", "--proxy-user"}:
i += 1
args["proxy_user"] = tokens[i]
elif not token.startswith("-"):

View file

@ -100,7 +100,7 @@ def read_text_file(file_path: str) -> str:
result = chardet.detect(raw_data)
encoding = result["encoding"]
if encoding in ["Windows-1252", "Windows-1254", "MacRoman"]:
if encoding in {"Windows-1252", "Windows-1254", "MacRoman"}:
encoding = "utf-8"
with _file_path.open(encoding=encoding) as f:

View file

@ -174,7 +174,7 @@ class LCModelComponent(Component):
inputs: list | dict = messages or {}
try:
if self.output_parser is not None:
runnable = runnable | self.output_parser
runnable |= self.output_parser
runnable = runnable.with_config(
{

View file

@ -216,7 +216,7 @@ class AddContentToPage(LCToolComponent):
block_type: {},
}
if block_type in [
if block_type in {
"paragraph",
"heading_1",
"heading_2",
@ -224,7 +224,7 @@ class AddContentToPage(LCToolComponent):
"bulleted_list_item",
"numbered_list_item",
"quote",
]:
}:
block[block_type]["rich_text"] = [
{
"type": "text",

View file

@ -72,9 +72,9 @@ class NotionPageContent(LCToolComponent):
content = ""
for block in blocks:
block_type = block.get("type")
if block_type in ["paragraph", "heading_1", "heading_2", "heading_3", "quote"]:
if block_type in {"paragraph", "heading_1", "heading_2", "heading_3", "quote"}:
content += self.parse_rich_text(block[block_type].get("rich_text", [])) + "\n\n"
elif block_type in ["bulleted_list_item", "numbered_list_item"]:
elif block_type in {"bulleted_list_item", "numbered_list_item"}:
content += self.parse_rich_text(block[block_type].get("rich_text", [])) + "\n"
elif block_type == "to_do":
content += self.parse_rich_text(block["to_do"].get("rich_text", [])) + "\n"

View file

@ -8,12 +8,12 @@ from .list_assistants import AssistantsListAssistants
from .run import AssistantsRun
__all__ = [
"AstraAssistantManager",
"AssistantsCreateAssistant",
"AssistantsCreateThread",
"AssistantsGetAssistantName",
"AssistantsListAssistants",
"AssistantsRun",
"AstraAssistantManager",
"Dotenv",
"GetEnvVar",
]

View file

@ -112,7 +112,7 @@ class APIRequestComponent(Component):
timeout: int = 5,
) -> Data:
method = method.upper()
if method not in ["GET", "POST", "PATCH", "PUT", "DELETE"]:
if method not in {"GET", "POST", "PATCH", "PUT", "DELETE"}:
msg = f"Unsupported method: {method}"
raise ValueError(msg)

View file

@ -31,7 +31,7 @@ class ShouldRunNextComponent(CustomComponent):
content = result.content
elif isinstance(result, str):
content = result
if isinstance(content, str) and content.lower().strip() in ["yes", "no"]:
if isinstance(content, str) and content.lower().strip() in {"yes", "no"}:
break
condition = str(content).lower().strip() == "yes"
self.status = f"Should Run Next: {condition}"

View file

@ -41,7 +41,7 @@ class SubFlowComponent(CustomComponent):
build_config["flow_name"]["options"] = self.get_flow_names()
# Clean up the build config
for key in list(build_config.keys()):
if key not in [*self.field_order, "code", "_type", "get_final_results_only"]:
if key not in {*self.field_order, "code", "_type", "get_final_results_only"}:
del build_config[key]
if field_value is not None and field_name == "flow_name":
try:

View file

@ -62,7 +62,7 @@ class DataConditionalRouterComponent(Component):
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ["true", "1", "yes", "y", "on"]
return value.lower() in {"true", "1", "yes", "y", "on"}
return bool(value)
def validate_input(self, data_item: Data) -> bool:

View file

@ -76,7 +76,7 @@ class GroqModel(LCModelComponent):
return []
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
if field_name in ("groq_api_key", "groq_api_base", "model_name"):
if field_name in {"groq_api_key", "groq_api_base", "model_name"}:
models = self.get_models()
build_config["model_name"]["options"] = models
return build_config

View file

@ -62,7 +62,7 @@ class LangChainHubPromptComponent(Component):
for message in prompt_template:
# Find all matches
matches = re.findall(pattern, message.template)
custom_fields = custom_fields + matches
custom_fields += matches
# Create a string version of the full template
full_template = full_template + "\n" + message.template

View file

@ -63,7 +63,7 @@ class SubFlowComponent(Component):
new_vertex_inputs = []
field_template = vertex.data["node"]["template"]
for inp in field_template:
if inp not in ["code", "_type"]:
if inp not in {"code", "_type"}:
field_template[inp]["display_name"] = (
vertex.display_name + " - " + field_template[inp]["display_name"]
)

View file

@ -87,7 +87,7 @@ class PythonCodeStructuredTool(LCToolComponent):
if field_name is None:
return build_config
if field_name not in ("tool_code", "tool_function"):
if field_name not in {"tool_code", "tool_function"}:
return build_config
try:

View file

@ -55,7 +55,7 @@ class RedisVectorStoreComponent(LCVectorStoreComponent):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
with Path("docuemnts.txt").open("w") as f:
with Path("docuemnts.txt").open("w", encoding="utf-8") as f:
f.write(str(documents))
if not documents:

View file

@ -32,7 +32,7 @@ def find_class_ast_node(class_obj):
return None, []
# Read the source code from the file
with Path(source_file).open() as file:
with Path(source_file).open(encoding="utf-8") as file:
source_code = file.read()
# Parse the source code into an AST
@ -333,7 +333,7 @@ class CodeParser:
bases = self.get_base_classes()
nodes = []
for base in bases:
if base.__name__ == node.name or base.__name__ in ["CustomComponent", "Component", "BaseComponent"]:
if base.__name__ == node.name or base.__name__ in {"CustomComponent", "Component", "BaseComponent"}:
continue
try:
class_node, import_nodes = find_class_ast_node(base)

View file

@ -131,7 +131,7 @@ def process_type(field_type: str):
# field_type is a string can be Prompt or Code too
# so we just need to lower if it is the case
lowercase_type = field_type.lower()
if lowercase_type in ["prompt", "code"]:
if lowercase_type in {"prompt", "code"}:
return lowercase_type
return field_type
@ -206,11 +206,11 @@ def add_extra_fields(frontend_node, field_config, function_args):
# then we need to add the extra fields
for extra_field in function_args:
if "name" not in extra_field or extra_field["name"] in [
if "name" not in extra_field or extra_field["name"] in {
"self",
"kwargs",
"args",
]:
}:
continue
field_name, field_type, field_value, field_required = get_field_properties(extra_field)

View file

@ -358,7 +358,7 @@ class Graph:
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs, max_iterations, event_manager)
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
async_gen_task = asyncio.ensure_future(anext(async_gen))
while True:
try:
@ -366,7 +366,7 @@ class Graph:
yield result
if isinstance(result, Finish):
return
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
async_gen_task = asyncio.ensure_future(anext(async_gen))
except StopAsyncIteration:
break
@ -1058,7 +1058,7 @@ class Graph:
"""Updates the edges of a vertex in the Graph."""
new_edges = []
for edge in self.edges:
if other_vertex.id in (edge.source_id, edge.target_id):
if other_vertex.id in {edge.source_id, edge.target_id}:
continue
new_edges.append(edge)
new_edges += other_vertex.edges
@ -1210,7 +1210,7 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if vertex_id not in (edge.source_id, edge.target_id)]
self.edges = [edge for edge in self.edges if vertex_id not in {edge.source_id, edge.target_id}]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -1707,7 +1707,7 @@ class Graph:
node_name = node_id.split("-")[0]
if node_name in InterfaceComponentTypes:
return InterfaceVertex
if node_name in ["SharedState", "Notify", "Listen"]:
if node_name in {"SharedState", "Notify", "Listen"}:
return StateVertex
if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]

View file

@ -235,7 +235,7 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
if new_edge["source"] == group_node_id:
new_edge = update_source_handle(new_edge, g_nodes, g_edges)
if group_node_id in (edge["target"], edge["source"]):
if group_node_id in {edge["target"], edge["source"]}:
updated_edges.append(new_edge)
return updated_edges

View file

@ -380,7 +380,7 @@ class Vertex:
except Exception: # noqa: BLE001
logger.debug(f"Error evaluating code for {field_name}")
params[field_name] = val
elif field.get("type") in ["dict", "NestedDict"]:
elif field.get("type") in {"dict", "NestedDict"}:
# When dict comes from the frontend it comes as a
# list of dicts, so we need to convert it to a dict
# before passing it to the build method

View file

@ -27,13 +27,13 @@ def update_memory_keys(langchain_object, possible_new_mem_key):
input_key = next(
key
for key in langchain_object.input_keys
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
if key not in {langchain_object.memory.memory_key, possible_new_mem_key}
)
output_key = next(
key
for key in langchain_object.output_keys
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
if key not in {langchain_object.memory.memory_key, possible_new_mem_key}
)
for key, attr in [(input_key, "input_key"), (output_key, "output_key"), (possible_new_mem_key, "memory_key")]:

View file

@ -28,7 +28,7 @@ def upload(file_path: str, host: str, flow_id: str):
url = f"{host}/api/v1/upload/{flow_id}"
with Path(file_path).open("rb") as file:
response = httpx.post(url, files={"file": file})
if response.status_code in (httpx.codes.OK, httpx.codes.CREATED):
if response.status_code in {httpx.codes.OK, httpx.codes.CREATED}:
return response.json()
except Exception as e:
msg = f"Error uploading file: {e}"

View file

@ -31,9 +31,9 @@ class Column(BaseModel):
@field_validator("formatter", mode="before")
@classmethod
def validate_formatter(cls, value):
if value in ["integer", "int", "float"]:
if value in {"integer", "int", "float"}:
value = FormatterType.number
if value in ["str", "string"]:
if value in {"str", "string"}:
value = FormatterType.text
if value == "dict":
value = FormatterType.json

View file

@ -37,7 +37,7 @@ class PluginService(Service):
if (
inspect.isclass(attr)
and issubclass(attr, BasePlugin)
and attr not in [CallbackPlugin, BasePlugin]
and attr not in {CallbackPlugin, BasePlugin}
):
self.register_plugin(plugin_name, attr())
except Exception: # noqa: BLE001

View file

@ -365,7 +365,7 @@ class Settings(BaseSettings):
def save_settings_to_yaml(settings: Settings, file_path: str):
with Path(file_path).open("w") as f:
with Path(file_path).open("w", encoding="utf-8") as f:
settings_dict = settings.model_dump()
yaml.dump(settings_dict, f)

View file

@ -5,7 +5,7 @@ from loguru import logger
def set_secure_permissions(file_path: Path):
if platform.system() in ["Linux", "Darwin"]: # Unix/Linux/Mac
if platform.system() in {"Linux", "Darwin"}: # Unix/Linux/Mac
file_path.chmod(0o600)
elif platform.system() == "Windows":
import win32api

View file

@ -116,7 +116,7 @@ class StoreService(Service):
return "id" in user_data[0]
except HTTPStatusError as exc:
if exc.response.status_code in [403, 401]:
if exc.response.status_code in {403, 401}:
return False
msg = f"Unexpected status code: {exc.response.status_code}"
raise ValueError(msg) from exc

View file

@ -99,7 +99,7 @@ class Input(BaseModel):
def serialize_model(self, handler):
result = handler(self)
# If the field is str, we add the Text input type
if self.field_type in ["str", "Text"] and "input_types" not in result:
if self.field_type in {"str", "Text"} and "input_types" not in result:
result["input_types"] = ["Text"]
if self.field_type == Text:
result["type"] = "str"

View file

@ -9,7 +9,7 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
"""
Extracts the inner type from a type hint that is a list or a Optional.
"""
if return_type.__origin__ in [list, SequenceABC]:
if return_type.__origin__ in {list, SequenceABC}:
return list(return_type.__args__)
return return_type
@ -59,7 +59,7 @@ def post_process_type(_type):
Union[List[Any], Any]: The processed return type.
"""
if hasattr(_type, "__origin__") and _type.__origin__ in [list, list, SequenceABC]:
if hasattr(_type, "__origin__") and _type.__origin__ in {list, list, SequenceABC}:
_type = extract_inner_type_from_generic_alias(_type)
# If the return type is not a Union, then we just return it as a list

View file

@ -116,7 +116,7 @@ def build_template_from_method(
"required": param.default == param.empty,
}
for name, param in params.items()
if name not in ["self", "kwargs", "args"]
if name not in {"self", "kwargs", "args"}
},
}

View file

@ -80,7 +80,7 @@ def fetch_latest_version(package_name: str, include_prerelease: bool) -> str | N
valid_versions = [v for v in versions if include_prerelease or not is_pre_release(v)]
if not valid_versions:
return None # Handle case where no valid versions are found
return max(valid_versions, key=lambda v: pkg_version.parse(v))
return max(valid_versions, key=pkg_version.parse)
except Exception: # noqa: BLE001
logger.exception("Error fetching latest version")