ref: Some ruff rule fixes from preview mode (#4131)
Some ruff rule fixes from preview mode
This commit is contained in:
parent
c7d80f3bc7
commit
3e181b91c6
33 changed files with 54 additions and 54 deletions
|
|
@ -166,7 +166,7 @@ def update_folder(
|
|||
|
||||
folder_data = existing_folder.model_dump(exclude_unset=True)
|
||||
for key, value in folder_data.items():
|
||||
if key not in ("components", "flows"):
|
||||
if key not in {"components", "flows"}:
|
||||
setattr(existing_folder, key, value)
|
||||
session.add(existing_folder)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class ChatResponse(ChatMessage):
|
|||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_message_type(cls, v):
|
||||
if v not in ["start", "stream", "end", "error", "info", "file"]:
|
||||
if v not in {"start", "stream", "end", "error", "info", "file"}:
|
||||
msg = "type must be start, stream, end, error, info, or file"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
|
@ -134,7 +134,7 @@ class FileResponse(ChatMessage):
|
|||
@field_validator("data_type")
|
||||
@classmethod
|
||||
def validate_data_type(cls, v):
|
||||
if v not in ["image", "csv"]:
|
||||
if v not in {"image", "csv"}:
|
||||
msg = "data_type must be image or csv"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -83,30 +83,30 @@ def parse_curl_command(curl_command):
|
|||
i += 1
|
||||
args["method"] = tokens[i].lower()
|
||||
method_on_curl = tokens[i].lower()
|
||||
elif token in ("-d", "--data"):
|
||||
elif token in {"-d", "--data"}:
|
||||
i += 1
|
||||
args["data"] = tokens[i]
|
||||
elif token in ("-b", "--data-binary", "--data-raw"):
|
||||
elif token in {"-b", "--data-binary", "--data-raw"}:
|
||||
i += 1
|
||||
args["data_binary"] = tokens[i]
|
||||
elif token in ("-H", "--header"):
|
||||
elif token in {"-H", "--header"}:
|
||||
i += 1
|
||||
args["headers"].append(tokens[i])
|
||||
elif token == "--compressed":
|
||||
args["compressed"] = True
|
||||
elif token in ("-k", "--insecure"):
|
||||
elif token in {"-k", "--insecure"}:
|
||||
args["insecure"] = True
|
||||
elif token in ("-u", "--user"):
|
||||
elif token in {"-u", "--user"}:
|
||||
i += 1
|
||||
args["user"] = tuple(tokens[i].split(":"))
|
||||
elif token in ("-I", "--include"):
|
||||
elif token in {"-I", "--include"}:
|
||||
args["include"] = True
|
||||
elif token in ("-s", "--silent"):
|
||||
elif token in {"-s", "--silent"}:
|
||||
args["silent"] = True
|
||||
elif token in ("-x", "--proxy"):
|
||||
elif token in {"-x", "--proxy"}:
|
||||
i += 1
|
||||
args["proxy"] = tokens[i]
|
||||
elif token in ("-U", "--proxy-user"):
|
||||
elif token in {"-U", "--proxy-user"}:
|
||||
i += 1
|
||||
args["proxy_user"] = tokens[i]
|
||||
elif not token.startswith("-"):
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def read_text_file(file_path: str) -> str:
|
|||
result = chardet.detect(raw_data)
|
||||
encoding = result["encoding"]
|
||||
|
||||
if encoding in ["Windows-1252", "Windows-1254", "MacRoman"]:
|
||||
if encoding in {"Windows-1252", "Windows-1254", "MacRoman"}:
|
||||
encoding = "utf-8"
|
||||
|
||||
with _file_path.open(encoding=encoding) as f:
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class LCModelComponent(Component):
|
|||
inputs: list | dict = messages or {}
|
||||
try:
|
||||
if self.output_parser is not None:
|
||||
runnable = runnable | self.output_parser
|
||||
runnable |= self.output_parser
|
||||
|
||||
runnable = runnable.with_config(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ class AddContentToPage(LCToolComponent):
|
|||
block_type: {},
|
||||
}
|
||||
|
||||
if block_type in [
|
||||
if block_type in {
|
||||
"paragraph",
|
||||
"heading_1",
|
||||
"heading_2",
|
||||
|
|
@ -224,7 +224,7 @@ class AddContentToPage(LCToolComponent):
|
|||
"bulleted_list_item",
|
||||
"numbered_list_item",
|
||||
"quote",
|
||||
]:
|
||||
}:
|
||||
block[block_type]["rich_text"] = [
|
||||
{
|
||||
"type": "text",
|
||||
|
|
|
|||
|
|
@ -72,9 +72,9 @@ class NotionPageContent(LCToolComponent):
|
|||
content = ""
|
||||
for block in blocks:
|
||||
block_type = block.get("type")
|
||||
if block_type in ["paragraph", "heading_1", "heading_2", "heading_3", "quote"]:
|
||||
if block_type in {"paragraph", "heading_1", "heading_2", "heading_3", "quote"}:
|
||||
content += self.parse_rich_text(block[block_type].get("rich_text", [])) + "\n\n"
|
||||
elif block_type in ["bulleted_list_item", "numbered_list_item"]:
|
||||
elif block_type in {"bulleted_list_item", "numbered_list_item"}:
|
||||
content += self.parse_rich_text(block[block_type].get("rich_text", [])) + "\n"
|
||||
elif block_type == "to_do":
|
||||
content += self.parse_rich_text(block["to_do"].get("rich_text", [])) + "\n"
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from .list_assistants import AssistantsListAssistants
|
|||
from .run import AssistantsRun
|
||||
|
||||
__all__ = [
|
||||
"AstraAssistantManager",
|
||||
"AssistantsCreateAssistant",
|
||||
"AssistantsCreateThread",
|
||||
"AssistantsGetAssistantName",
|
||||
"AssistantsListAssistants",
|
||||
"AssistantsRun",
|
||||
"AstraAssistantManager",
|
||||
"Dotenv",
|
||||
"GetEnvVar",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class APIRequestComponent(Component):
|
|||
timeout: int = 5,
|
||||
) -> Data:
|
||||
method = method.upper()
|
||||
if method not in ["GET", "POST", "PATCH", "PUT", "DELETE"]:
|
||||
if method not in {"GET", "POST", "PATCH", "PUT", "DELETE"}:
|
||||
msg = f"Unsupported method: {method}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class ShouldRunNextComponent(CustomComponent):
|
|||
content = result.content
|
||||
elif isinstance(result, str):
|
||||
content = result
|
||||
if isinstance(content, str) and content.lower().strip() in ["yes", "no"]:
|
||||
if isinstance(content, str) and content.lower().strip() in {"yes", "no"}:
|
||||
break
|
||||
condition = str(content).lower().strip() == "yes"
|
||||
self.status = f"Should Run Next: {condition}"
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class SubFlowComponent(CustomComponent):
|
|||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
# Clean up the build config
|
||||
for key in list(build_config.keys()):
|
||||
if key not in [*self.field_order, "code", "_type", "get_final_results_only"]:
|
||||
if key not in {*self.field_order, "code", "_type", "get_final_results_only"}:
|
||||
del build_config[key]
|
||||
if field_value is not None and field_name == "flow_name":
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class DataConditionalRouterComponent(Component):
|
|||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ["true", "1", "yes", "y", "on"]
|
||||
return value.lower() in {"true", "1", "yes", "y", "on"}
|
||||
return bool(value)
|
||||
|
||||
def validate_input(self, data_item: Data) -> bool:
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class GroqModel(LCModelComponent):
|
|||
return []
|
||||
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
if field_name in ("groq_api_key", "groq_api_base", "model_name"):
|
||||
if field_name in {"groq_api_key", "groq_api_base", "model_name"}:
|
||||
models = self.get_models()
|
||||
build_config["model_name"]["options"] = models
|
||||
return build_config
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class LangChainHubPromptComponent(Component):
|
|||
for message in prompt_template:
|
||||
# Find all matches
|
||||
matches = re.findall(pattern, message.template)
|
||||
custom_fields = custom_fields + matches
|
||||
custom_fields += matches
|
||||
|
||||
# Create a string version of the full template
|
||||
full_template = full_template + "\n" + message.template
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class SubFlowComponent(Component):
|
|||
new_vertex_inputs = []
|
||||
field_template = vertex.data["node"]["template"]
|
||||
for inp in field_template:
|
||||
if inp not in ["code", "_type"]:
|
||||
if inp not in {"code", "_type"}:
|
||||
field_template[inp]["display_name"] = (
|
||||
vertex.display_name + " - " + field_template[inp]["display_name"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class PythonCodeStructuredTool(LCToolComponent):
|
|||
if field_name is None:
|
||||
return build_config
|
||||
|
||||
if field_name not in ("tool_code", "tool_function"):
|
||||
if field_name not in {"tool_code", "tool_function"}:
|
||||
return build_config
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class RedisVectorStoreComponent(LCVectorStoreComponent):
|
|||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
with Path("docuemnts.txt").open("w") as f:
|
||||
with Path("docuemnts.txt").open("w", encoding="utf-8") as f:
|
||||
f.write(str(documents))
|
||||
|
||||
if not documents:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def find_class_ast_node(class_obj):
|
|||
return None, []
|
||||
|
||||
# Read the source code from the file
|
||||
with Path(source_file).open() as file:
|
||||
with Path(source_file).open(encoding="utf-8") as file:
|
||||
source_code = file.read()
|
||||
|
||||
# Parse the source code into an AST
|
||||
|
|
@ -333,7 +333,7 @@ class CodeParser:
|
|||
bases = self.get_base_classes()
|
||||
nodes = []
|
||||
for base in bases:
|
||||
if base.__name__ == node.name or base.__name__ in ["CustomComponent", "Component", "BaseComponent"]:
|
||||
if base.__name__ == node.name or base.__name__ in {"CustomComponent", "Component", "BaseComponent"}:
|
||||
continue
|
||||
try:
|
||||
class_node, import_nodes = find_class_ast_node(base)
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ def process_type(field_type: str):
|
|||
# field_type is a string can be Prompt or Code too
|
||||
# so we just need to lower if it is the case
|
||||
lowercase_type = field_type.lower()
|
||||
if lowercase_type in ["prompt", "code"]:
|
||||
if lowercase_type in {"prompt", "code"}:
|
||||
return lowercase_type
|
||||
return field_type
|
||||
|
||||
|
|
@ -206,11 +206,11 @@ def add_extra_fields(frontend_node, field_config, function_args):
|
|||
# then we need to add the extra fields
|
||||
|
||||
for extra_field in function_args:
|
||||
if "name" not in extra_field or extra_field["name"] in [
|
||||
if "name" not in extra_field or extra_field["name"] in {
|
||||
"self",
|
||||
"kwargs",
|
||||
"args",
|
||||
]:
|
||||
}:
|
||||
continue
|
||||
|
||||
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ class Graph:
|
|||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
async_gen = self.async_start(inputs, max_iterations, event_manager)
|
||||
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
|
||||
async_gen_task = asyncio.ensure_future(anext(async_gen))
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -366,7 +366,7 @@ class Graph:
|
|||
yield result
|
||||
if isinstance(result, Finish):
|
||||
return
|
||||
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
|
||||
async_gen_task = asyncio.ensure_future(anext(async_gen))
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
|
@ -1058,7 +1058,7 @@ class Graph:
|
|||
"""Updates the edges of a vertex in the Graph."""
|
||||
new_edges = []
|
||||
for edge in self.edges:
|
||||
if other_vertex.id in (edge.source_id, edge.target_id):
|
||||
if other_vertex.id in {edge.source_id, edge.target_id}:
|
||||
continue
|
||||
new_edges.append(edge)
|
||||
new_edges += other_vertex.edges
|
||||
|
|
@ -1210,7 +1210,7 @@ class Graph:
|
|||
return
|
||||
self.vertices.remove(vertex)
|
||||
self.vertex_map.pop(vertex_id)
|
||||
self.edges = [edge for edge in self.edges if vertex_id not in (edge.source_id, edge.target_id)]
|
||||
self.edges = [edge for edge in self.edges if vertex_id not in {edge.source_id, edge.target_id}]
|
||||
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
|
|
@ -1707,7 +1707,7 @@ class Graph:
|
|||
node_name = node_id.split("-")[0]
|
||||
if node_name in InterfaceComponentTypes:
|
||||
return InterfaceVertex
|
||||
if node_name in ["SharedState", "Notify", "Listen"]:
|
||||
if node_name in {"SharedState", "Notify", "Listen"}:
|
||||
return StateVertex
|
||||
if node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
|
|||
if new_edge["source"] == group_node_id:
|
||||
new_edge = update_source_handle(new_edge, g_nodes, g_edges)
|
||||
|
||||
if group_node_id in (edge["target"], edge["source"]):
|
||||
if group_node_id in {edge["target"], edge["source"]}:
|
||||
updated_edges.append(new_edge)
|
||||
return updated_edges
|
||||
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ class Vertex:
|
|||
except Exception: # noqa: BLE001
|
||||
logger.debug(f"Error evaluating code for {field_name}")
|
||||
params[field_name] = val
|
||||
elif field.get("type") in ["dict", "NestedDict"]:
|
||||
elif field.get("type") in {"dict", "NestedDict"}:
|
||||
# When dict comes from the frontend it comes as a
|
||||
# list of dicts, so we need to convert it to a dict
|
||||
# before passing it to the build method
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ def update_memory_keys(langchain_object, possible_new_mem_key):
|
|||
input_key = next(
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
|
||||
if key not in {langchain_object.memory.memory_key, possible_new_mem_key}
|
||||
)
|
||||
|
||||
output_key = next(
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
|
||||
if key not in {langchain_object.memory.memory_key, possible_new_mem_key}
|
||||
)
|
||||
|
||||
for key, attr in [(input_key, "input_key"), (output_key, "output_key"), (possible_new_mem_key, "memory_key")]:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def upload(file_path: str, host: str, flow_id: str):
|
|||
url = f"{host}/api/v1/upload/{flow_id}"
|
||||
with Path(file_path).open("rb") as file:
|
||||
response = httpx.post(url, files={"file": file})
|
||||
if response.status_code in (httpx.codes.OK, httpx.codes.CREATED):
|
||||
if response.status_code in {httpx.codes.OK, httpx.codes.CREATED}:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
msg = f"Error uploading file: {e}"
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ class Column(BaseModel):
|
|||
@field_validator("formatter", mode="before")
|
||||
@classmethod
|
||||
def validate_formatter(cls, value):
|
||||
if value in ["integer", "int", "float"]:
|
||||
if value in {"integer", "int", "float"}:
|
||||
value = FormatterType.number
|
||||
if value in ["str", "string"]:
|
||||
if value in {"str", "string"}:
|
||||
value = FormatterType.text
|
||||
if value == "dict":
|
||||
value = FormatterType.json
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class PluginService(Service):
|
|||
if (
|
||||
inspect.isclass(attr)
|
||||
and issubclass(attr, BasePlugin)
|
||||
and attr not in [CallbackPlugin, BasePlugin]
|
||||
and attr not in {CallbackPlugin, BasePlugin}
|
||||
):
|
||||
self.register_plugin(plugin_name, attr())
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
|
|||
|
|
@ -365,7 +365,7 @@ class Settings(BaseSettings):
|
|||
|
||||
|
||||
def save_settings_to_yaml(settings: Settings, file_path: str):
|
||||
with Path(file_path).open("w") as f:
|
||||
with Path(file_path).open("w", encoding="utf-8") as f:
|
||||
settings_dict = settings.model_dump()
|
||||
yaml.dump(settings_dict, f)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from loguru import logger
|
|||
|
||||
|
||||
def set_secure_permissions(file_path: Path):
|
||||
if platform.system() in ["Linux", "Darwin"]: # Unix/Linux/Mac
|
||||
if platform.system() in {"Linux", "Darwin"}: # Unix/Linux/Mac
|
||||
file_path.chmod(0o600)
|
||||
elif platform.system() == "Windows":
|
||||
import win32api
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class StoreService(Service):
|
|||
|
||||
return "id" in user_data[0]
|
||||
except HTTPStatusError as exc:
|
||||
if exc.response.status_code in [403, 401]:
|
||||
if exc.response.status_code in {403, 401}:
|
||||
return False
|
||||
msg = f"Unexpected status code: {exc.response.status_code}"
|
||||
raise ValueError(msg) from exc
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class Input(BaseModel):
|
|||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
# If the field is str, we add the Text input type
|
||||
if self.field_type in ["str", "Text"] and "input_types" not in result:
|
||||
if self.field_type in {"str", "Text"} and "input_types" not in result:
|
||||
result["input_types"] = ["Text"]
|
||||
if self.field_type == Text:
|
||||
result["type"] = "str"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
|
|||
"""
|
||||
Extracts the inner type from a type hint that is a list or a Optional.
|
||||
"""
|
||||
if return_type.__origin__ in [list, SequenceABC]:
|
||||
if return_type.__origin__ in {list, SequenceABC}:
|
||||
return list(return_type.__args__)
|
||||
return return_type
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ def post_process_type(_type):
|
|||
Union[List[Any], Any]: The processed return type.
|
||||
|
||||
"""
|
||||
if hasattr(_type, "__origin__") and _type.__origin__ in [list, list, SequenceABC]:
|
||||
if hasattr(_type, "__origin__") and _type.__origin__ in {list, list, SequenceABC}:
|
||||
_type = extract_inner_type_from_generic_alias(_type)
|
||||
|
||||
# If the return type is not a Union, then we just return it as a list
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ def build_template_from_method(
|
|||
"required": param.default == param.empty,
|
||||
}
|
||||
for name, param in params.items()
|
||||
if name not in ["self", "kwargs", "args"]
|
||||
if name not in {"self", "kwargs", "args"}
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def fetch_latest_version(package_name: str, include_prerelease: bool) -> str | N
|
|||
valid_versions = [v for v in versions if include_prerelease or not is_pre_release(v)]
|
||||
if not valid_versions:
|
||||
return None # Handle case where no valid versions are found
|
||||
return max(valid_versions, key=lambda v: pkg_version.parse(v))
|
||||
return max(valid_versions, key=pkg_version.parse)
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error fetching latest version")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue