feat: Add ruff rules SIM (#3979)

Add ruff rules SIM
This commit is contained in:
Christophe Bornet 2024-10-02 14:45:41 +02:00 committed by GitHub
commit 942c8dca36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 228 additions and 374 deletions

View file

@ -457,11 +457,10 @@ def migration(
"""
Run or test migrations.
"""
if fix:
if not typer.confirm(
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
):
raise typer.Abort
if fix and not typer.confirm(
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
):
raise typer.Abort
initialize_services(fix_migration=fix)
db_service = get_db_service()

View file

@ -90,10 +90,7 @@ async def logs(
status_code=HTTPStatus.BAD_REQUEST,
detail="Timestamp is required when requesting logs after the timestamp",
)
if lines_before <= 0:
content = log_buffer.get_last_n(10)
else:
content = log_buffer.get_last_n(lines_before)
content = log_buffer.get_last_n(10) if lines_before <= 0 else log_buffer.get_last_n(lines_before)
else:
if lines_before > 0:
content = log_buffer.get_before_timestamp(timestamp=timestamp, lines=lines_before)

View file

@ -344,20 +344,17 @@ async def build_flow(
raise ValueError(msg) from exc
event_manager.on_end_vertex(data={"build_data": build_data})
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(
build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)
)
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
if vertex_build_response.valid and vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
if not data:

View file

@ -95,13 +95,12 @@ def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
if has_input_value and input_value_is_chat:
msg = "If you pass an input_value to the chat input, you cannot pass a tweak with the same name."
raise InvalidChatInputException(msg)
elif "Text Input" in key or "TextInput" in key:
if isinstance(value, dict):
has_input_value = value.get("input_value") is not None
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
if has_input_value and input_value_is_text:
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
raise InvalidChatInputException(msg)
elif ("Text Input" in key or "TextInput" in key) and isinstance(value, dict):
has_input_value = value.get("input_value") is not None
input_value_is_text = input_request.input_value is not None and input_request.input_type == "text"
if has_input_value and input_value_is_text:
msg = "If you pass an input_value to the text input, you cannot pass a tweak with the same name."
raise InvalidChatInputException(msg)
async def simple_run_flow(

View file

@ -310,10 +310,7 @@ async def upload_file(
contents = await file.read()
data = orjson.loads(contents)
response_list = []
if "flows" in data:
flow_list = FlowListCreate(**data)
else:
flow_list = FlowListCreate(flows=[FlowCreate(**data)])
flow_list = FlowListCreate(**data) if "flows" in data else FlowListCreate(flows=[FlowCreate(**data)])
# Now we set the user_id for all flows
for flow in flow_list.flows:
flow.user_id = current_user.id

View file

@ -53,10 +53,7 @@ class BaseCrewComponent(Component):
self,
) -> Callable:
def task_callback(task_output: TaskOutput):
if self._vertex:
vertex_id = self._vertex.id
else:
vertex_id = self.display_name or self.__class__.__name__
vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
return task_callback

View file

@ -67,10 +67,7 @@ def build_data_from_result_data(result_data: ResultData, get_final_results_only:
if isinstance(result_data.results, dict):
for name, result in result_data.results.items():
dataobj: Data | Message | None = None
if isinstance(result, Message):
dataobj = result
else:
dataobj = Data(data=result, text_key=name)
dataobj = result if isinstance(result, Message) else Data(data=result, text_key=name)
data.append(dataobj)
else:

View file

@ -25,12 +25,16 @@ class ChatComponent(Component):
msg = "Only one message can be stored at a time."
raise ValueError(msg)
stored_message = messages[0]
if hasattr(self, "_event_manager") and self._event_manager and stored_message.id:
if not isinstance(message.text, str):
complete_message = self._stream_message(message, stored_message.id)
message_table = update_message(message_id=stored_message.id, message={"text": complete_message})
stored_message = Message(**message_table.model_dump())
self.vertex._added_message = stored_message
if (
hasattr(self, "_event_manager")
and self._event_manager
and stored_message.id
and not isinstance(message.text, str)
):
complete_message = self._stream_message(message, stored_message.id)
message_table = update_message(message_id=stored_message.id, message={"text": complete_message})
stored_message = Message(**message_table.model_dump())
self.vertex._added_message = stored_message
self.status = stored_message
return stored_message
@ -77,8 +81,6 @@ class ChatComponent(Component):
session_id: str | None = None,
return_message: bool | None = False,
) -> Message:
message: Message | None = None
if isinstance(input_value, Data):
# Update the data of the record
message = Message.from_data(input_value)
@ -86,10 +88,7 @@ class ChatComponent(Component):
message = Message(
text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id
)
if not return_message:
message_text = message.text
else:
message_text = message # type: ignore
message_text = message.text if not return_message else message
self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):

View file

@ -130,10 +130,7 @@ class SequentialTaskAgentComponent(Component):
# If there's a previous task, create a list of tasks
if self.previous_task:
if isinstance(self.previous_task, list):
tasks = self.previous_task + [task]
else:
tasks = [self.previous_task, task]
tasks = self.previous_task + [task] if isinstance(self.previous_task, list) else [self.previous_task, task]
else:
tasks = [task]

View file

@ -28,10 +28,7 @@ class SQLGeneratorComponent(LCChainComponent):
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
def invoke_chain(self) -> Message:
if self.prompt:
prompt_template = PromptTemplate.from_template(template=self.prompt)
else:
prompt_template = None
prompt_template = PromptTemplate.from_template(template=self.prompt) if self.prompt else None
if self.top_k < 1:
msg = "Top K must be greater than 0."

View file

@ -96,10 +96,7 @@ class GmailLoaderComponent(Component):
msg = "From email not found."
raise ValueError(msg)
if "parts" in msg["payload"]:
parts = msg["payload"]["parts"]
else:
parts = [msg["payload"]]
parts = msg["payload"]["parts"] if "parts" in msg["payload"] else [msg["payload"]]
for part in parts:
if part["mimeType"] == "text/plain":

View file

@ -96,17 +96,13 @@ class GoogleDriveSearchComponent(Component):
"""
Generates the appropriate Google Drive URL for a file based on its MIME type.
"""
if mime_type == "application/vnd.google-apps.document":
return f"https://docs.google.com/document/d/{file_id}/edit"
if mime_type == "application/vnd.google-apps.spreadsheet":
return f"https://docs.google.com/spreadsheets/d/{file_id}/edit"
if mime_type == "application/vnd.google-apps.presentation":
return f"https://docs.google.com/presentation/d/{file_id}/edit"
if mime_type == "application/vnd.google-apps.drawing":
return f"https://docs.google.com/drawings/d/{file_id}/edit"
if mime_type == "application/pdf":
return f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk"
return f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk"
return {
"application/vnd.google-apps.document": f"https://docs.google.com/document/d/{file_id}/edit",
"application/vnd.google-apps.spreadsheet": f"https://docs.google.com/spreadsheets/d/{file_id}/edit",
"application/vnd.google-apps.presentation": f"https://docs.google.com/presentation/d/{file_id}/edit",
"application/vnd.google-apps.drawing": f"https://docs.google.com/drawings/d/{file_id}/edit",
"application/pdf": f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk",
}.get(mime_type, f"https://drive.google.com/file/d/{file_id}/view?usp=drivesdk")
def search_files(self) -> dict:
# Load the token information from the JSON string

View file

@ -39,10 +39,7 @@ class TextEmbedderComponent(Component):
embeddings = embedding_model.embed_documents([text_content])
# Assuming the embedding model returns a list of embeddings, we take the first one
if embeddings:
embedding_vector = embeddings[0]
else:
embedding_vector = []
embedding_vector = embeddings[0] if embeddings else []
# Create a Data object to encapsulate the results
result_data = Data(data={"text": text_content, "embeddings": embedding_vector})

View file

@ -21,25 +21,24 @@ class AIMLEmbeddingsImpl(BaseModel, Embeddings):
"Authorization": f"Bearer {self.api_key.get_secret_value()}",
}
with httpx.Client() as client:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for i, text in enumerate(texts):
futures.append((i, executor.submit(self._embed_text, client, headers, text)))
with httpx.Client() as client, concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for i, text in enumerate(texts):
futures.append((i, executor.submit(self._embed_text, client, headers, text)))
for index, future in futures:
try:
result_data = future.result()
assert len(result_data["data"]) == 1, "Expected one embedding"
embeddings[index] = result_data["data"][0]["embedding"]
except (
httpx.HTTPStatusError,
httpx.RequestError,
json.JSONDecodeError,
KeyError,
) as e:
logger.error(f"Error occurred: {e}")
raise
for index, future in futures:
try:
result_data = future.result()
assert len(result_data["data"]) == 1, "Expected one embedding"
embeddings[index] = result_data["data"][0]["embedding"]
except (
httpx.HTTPStatusError,
httpx.RequestError,
json.JSONDecodeError,
KeyError,
) as e:
logger.error(f"Error occurred: {e}")
raise
return embeddings # type: ignore

View file

@ -61,15 +61,9 @@ class FirecrawlCrawlApi(CustomComponent):
"Could not import firecrawl integration package. " "Please install it with `pip install firecrawl-py`."
)
raise ImportError(msg)
if crawlerOptions:
crawler_options_dict = crawlerOptions.__dict__["data"]["text"]
else:
crawler_options_dict = {}
crawler_options_dict = crawlerOptions.__dict__["data"]["text"] if crawlerOptions else {}
if pageOptions:
page_options_dict = pageOptions.__dict__["data"]["text"]
else:
page_options_dict = {}
page_options_dict = pageOptions.__dict__["data"]["text"] if pageOptions else {}
if not idempotency_key:
idempotency_key = str(uuid.uuid4())

View file

@ -54,15 +54,9 @@ class FirecrawlScrapeApi(CustomComponent):
"Could not import firecrawl integration package. " "Please install it with `pip install firecrawl-py`."
)
raise ImportError(msg)
if extractorOptions:
extractor_options_dict = extractorOptions.__dict__["data"]["text"]
else:
extractor_options_dict = {}
extractor_options_dict = extractorOptions.__dict__["data"]["text"] if extractorOptions else {}
if pageOptions:
page_options_dict = pageOptions.__dict__["data"]["text"]
else:
page_options_dict = {}
page_options_dict = pageOptions.__dict__["data"]["text"] if pageOptions else {}
app = FirecrawlApp(api_key=api_key)
results = app.scrape_url(

View file

@ -79,10 +79,7 @@ class AIMLModelComponent(LCModelComponent):
aiml_api_base = self.aiml_api_base or "https://api.aimlapi.com"
seed = self.seed
if isinstance(aiml_api_key, SecretStr):
openai_api_key = aiml_api_key.get_secret_value()
else:
openai_api_key = aiml_api_key
openai_api_key = aiml_api_key.get_secret_value() if isinstance(aiml_api_key, SecretStr) else aiml_api_key
return ChatOpenAI(
model=model_name,

View file

@ -36,10 +36,7 @@ class CohereComponent(LCModelComponent):
cohere_api_key = self.cohere_api_key
temperature = self.temperature
if cohere_api_key:
api_key = SecretStr(cohere_api_key)
else:
api_key = None
api_key = SecretStr(cohere_api_key) if cohere_api_key else None
return ChatCohere(
temperature=temperature or 0.75,

View file

@ -78,10 +78,7 @@ class MistralAIModelComponent(LCModelComponent):
random_seed = self.random_seed
safe_mode = self.safe_mode
if mistral_api_key:
api_key = SecretStr(mistral_api_key)
else:
api_key = None
api_key = SecretStr(mistral_api_key) if mistral_api_key else None
return ChatMistralAI(
max_tokens=max_tokens or None,

View file

@ -102,10 +102,7 @@ class OpenAIModelComponent(LCModelComponent):
json_mode = bool(output_schema_dict) or self.json_mode
seed = self.seed
if openai_api_key:
api_key = SecretStr(openai_api_key)
else:
api_key = None
api_key = SecretStr(openai_api_key) if openai_api_key else None
output = ChatOpenAI(
max_tokens=max_tokens or None,
model_kwargs=model_kwargs,

View file

@ -57,7 +57,7 @@ class SubFlowComponent(Component):
for vertex in inputs_vertex:
new_vertex_inputs = []
field_template = vertex.data["node"]["template"]
for inp in field_template.keys():
for inp in field_template:
if inp not in ["code", "_type"]:
field_template[inp]["display_name"] = (
vertex.display_name + " - " + field_template[inp]["display_name"]
@ -84,10 +84,10 @@ class SubFlowComponent(Component):
async def generate_results(self) -> list[Data]:
tweaks: dict = {}
for field in self._attributes.keys():
for field in self._attributes:
if field != "flow_name":
[node, name] = field.split("|")
if node not in tweaks.keys():
if node not in tweaks:
tweaks[node] = {}
tweaks[node][name] = self._attributes[field]
flow_name = self._attributes.get("flow_name")

View file

@ -43,10 +43,7 @@ class CharacterTextSplitterComponent(LCTextSplitterComponent):
return self.data_input
def build_text_splitter(self) -> TextSplitter:
if self.separator:
separator = unescape_string(self.separator)
else:
separator = "\n\n"
separator = unescape_string(self.separator) if self.separator else "\n\n"
return CharacterTextSplitter(
chunk_overlap=self.chunk_overlap,
chunk_size=self.chunk_size,

View file

@ -51,10 +51,7 @@ class NaturalLanguageTextSplitterComponent(LCTextSplitterComponent):
return self.data_input
def build_text_splitter(self) -> TextSplitter:
if self.separator:
separator = unescape_string(self.separator)
else:
separator = "\n\n"
separator = unescape_string(self.separator) if self.separator else "\n\n"
return NLTKTextSplitter(
language=self.language.lower() if self.language else "english",
separator=separator,

View file

@ -188,9 +188,7 @@ class PythonCodeStructuredTool(LCToolComponent):
schema_annotation = Any
schema_fields[field_name] = (
schema_annotation,
Field(
default=func_arg["default"] if "default" in func_arg else Undefined, description=field_description
),
Field(default=func_arg.get("default", Undefined), description=field_description),
)
if "temp_annotation_type" in _globals:

View file

@ -174,10 +174,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
else:
documents.append(_input)
if self.enable_body_search:
body_index_options = [("index_analyzer", "STANDARD")]
else:
body_index_options = None
body_index_options = [("index_analyzer", "STANDARD")] if self.enable_body_search else None
if self.setup_mode == "Off":
setup_mode = SetupMode.OFF

View file

@ -161,10 +161,7 @@ class CassandraGraphVectorStoreComponent(LCVectorStoreComponent):
else:
documents.append(_input)
if self.setup_mode == "Off":
setup_mode = SetupMode.OFF
else:
setup_mode = SetupMode.SYNC
setup_mode = SetupMode.OFF if self.setup_mode == "Off" else SetupMode.SYNC
if documents:
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")

View file

@ -125,10 +125,7 @@ class ChromaVectorStoreComponent(LCVectorStoreComponent):
client = Client(settings=chroma_settings)
# Check persist_directory and expand it if it is a relative path
if self.persist_directory is not None:
persist_directory = self.resolve_path(self.persist_directory)
else:
persist_directory = None
persist_directory = self.resolve_path(self.persist_directory) if self.persist_directory is not None else None
chroma = Chroma(
persist_directory=persist_directory,

View file

@ -1,4 +1,5 @@
import ast
import contextlib
import inspect
import traceback
from typing import Any
@ -171,11 +172,9 @@ class CodeParser:
return_type_str = ast.unparse(node.returns)
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
try:
# Handle cases where the type is not found in the constructed environment
with contextlib.suppress(NameError):
return_type = eval(return_type_str, eval_env)
except NameError:
# Handle cases where the type is not found in the constructed environment
pass
func = CallableCodeDetails(
name=node.name,

View file

@ -81,7 +81,7 @@ class BaseComponent:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
if key not in ATTR_FUNC_MAPPING:
template_config.pop(key, None)
return template_config

View file

@ -226,7 +226,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
field_required,
config,
)
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config.keys()):
if "kwargs" in function_args_names and not all(key in function_args_names for key in field_config):
for field_name, field_config in _field_config.copy().items():
if "name" not in field_config or field_name == "code":
continue
@ -503,35 +503,35 @@ def update_field_dict(
call: bool = False,
):
"""Update the field dictionary by calling options() or value() if they are callable"""
if ("real_time_refresh" in field_dict or "refresh_button" in field_dict) and any(
(
field_dict.get("real_time_refresh", False),
field_dict.get("refresh_button", False),
if (
("real_time_refresh" in field_dict or "refresh_button" in field_dict)
and any(
(
field_dict.get("real_time_refresh", False),
field_dict.get("refresh_button", False),
)
)
and call
):
if call:
try:
dd_build_config = dotdict(build_config)
custom_component_instance.update_build_config(
build_config=dd_build_config,
field_value=update_field,
field_name=update_field_value,
)
build_config = dd_build_config
except Exception as exc:
logger.error(f"Error while running update_build_config: {str(exc)}")
msg = f"Error while running update_build_config: {str(exc)}"
raise UpdateBuildConfigError(msg) from exc
try:
dd_build_config = dotdict(build_config)
custom_component_instance.update_build_config(
build_config=dd_build_config,
field_value=update_field,
field_name=update_field_value,
)
build_config = dd_build_config
except Exception as exc:
logger.error(f"Error while running update_build_config: {str(exc)}")
msg = f"Error while running update_build_config: {str(exc)}"
raise UpdateBuildConfigError(msg) from exc
return build_config
def sanitize_field_config(field_config: dict | Input):
# If any of the already existing keys are in field_config, remove them
if isinstance(field_config, Input):
field_dict = field_config.to_dict()
else:
field_dict = field_config
field_dict = field_config.to_dict() if isinstance(field_config, Input) else field_config
for key in [
"name",
"field_type",

View file

@ -244,10 +244,10 @@ class CycleEdge(Edge):
await self.honor(source, target)
# If the target vertex is a power component we log messages
if target.vertex_type == "ChatOutput" and (
isinstance(target.params.get(INPUT_FIELD_NAME), str)
or isinstance(target.params.get(INPUT_FIELD_NAME), dict)
if (
target.vertex_type == "ChatOutput"
and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict)
and target.params.get("message") == ""
):
if target.params.get("message") == "":
return self.result
return self.result
return self.result

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import copy
import json
import uuid
@ -1061,10 +1062,7 @@ class Graph:
same_length = len(vertex.edges) == len(other_vertex.edges)
if not same_length:
return False
for edge in vertex.edges:
if edge not in other_vertex.edges:
return False
return True
return all(edge in other_vertex.edges for edge in vertex.edges)
def update(self, other: Graph) -> Graph:
# Existing vertices in self graph
@ -1080,10 +1078,8 @@ class Graph:
# Remove vertices that are not in the other graph
for vertex_id in removed_vertex_ids:
try:
with contextlib.suppress(ValueError):
self.remove_vertex(vertex_id)
except ValueError:
pass
# The order here matters because adding the vertex is required
# if any of them have edges that point to any of the new vertices

View file

@ -55,9 +55,7 @@ class RunnableVerticesManager:
return False
if vertex_id not in self.vertices_to_run:
return False
if not self.are_all_predecessors_fulfilled(vertex_id):
return False
return True
return self.are_all_predecessors_fulfilled(vertex_id)
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
return not any(self.run_predecessors.get(vertex_id, []))

View file

@ -354,12 +354,7 @@ def has_cycle(vertex_ids: list[str], edges: list[tuple[str, str]]) -> bool:
visited: set[str] = set()
rec_stack: set[str] = set()
for vertex in vertex_ids:
if vertex not in visited:
if dfs(vertex, visited, rec_stack):
return True
return False
return any(vertex not in visited and dfs(vertex, visited, rec_stack) for vertex in vertex_ids)
def find_cycle_edge(entry_point: str, edges: list[tuple[str, str]]) -> tuple[str, str]:

View file

@ -103,11 +103,10 @@ def get_artifact_type(value, build_result) -> str:
case Message():
result = ArtifactType.MESSAGE
if result == ArtifactType.UNKNOWN:
if isinstance(build_result, Generator):
result = ArtifactType.STREAM
elif isinstance(value, Message) and isinstance(value.text, Generator):
result = ArtifactType.STREAM
if result == ArtifactType.UNKNOWN and (
isinstance(build_result, Generator) or isinstance(value, Message) and isinstance(value.text, Generator)
):
result = ArtifactType.STREAM
return result.value

View file

@ -287,7 +287,7 @@ class Vertex:
if not param_dict or len(param_dict) != 1:
params[param_key] = self.graph.get_vertex(edge.source_id)
else:
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()}
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict}
else:
params[param_key] = self.graph.get_vertex(edge.source_id)
@ -415,8 +415,6 @@ class Vertex:
elif val is not None and val != "":
params[field_name] = val
elif val is not None and val != "":
params[field_name] = val
if field.get("load_from_db"):
load_from_db_fields.append(field_name)
@ -534,10 +532,7 @@ class Vertex:
# to the frontend
self.set_artifacts()
artifacts = self.artifacts_raw
if isinstance(artifacts, dict):
messages = self.extract_messages_from_artifacts(artifacts)
else:
messages = []
messages = self.extract_messages_from_artifacts(artifacts) if isinstance(artifacts, dict) else []
result_dict = ResultData(
results=result_dict,
artifacts=artifacts,

View file

@ -1,4 +1,5 @@
import asyncio
import contextlib
import json
from collections.abc import AsyncIterator, Generator, Iterator
from typing import TYPE_CHECKING, Any, cast
@ -167,7 +168,7 @@ class ComponentVertex(Vertex):
message_dict = artifact if isinstance(artifact, dict) else artifact.model_dump()
if not message_dict.get("text"):
continue
try:
with contextlib.suppress(KeyError):
messages.append(
ChatOutputResponse(
message=message_dict["text"],
@ -182,8 +183,6 @@ class ComponentVertex(Vertex):
type=self.artifacts_type[key],
).model_dump(exclude_none=True)
)
except KeyError:
pass
return messages
def _finalize_build(self):
@ -440,11 +439,12 @@ class InterfaceVertex(ComponentVertex):
for key, value in origin_vertex.results.items():
if isinstance(value, AsyncIterator | Iterator):
origin_vertex.results[key] = complete_message
if self._custom_component:
if hasattr(self._custom_component, "should_store_message") and hasattr(
self._custom_component, "store_message"
):
self._custom_component.store_message(message)
if (
self._custom_component
and hasattr(self._custom_component, "should_store_message")
and hasattr(self._custom_component, "store_message")
):
self._custom_component.store_message(message)
log_vertex_build(
flow_id=self.graph.flow_id,
vertex_id=self.id,

View file

@ -90,17 +90,19 @@ def update_projects_components_with_latest_component_versions(project_data, all_
)
else:
for attr in NODE_FORMAT_ATTRIBUTES:
if attr in latest_node:
if (
attr in latest_node
# Check if it needs to be updated
if latest_node[attr] != node_data.get(attr):
node_changes_log[node_data["display_name"]].append(
{
"attr": attr,
"old_value": node_data.get(attr),
"new_value": latest_node[attr],
}
)
node_data[attr] = latest_node[attr]
and latest_node[attr] != node_data.get(attr)
):
node_changes_log[node_data["display_name"]].append(
{
"attr": attr,
"old_value": node_data.get(attr),
"new_value": latest_node[attr],
}
)
node_data[attr] = latest_node[attr]
for field_name, field_dict in latest_template.items():
if field_name not in node_data["template"]:
@ -109,17 +111,20 @@ def update_projects_components_with_latest_component_versions(project_data, all_
# The idea here is to update some attributes of the field
to_check_attributes = FIELD_FORMAT_ATTRIBUTES
for attr in to_check_attributes:
if attr in field_dict and attr in node_data["template"].get(field_name):
if (
attr in field_dict
and attr in node_data["template"].get(field_name)
# Check if it needs to be updated
if field_dict[attr] != node_data["template"][field_name][attr]:
node_changes_log[node_data["display_name"]].append(
{
"attr": f"{field_name}.{attr}",
"old_value": node_data["template"][field_name][attr],
"new_value": field_dict[attr],
}
)
node_data["template"][field_name][attr] = field_dict[attr]
and field_dict[attr] != node_data["template"][field_name][attr]
):
node_changes_log[node_data["display_name"]].append(
{
"attr": f"{field_name}.{attr}",
"old_value": node_data["template"][field_name][attr],
"new_value": field_dict[attr],
}
)
node_data["template"][field_name][attr] = field_dict[attr]
# Remove fields that are not in the latest template
if node_data.get("display_name") != "Prompt":
for field_name in list(node_data["template"].keys()):
@ -274,16 +279,17 @@ def update_edges_with_latest_component_versions(project_data):
source_handle["output_types"] = new_output_types
field_name = target_handle.get("fieldName")
if field_name in target_node_data.get("template"):
if target_handle["inputTypes"] != target_node_data.get("template").get(field_name).get("input_types"):
edge_changes_log[target_node_data["display_name"]].append(
{
"attr": "inputTypes",
"old_value": target_handle["inputTypes"],
"new_value": target_node_data.get("template").get(field_name).get("input_types"),
}
)
target_handle["inputTypes"] = target_node_data.get("template").get(field_name).get("input_types")
if field_name in target_node_data.get("template") and target_handle["inputTypes"] != target_node_data.get(
"template"
).get(field_name).get("input_types"):
edge_changes_log[target_node_data["display_name"]].append(
{
"attr": "inputTypes",
"old_value": target_handle["inputTypes"],
"new_value": target_node_data.get("template").get(field_name).get("input_types"),
}
)
target_handle["inputTypes"] = target_node_data.get("template").get(field_name).get("input_types")
escaped_source_handle = escape_json_dump(source_handle)
escaped_target_handle = escape_json_dump(target_handle)
try:
@ -390,10 +396,7 @@ def get_project_data(project):
updated_at_datetime = datetime.strptime(project_updated_at, "%Y-%m-%dT%H:%M:%S.%f")
project_data = project.get("data")
project_icon = project.get("icon")
if project_icon and purely_emoji(project_icon):
project_icon = demojize(project_icon)
else:
project_icon = ""
project_icon = demojize(project_icon) if project_icon and purely_emoji(project_icon) else ""
project_icon_bg_color = project.get("icon_bg_color")
return (
project_name,

View file

@ -131,12 +131,7 @@ class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTr
ValueError: If the value is not of a valid type or if the input is missing a required key.
"""
is_list = _info.data["is_list"]
value = None
if is_list:
value = [cls._validate_value(vv, _info) for vv in v]
else:
value = cls._validate_value(v, _info)
return value
return [cls._validate_value(vv, _info) for vv in v] if is_list else cls._validate_value(v, _info)
class MessageInput(StrInput, InputTraceMixin):

View file

@ -89,12 +89,11 @@ def convert_kwargs(params):
# Loop through items to avoid repeated lookups
items_to_remove = []
for key, value in params.items():
if "kwargs" in key or "config" in key:
if isinstance(value, str):
try:
params[key] = orjson.loads(value)
except orjson.JSONDecodeError:
items_to_remove.append(key)
if ("kwargs" in key or "config" in key) and isinstance(value, str):
try:
params[key] = orjson.loads(value)
except orjson.JSONDecodeError:
items_to_remove.append(key)
# Remove invalid keys outside the loop to avoid modifying dict during iteration
for key in items_to_remove:

View file

@ -69,11 +69,9 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
if not match:
break
# Extract the variable name from either the single or double brace match
if match.group(1): # Match found in double braces
variable_name = "{{" + match.group(1) + "}}" # Re-add single braces for JSON strings
else: # Match found in single braces
variable_name = match.group(2)
# Extract the variable name from either the single or double brace match.
# If match found in double braces, re-add single braces for JSON strings.
variable_name = "{{" + match.group(1) + "}}" if match.group(1) else match.group(2)
if variable_name is not None:
# This means there is a match
# but there is nothing inside the braces

View file

@ -77,22 +77,18 @@ class SizedLogBuffer:
try:
with self._wlock:
as_list = list(self.buffer)
i = 0
max_index = -1
for ts, msg in as_list:
for i, (ts, msg) in enumerate(as_list):
if ts >= timestamp:
max_index = i
break
i += 1
if max_index == -1:
return self.get_last_n(lines)
rc = {}
i = 0
start_from = max(max_index - lines, 0)
for ts, msg in as_list:
for i, (ts, msg) in enumerate(as_list):
if start_from <= i < max_index:
rc[ts] = msg
i += 1
return rc
finally:
self._rsemaphore.release()

View file

@ -46,10 +46,7 @@ def get_messages(
if flow_id:
stmt = stmt.where(MessageTable.flow_id == flow_id)
if order_by:
if order == "DESC":
col = getattr(MessageTable, order_by).desc()
else:
col = getattr(MessageTable, order_by).asc()
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
if limit:
stmt = stmt.limit(limit)

View file

@ -29,10 +29,7 @@ async def run_graph_internal(
) -> tuple[list[RunOutputs], str]:
"""Run the graph and generate the result"""
inputs = inputs or []
if session_id is None:
session_id_str = flow_id
else:
session_id_str = session_id
session_id_str = flow_id if session_id is None else session_id
components = []
inputs_list = []
types = []
@ -168,11 +165,7 @@ def process_tweaks(
:return: The modified graph_data dictionary.
:raises ValueError: If the input is not in the expected format.
"""
tweaks_dict = {}
if not isinstance(tweaks, dict):
tweaks_dict = cast(dict[str, Any], tweaks.model_dump())
else:
tweaks_dict = tweaks
tweaks_dict = cast(dict[str, Any], tweaks.model_dump()) if not isinstance(tweaks, dict) else tweaks
if "stream" not in tweaks_dict:
tweaks_dict |= {"stream": stream}
nodes = validate_input(graph_data, cast(dict[str, str | dict[str, Any]], tweaks_dict))
@ -182,9 +175,7 @@ def process_tweaks(
all_nodes_tweaks = {}
for key, value in tweaks_dict.items():
if isinstance(value, dict):
if node := nodes_map.get(key):
apply_tweaks(node, value)
elif node := nodes_display_name_map.get(key):
if (node := nodes_map.get(key)) or (node := nodes_display_name_map.get(key)):
apply_tweaks(node, value)
else:
all_nodes_tweaks[key] = value

View file

@ -41,11 +41,13 @@ def get_artifact_type(value, build_result=None) -> str:
case list():
result = ArtifactType.ARRAY
if result == ArtifactType.UNKNOWN:
if build_result and isinstance(build_result, Generator):
result = ArtifactType.STREAM
elif isinstance(value, Message) and isinstance(value.text, Generator):
result = ArtifactType.STREAM
if result == ArtifactType.UNKNOWN and (
build_result
and isinstance(build_result, Generator)
or isinstance(value, Message)
and isinstance(value.text, Generator)
):
result = ArtifactType.STREAM
return result.value
@ -56,9 +58,7 @@ def post_process_raw(raw, artifact_type: str):
elif artifact_type == ArtifactType.ARRAY.value:
_raw = []
for item in raw:
if hasattr(item, "dict"):
_raw.append(recursive_serialize_or_str(item))
elif hasattr(item, "model_dump"):
if hasattr(item, "dict") or hasattr(item, "model_dump"):
_raw.append(recursive_serialize_or_str(item))
else:
_raw.append(str(item))

View file

@ -103,10 +103,7 @@ class Message(Data):
# they are: "text", "sender"
if self.text is None or not self.sender:
logger.warning("Missing required keys ('text', 'sender') in Message, defaulting to HumanMessage.")
if not isinstance(self.text, str):
text = ""
else:
text = self.text
text = "" if not isinstance(self.text, str) else self.text
if self.sender == MESSAGE_SENDER_USER or not self.sender:
if self.files:
@ -160,9 +157,7 @@ class Message(Data):
@field_serializer("text", mode="plain")
def serialize_text(self, value):
if isinstance(value, AsyncIterator):
return ""
if isinstance(value, Iterator):
if isinstance(value, AsyncIterator | Iterator):
return ""
return value

View file

@ -56,12 +56,13 @@ def get_type(payload):
case str():
result = LogType.TEXT
if result == LogType.UNKNOWN:
if payload and isinstance(payload, Generator):
result = LogType.STREAM
elif isinstance(payload, Message) and isinstance(payload.text, Generator):
result = LogType.STREAM
if result == LogType.UNKNOWN and (
payload
and isinstance(payload, Generator)
or isinstance(payload, Message)
and isinstance(payload.text, Generator)
):
result = LogType.STREAM
return result

View file

@ -72,11 +72,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]): # type: ignore
# Move the key to the end to make it recently used
self._cache.move_to_end(key)
# Check if the value is pickled
if isinstance(item["value"], bytes):
value = pickle.loads(item["value"])
else:
value = item["value"]
return value
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
self.delete(key)
return None

View file

@ -126,10 +126,7 @@ def save_uploaded_file(file: UploadFile, folder_name):
cache_path = Path(CACHE_DIR)
folder_path = cache_path / folder_name
filename = file.filename
if isinstance(filename, str) or isinstance(filename, Path):
file_extension = Path(filename).suffix
else:
file_extension = ""
file_extension = Path(filename).suffix if isinstance(filename, str | Path) else ""
file_object = file.file
# Create the folder if it doesn't exist

View file

@ -93,10 +93,7 @@ class CacheService(Subject, Service):
"image": "png",
"pandas": "csv",
}
if obj_type in object_extensions:
_extension = object_extensions[obj_type]
else:
_extension = type(obj).__name__.lower()
_extension = object_extensions[obj_type] if obj_type in object_extensions else type(obj).__name__.lower()
self.current_cache[name] = {
"obj": obj,
"type": obj_type,

View file

@ -117,10 +117,10 @@ class FlowBase(SQLModel):
raise ValueError(msg)
# data must contain nodes and edges
if "nodes" not in v.keys():
if "nodes" not in v:
msg = "Flow must have nodes"
raise ValueError(msg)
if "edges" not in v.keys():
if "edges" not in v:
msg = "Flow must have edges"
raise ValueError(msg)

View file

@ -46,12 +46,9 @@ class MessageBase(SQLModel):
timestamp = message.timestamp
if not flow_id and message.flow_id:
flow_id = message.flow_id
if not isinstance(message.text, str):
# If the text is not a string, it means it could be
# async iterator so we simply add it as an empty string
message_text = ""
else:
message_text = message.text
# If the text is not a string, it means it could be
# async iterator so we simply add it as an empty string
message_text = "" if not isinstance(message.text, str) else message.text
return cls(
sender=message.sender,
sender_name=message.sender_name,

View file

@ -31,10 +31,7 @@ def is_list_of_any(field: FieldInfo) -> bool:
if field.annotation is None:
return False
try:
if hasattr(field.annotation, "__args__"):
union_args = field.annotation.__args__
else:
union_args = []
union_args = field.annotation.__args__ if hasattr(field.annotation, "__args__") else []
return field.annotation.__origin__ is list or any(
arg.__origin__ is list for arg in union_args if hasattr(arg, "__origin__")
@ -267,10 +264,7 @@ class Settings(BaseSettings):
final_path = new_path
if final_path is None:
if is_pre_release:
final_path = new_pre_path
else:
final_path = new_path
final_path = new_pre_path if is_pre_release else new_path
value = f"sqlite:///{final_path}"
@ -370,7 +364,7 @@ def load_settings_from_yaml(file_path: str) -> Settings:
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.model_fields.keys():
if key not in Settings.model_fields:
msg = f"Key {key} not found in settings"
raise KeyError(msg)
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")

View file

@ -30,7 +30,7 @@ class SettingsService(Service):
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.model_fields.keys():
if key not in Settings.model_fields:
msg = f"Key {key} not found in settings"
raise KeyError(msg)
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")

View file

@ -126,10 +126,7 @@ class StoreService(Service):
self, url: str, api_key: str | None = None, params: dict[str, Any] | None = None
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Utility method to perform GET requests."""
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
else:
headers = {}
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers, params=params, timeout=self.timeout)

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from uuid import UUID
from loguru import logger
@ -10,6 +12,7 @@ from langflow.services.tracing.schema import Log
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackHandler
from langfuse.client import StatefulSpanClient
from langflow.graph.vertex.base import Vertex
@ -23,7 +26,7 @@ class LangFuseTracer(BaseTracer):
self.trace_type = trace_type
self.trace_id = trace_id
self.flow_id = trace_name.split(" - ")[-1]
self.last_span = None
self.last_span: StatefulSpanClient | None = None
self.spans: dict = {}
self._ready: bool = self.setup_langfuse()
@ -68,7 +71,7 @@ class LangFuseTracer(BaseTracer):
trace_type: str,
inputs: dict[str, Any],
metadata: dict[str, Any] | None = None,
vertex: Optional["Vertex"] = None,
vertex: Vertex | None = None,
):
start_time = datetime.utcnow()
if not self._ready:
@ -86,10 +89,7 @@ class LangFuseTracer(BaseTracer):
"start_time": start_time,
}
if self.last_span:
span = self.last_span.span(**content_span)
else:
span = self.trace.span(**content_span)
span = self.last_span.span(**content_span) if self.last_span else self.trace.span(**content_span)
self.last_span = span
self.spans[trace_id] = span
@ -127,7 +127,7 @@ class LangFuseTracer(BaseTracer):
self._client.flush()
def get_langchain_callback(self) -> Optional["BaseCallbackHandler"]:
def get_langchain_callback(self) -> BaseCallbackHandler | None:
if not self._ready:
return None
return None # self._callback

View file

@ -243,7 +243,7 @@ class TracingService(Service):
def _cleanup_inputs(self, inputs: dict[str, Any]):
inputs = inputs.copy()
for key in inputs.keys():
for key in inputs:
if "api_key" in key:
inputs[key] = "*****" # avoid logging api_keys for security reasons
return inputs

View file

@ -20,10 +20,7 @@ def convert_to_langchain_type(value):
else:
value = value.to_lc_document()
elif isinstance(value, Data):
if "text" in value.data:
value = value.to_lc_document()
else:
value = value.data
value = value.to_lc_document() if "text" in value.data else value.data
return value

View file

@ -93,7 +93,7 @@ class KubernetesSecretService(VariableService, Service):
return []
names = []
for key in variables.keys():
for key in variables:
if key.startswith(CREDENTIAL_TYPE + "_"):
names.append(key[len(CREDENTIAL_TYPE) + 1 :])
else:

View file

@ -97,9 +97,8 @@ class Input(BaseModel):
def serialize_model(self, handler):
result = handler(self)
# If the field is str, we add the Text input type
if self.field_type in ["str", "Text"]:
if "input_types" not in result:
result["input_types"] = ["Text"]
if self.field_type in ["str", "Text"] and "input_types" not in result:
result["input_types"] = ["Text"]
if self.field_type == Text:
result["type"] = "str"
else:

View file

@ -56,9 +56,7 @@ def build_template_from_function(name: str, type_to_loader_dict: dict, add_funct
elif name_ not in ["name"]:
variables[class_field_items][name_] = value_
variables[class_field_items]["placeholder"] = (
docs.params[class_field_items] if class_field_items in docs.params else ""
)
variables[class_field_items]["placeholder"] = docs.params.get(class_field_items, "")
# Adding function to base classes to allow
# the output to be a function
base_classes = get_base_classes(_class)

View file

@ -169,6 +169,7 @@ select = [
"Q",
"RET",
"RSE",
"SIM",
"SLOT",
"T10",
"TID",