ref: Some ruff fixes from preview (#5420)

* Some ruff fixes from preview

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Christophe Bornet 2024-12-28 22:25:35 +01:00 committed by GitHub
commit e91bcc2520
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
79 changed files with 402 additions and 374 deletions

View file

@ -90,7 +90,7 @@ class AgentQL(Component):
except httpx.HTTPStatusError as e:
response = e.response
if response.status_code in [401, 403]:
if response.status_code in {401, 403}:
self.status = "Please, provide a valid API Key. You can create one at https://dev.agentql.com."
else:
try:

View file

@ -71,7 +71,7 @@ class DirectoryComponent(Component):
def load_directory(self) -> list[Data]:
path = self.path
types = self.types if self.types else TEXT_FILE_TYPES
types = self.types or TEXT_FILE_TYPES
depth = self.depth
max_concurrency = self.max_concurrency
load_hidden = self.load_hidden

View file

@ -131,12 +131,12 @@ class LangWatchComponent(Component):
# Clear component's dynamic attributes
for attr in list(self.__dict__.keys()):
if attr not in default_keys and attr not in [
if attr not in default_keys and attr not in {
"evaluators",
"dynamic_inputs",
"_code",
"current_evaluator",
]:
}:
delattr(self, attr)
# Add new dynamic inputs
@ -177,7 +177,7 @@ class LangWatchComponent(Component):
input_fields = [
field
for field in evaluator.get("requiredFields", []) + evaluator.get("optionalFields", [])
if field not in ["input", "output"]
if field not in {"input", "output"}
]
for field in input_fields:

View file

@ -126,8 +126,7 @@ class ConditionalRouterComponent(Component):
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
if field_name == "operator":
if field_value == "matches regex":
if "case_sensitive" in build_config:
del build_config["case_sensitive"]
build_config.pop("case_sensitive", None)
# Ensure case_sensitive is present for all other operators
elif "case_sensitive" not in build_config:
case_sensitive_input = next(

View file

@ -22,7 +22,7 @@ class OpenRouterComponent(LCModelComponent):
display_name = "OpenRouter"
description = (
"OpenRouter provides unified access to multiple AI models " "from different providers through a single API."
"OpenRouter provides unified access to multiple AI models from different providers through a single API."
)
icon = "OpenRouter"
@ -180,9 +180,7 @@ class OpenRouterComponent(LCModelComponent):
build_config["model_name"]["value"] = models[0]["id"]
tooltips = {
model["id"]: (
f"{model['name']}\n" f"Context Length: {model['context_length']}\n" f"{model['description']}"
)
model["id"]: (f"{model['name']}\nContext Length: {model['context_length']}\n{model['description']}")
for model in models
}
build_config["model_name"]["tooltips"] = tooltips

View file

@ -10,7 +10,7 @@ from langflow.utils.constants import MESSAGE_SENDER_AI
class NeedleComponent(Component):
display_name = "Needle Retriever"
description = "A retriever that uses the Needle API to search collections " "and generates responses using OpenAI."
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
documentation = "https://docs.needle-ai.com"
icon = "search"
name = "needle"
@ -105,8 +105,8 @@ class NeedleComponent(Component):
if str(output_type).lower().strip() == "chunks":
# If chunks selected, include full context and answer
docs = result["source_documents"]
context = "\n\n".join([f"Document {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {query}\n\n" f"Context:\n{context}\n\n" f"Answer: {result['answer']}"
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
else:
# If answer selected, only include the answer
text_content = result["answer"]

View file

@ -144,7 +144,7 @@ class DataFrameOperationsComponent(Component):
build_config["new_column_value"]["show"] = True
elif field_value == "Select Columns":
build_config["columns_to_select"]["show"] = True
elif field_value in ["Head", "Tail"]:
elif field_value in {"Head", "Tail"}:
build_config["num_rows"]["show"] = True
elif field_value == "Replace Value":
build_config["column_name"]["show"] = True

View file

@ -74,7 +74,7 @@ class MergeDataComponent(Component):
for key, value in data_input.data.items():
if key in result_data and isinstance(value, str):
if isinstance(result_data[key], list):
cast(list[str], result_data[key]).append(value)
cast("list[str]", result_data[key]).append(value)
else:
result_data[key] = [result_data[key], value]
else:

View file

@ -295,7 +295,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
# Always attempt to update the database list
if field_name in ["token", "api_endpoint", "collection_name"]:
if field_name in {"token", "api_endpoint", "collection_name"}:
# Update the database selector
build_config["api_endpoint"]["options"] = self._initialize_database_options()

View file

@ -1101,7 +1101,8 @@ class Graph:
return False
return self.vertex_edges_are_identical(vertex, other_vertex)
def vertex_edges_are_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
@staticmethod
def vertex_edges_are_identical(vertex: Vertex, other_vertex: Vertex) -> bool:
same_length = len(vertex.edges) == len(other_vertex.edges)
if not same_length:
return False
@ -1747,7 +1748,8 @@ class Graph:
new_edge = Edge(source, target, edge)
return new_edge
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> type[Vertex]:
@staticmethod
def _get_vertex_class(node_type: str, node_base_type: str, node_id: str) -> type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -1830,7 +1832,8 @@ class Graph:
self._record_snapshot()
return self
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> list[Vertex]:
@staticmethod
def get_children_by_vertex_type(vertex: Vertex, vertex_type: str) -> list[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -2059,7 +2062,8 @@ class Graph:
self._first_layer = first_layer
return first_layer
def sort_interface_components_first(self, vertices_layers: list[list[str]]) -> list[list[str]]:
@staticmethod
def sort_interface_components_first(vertices_layers: list[list[str]]) -> list[list[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
@ -2097,9 +2101,6 @@ class Graph:
This method is responsible for building the run map for the graph,
which maps each node in the graph to its corresponding run function.
Returns:
None
"""
self.run_manager.build_run_map(predecessor_map=self.predecessor_map, vertices_to_run=self.vertices_to_run)
@ -2169,7 +2170,8 @@ class Graph:
in_degree[vertex.id] = 0
return in_degree
def build_adjacency_maps(self, edges: list[CycleEdge]) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
@staticmethod
def build_adjacency_maps(edges: list[CycleEdge]) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""Returns the adjacency maps for the graph."""
predecessor_map: dict[str, list[str]] = defaultdict(list)
successor_map: dict[str, list[str]] = defaultdict(list)

View file

@ -110,9 +110,6 @@ def update_template(template, g_nodes) -> None:
Args:
template (dict): The new template to update the node with.
g_nodes (list): The list of nodes in the graph.
Returns:
None
"""
for value in template.values():
if not value.get("proxy"):
@ -161,9 +158,6 @@ def set_new_target_handle(proxy_id, new_edge, target_handle, node) -> None:
new_edge (dict): The new edge to be created.
target_handle (dict): The target handle of the edge.
node (dict): The node containing the edge.
Returns:
None
"""
new_edge["target"] = proxy_id
type_ = target_handle.get("type")

View file

@ -423,7 +423,7 @@ class Vertex:
else:
msg = f"Invalid value type {type(val)} for field {field_name}"
raise ValueError(msg)
elif val is not None and val != "":
elif val:
params[field_name] = val
if field.get("load_from_db"):
@ -596,7 +596,8 @@ class Vertex:
result = await value.get_result(self, target_handle_name=key)
self.params[key][sub_key] = result
def _is_vertex(self, value):
@staticmethod
def _is_vertex(value):
"""Checks if the provided value is an instance of Vertex."""
return isinstance(value, Vertex)

File diff suppressed because one or more lines are too long

View file

@ -1,3 +1,5 @@
from typing_extensions import override
from langflow.services.deps import get_settings_service
from langflow.utils.lazy_load import LazyLoadDictBase
@ -13,6 +15,7 @@ class AllTypesDict(LazyLoadDictBase):
"Custom": ["Custom Tool", "Python Function"],
}
@override
def get_type_dict(self):
from langflow.interface.types import get_all_types_dict

View file

@ -15,7 +15,7 @@ from loguru._file_sink import FileSink
from loguru._simple_sinks import AsyncSink
from platformdirs import user_cache_dir
from rich.logging import RichHandler
from typing_extensions import NotRequired
from typing_extensions import NotRequired, override
from langflow.settings import DEV
@ -285,6 +285,7 @@ class InterceptHandler(logging.Handler):
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging.
"""
@override
def emit(self, record) -> None:
# Get corresponding Loguru level if it exists
try:

View file

@ -19,7 +19,7 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from pydantic import PydanticDeprecatedSince20
from pydantic_core import PydanticSerializationError
from rich import print as rprint
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from langflow.api import health_check_router, log_router, router
from langflow.initial_setup.setup import (
@ -45,7 +45,7 @@ class RequestCancelledMiddleware(BaseHTTPMiddleware):
def __init__(self, app) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
sentinel = object()
async def cancel_handler():
@ -68,7 +68,7 @@ class RequestCancelledMiddleware(BaseHTTPMiddleware):
class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
response = await call_next(request)
except Exception as exc:

View file

@ -26,7 +26,8 @@ class ContentSizeLimitMiddleware:
self.app = app
self.logger = logger
def receive_wrapper(self, receive):
@staticmethod
def receive_wrapper(receive):
received = 0
async def inner():

View file

@ -32,7 +32,7 @@ class DataFrame(pandas_DataFrame):
>>> dataset = DataFrame({"name": ["John", "Jane"], "age": [30, 25]})
"""
def __init__(self, data: None | list[dict | Data] | dict | pd.DataFrame = None, **kwargs):
def __init__(self, data: list[dict | Data] | dict | pd.DataFrame | None = None, **kwargs):
if data is None:
super().__init__(**kwargs)
return

View file

@ -6,7 +6,7 @@ from typing_extensions import Protocol
from langflow.schema.message import ContentBlock, Message
from langflow.schema.playground_events import PlaygroundEvent
LoggableType: TypeAlias = str | dict | list | int | float | bool | None | BaseModel | PlaygroundEvent
LoggableType: TypeAlias = str | dict | list | int | float | bool | BaseModel | PlaygroundEvent | None
class LogFunctionType(Protocol):

View file

@ -1,3 +1,5 @@
from typing_extensions import override
from langflow.services.auth.service import AuthService
from langflow.services.factory import ServiceFactory
@ -8,5 +10,6 @@ class AuthServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(AuthService)
@override
def create(self, settings_service):
return AuthService(settings_service)

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.logging.logger import logger
from langflow.services.cache.disk import AsyncDiskCache
from langflow.services.cache.service import AsyncInMemoryCache, CacheService, RedisCache, ThreadingInMemoryCache
@ -15,6 +17,7 @@ class CacheServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(CacheService)
@override
def create(self, settings_service: SettingsService):
# Here you would have logic to create and configure a CacheService
# based on the settings_service

View file

@ -187,7 +187,8 @@ class DatabaseService(Service):
await session.commit()
logger.debug("Successfully assigned orphaned flows to the default superuser")
def _generate_unique_flow_name(self, original_name: str, existing_names: set[str]) -> str:
@staticmethod
def _generate_unique_flow_name(original_name: str, existing_names: set[str]) -> str:
"""Generate a unique flow name by adding or incrementing a suffix."""
if original_name not in existing_names:
return original_name
@ -212,7 +213,8 @@ class DatabaseService(Service):
return new_name
def _check_schema_health(self, connection) -> bool:
@staticmethod
def _check_schema_health(connection) -> bool:
inspector = inspect(connection)
model_mapping: dict[str, type[SQLModel]] = {
@ -249,7 +251,8 @@ class DatabaseService(Service):
async with self.with_session() as session, session.bind.connect() as conn:
await conn.run_sync(self._check_schema_health)
def init_alembic(self, alembic_cfg) -> None:
@staticmethod
def init_alembic(alembic_cfg) -> None:
logger.info("Initializing alembic")
command.ensure_version(alembic_cfg)
# alembic_cfg.attributes["connection"].commit()
@ -317,7 +320,8 @@ class DatabaseService(Service):
should_initialize_alembic = True
await asyncio.to_thread(self._run_migrations, should_initialize_alembic, fix)
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
@staticmethod
def try_downgrade_upgrade_until_success(alembic_cfg, retries=5) -> None:
# Try -1 then head, if it fails, try -2 then head, etc.
# until we reach the number of retries
for i in range(1, retries + 1):

View file

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.session.service import SessionService
@ -11,5 +13,6 @@ class SessionServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(SessionService)
@override
def create(self, cache_service: "CacheService"):
return SessionService(cache_service)

View file

@ -38,7 +38,8 @@ class SessionService(Service):
return graph, artifacts
def build_key(self, session_id, data_graph) -> str:
@staticmethod
def build_key(session_id, data_graph) -> str:
json_hash = compute_dict_hash(data_graph)
return f"{session_id}{':' if session_id else ''}{json_hash}"

View file

@ -1,3 +1,5 @@
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.settings.service import SettingsService
@ -13,6 +15,7 @@ class SettingsServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(SettingsService)
@override
def create(self):
# Here you would have logic to create and configure a SettingsService

View file

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.shared_component_cache.service import SharedComponentCacheService
@ -11,5 +13,6 @@ class SharedComponentCacheServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(SharedComponentCacheService)
@override
def create(self, settings_service: "SettingsService"):
return SharedComponentCacheService(expiration_time=settings_service.settings.cache_expire)

View file

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.socket.service import SocketIOService
@ -13,5 +15,6 @@ class SocketIOFactory(ServiceFactory):
service_class=SocketIOService,
)
@override
def create(self, cache_service: "CacheService"):
return SocketIOService(cache_service)

View file

@ -1,3 +1,5 @@
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.settings.service import SettingsService
from langflow.services.state.service import InMemoryStateService
@ -7,6 +9,7 @@ class StateServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(InMemoryStateService)
@override
def create(self, settings_service: SettingsService):
return InMemoryStateService(
settings_service,

View file

@ -1,4 +1,5 @@
from loguru import logger
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.session.service import SessionService
@ -12,6 +13,7 @@ class StorageServiceFactory(ServiceFactory):
StorageService,
)
@override
def create(self, session_service: SessionService, settings_service: SettingsService):
storage_type = settings_service.settings.storage_type
if storage_type.lower() == "local":

View file

@ -21,12 +21,15 @@ class LocalStorageService(StorageService):
async def save_file(self, flow_id: str, file_name: str, data: bytes) -> None:
"""Save a file in the local storage.
:param flow_id: The identifier for the flow.
:param file_name: The name of the file to be saved.
:param data: The byte content of the file.
:raises FileNotFoundError: If the specified flow does not exist.
:raises IsADirectoryError: If the file name is a directory.
:raises PermissionError: If there is no permission to write the file.
Args:
flow_id: The identifier for the flow.
file_name: The name of the file to be saved.
data: The byte content of the file.
Raises:
FileNotFoundError: If the specified flow does not exist.
IsADirectoryError: If the file name is a directory.
PermissionError: If there is no permission to write the file.
"""
folder_path = self.data_dir / flow_id
await folder_path.mkdir(parents=True, exist_ok=True)
@ -43,10 +46,15 @@ class LocalStorageService(StorageService):
async def get_file(self, flow_id: str, file_name: str) -> bytes:
"""Retrieve a file from the local storage.
:param flow_id: The identifier for the flow.
:param file_name: The name of the file to be retrieved.
:return: The byte content of the file.
:raises FileNotFoundError: If the file does not exist.
Args:
flow_id: The identifier for the flow.
file_name: The name of the file to be retrieved.
Returns:
The byte content of the file.
Raises:
FileNotFoundError: If the file does not exist.
"""
file_path = self.data_dir / flow_id / file_name
if not await file_path.exists():
@ -63,9 +71,14 @@ class LocalStorageService(StorageService):
async def list_files(self, flow_id: str):
"""List all files in a specified flow.
:param flow_id: The identifier for the flow.
:return: A list of file names.
:raises FileNotFoundError: If the flow directory does not exist.
Args:
flow_id: The identifier for the flow.
Returns:
A list of file names.
Raises:
FileNotFoundError: If the flow directory does not exist.
"""
if not isinstance(flow_id, str):
flow_id = str(flow_id)

View file

@ -18,10 +18,13 @@ class S3StorageService(StorageService):
async def save_file(self, folder: str, file_name: str, data) -> None:
"""Save a file to the S3 bucket.
:param folder: The folder in the bucket to save the file.
:param file_name: The name of the file to be saved.
:param data: The byte content of the file.
:raises Exception: If an error occurs during file saving.
Args:
folder: The folder in the bucket to save the file.
file_name: The name of the file to be saved.
data: The byte content of the file.
Raises:
Exception: If an error occurs during file saving.
"""
try:
self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data)
@ -36,10 +39,15 @@ class S3StorageService(StorageService):
async def get_file(self, folder: str, file_name: str):
"""Retrieve a file from the S3 bucket.
:param folder: The folder in the bucket where the file is stored.
:param file_name: The name of the file to be retrieved.
:return: The byte content of the file.
:raises Exception: If an error occurs during file retrieval.
Args:
folder: The folder in the bucket where the file is stored.
file_name: The name of the file to be retrieved.
Returns:
The byte content of the file.
Raises:
Exception: If an error occurs during file retrieval.
"""
try:
response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
@ -52,9 +60,14 @@ class S3StorageService(StorageService):
async def list_files(self, folder: str):
"""List all files in a specified folder of the S3 bucket.
:param folder: The folder in the bucket to list files from.
:return: A list of file names.
:raises Exception: If an error occurs during file listing.
Args:
folder: The folder in the bucket to list files from.
Returns:
A list of file names.
Raises:
Exception: If an error occurs during file listing.
"""
try:
response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder)
@ -69,9 +82,12 @@ class S3StorageService(StorageService):
async def delete_file(self, folder: str, file_name: str) -> None:
"""Delete a file from the S3 bucket.
:param folder: The folder in the bucket where the file is stored.
:param file_name: The name of the file to be deleted.
:raises Exception: If an error occurs during file deletion.
Args:
folder: The folder in the bucket where the file is stored.
file_name: The name of the file to be deleted.
Raises:
Exception: If an error occurs during file deletion.
"""
try:
self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.store.service import StoreService
@ -13,5 +15,6 @@ class StoreServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(StoreService)
@override
def create(self, settings_service: SettingsService):
return StoreService(settings_service)

View file

@ -164,7 +164,8 @@ class StoreService(Service):
except Exception: # noqa: BLE001
logger.opt(exception=True).debug("Webhook failed")
def build_tags_filter(self, tags: list[str]):
@staticmethod
def build_tags_filter(tags: list[str]):
tags_filter: dict[str, Any] = {"tags": {"_and": []}}
for tag in tags:
tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}})
@ -249,7 +250,8 @@ class StoreService(Service):
return filter_conditions
def build_liked_filter(self):
@staticmethod
def build_liked_filter():
user_data = user_data_var.get()
# params["filter"] = json.dumps({"user_created": {"_eq": user_data["id"]}})
if not user_data:

View file

@ -1,3 +1,5 @@
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.task.service import TaskService
@ -6,6 +8,7 @@ class TaskServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(TaskService)
@override
def create(self):
# Here you would have logic to create and configure a TaskService
return TaskService()

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.telemetry.service import TelemetryService
@ -13,5 +15,6 @@ class TelemetryServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(TelemetryService)
@override
def create(self, settings_service: SettingsService):
return TelemetryService(settings_service)

View file

@ -132,7 +132,8 @@ class TelemetryService(Service):
except Exception: # noqa: BLE001
logger.exception("Error flushing logs")
async def _cancel_task(self, task: asyncio.Task, cancel_msg: str) -> None:
@staticmethod
async def _cancel_task(task: asyncio.Task, cancel_msg: str) -> None:
task.cancel(cancel_msg)
await asyncio.wait([task])
if not task.cancelled():

View file

@ -319,7 +319,8 @@ class ArizePhoenixTracer(BaseTracer):
return value
def _error_to_string(self, error: Exception | None):
@staticmethod
def _error_to_string(error: Exception | None):
"""Converts an error to a string with traceback details."""
error_message = None
if error:
@ -327,11 +328,13 @@ class ArizePhoenixTracer(BaseTracer):
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
return error_message
def _get_current_timestamp(self) -> int:
@staticmethod
def _get_current_timestamp() -> int:
"""Gets the current UTC timestamp in nanoseconds."""
return int(datetime.now(timezone.utc).timestamp() * 1_000_000_000)
def _safe_json_dumps(self, obj: Any, **kwargs: Any) -> str:
@staticmethod
def _safe_json_dumps(obj: Any, **kwargs: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
return json.dumps(obj, default=str, ensure_ascii=False, **kwargs)
@ -359,6 +362,7 @@ class ArizePhoenixTracer(BaseTracer):
else:
current_span.set_status(Status(StatusCode.OK))
@override
def get_langchain_callback(self) -> BaseCallbackHandler | None:
"""Returns the LangChain callback handler if applicable."""
return None

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.tracing.service import TracingService
@ -13,5 +15,6 @@ class TracingServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(TracingService)
@override
def create(self, settings_service: SettingsService):
return TracingService(settings_service)

View file

@ -135,7 +135,8 @@ class LangFuseTracer(BaseTracer):
return None
return None # self._callback
def _get_config(self) -> dict:
@staticmethod
def _get_config() -> dict:
secret_key = os.getenv("LANGFUSE_SECRET_KEY", None)
public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None)
host = os.getenv("LANGFUSE_HOST", None)

View file

@ -7,6 +7,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from loguru import logger
from typing_extensions import override
from langflow.schema.data import Data
from langflow.services.tracing.base import BaseTracer
@ -140,7 +141,8 @@ class LangSmithTracer(BaseTracer):
child.post()
self._child_link[trace_name] = child.get_url()
def _error_to_string(self, error: Exception | None):
@staticmethod
def _error_to_string(error: Exception | None):
error_message = None
if error:
string_stacktrace = traceback.format_exception(error)
@ -163,5 +165,6 @@ class LangSmithTracer(BaseTracer):
self._run_tree.post()
self._run_link = self._run_tree.get_url()
@override
def get_langchain_callback(self) -> BaseCallbackHandler | None:
return None

View file

@ -264,7 +264,8 @@ class TracingService(Service):
self.outputs[trace_name] |= outputs or {}
self.outputs_metadata[trace_name] |= output_metadata or {}
def _cleanup_inputs(self, inputs: dict[str, Any]):
@staticmethod
def _cleanup_inputs(inputs: dict[str, Any]):
inputs = inputs.copy()
for key in inputs:
if "api_key" in key:

View file

@ -177,9 +177,6 @@ async def clean_transactions(settings_service: SettingsService, session: AsyncSe
Args:
settings_service: The settings service containing configuration like max_transactions_to_keep
session: The database session to use for the deletion
Returns:
None
"""
try:
# Delete transactions using bulk delete
@ -209,9 +206,6 @@ async def clean_vertex_builds(settings_service: SettingsService, session: AsyncS
Args:
settings_service: The settings service containing configuration like max_vertex_builds_to_keep
session: The database session to use for the deletion
Returns:
None
"""
try:
# Delete vertex builds using bulk delete

View file

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from langflow.services.factory import ServiceFactory
from langflow.services.variable.service import DatabaseVariableService, VariableService
@ -13,6 +15,7 @@ class VariableServiceFactory(ServiceFactory):
def __init__(self) -> None:
super().__init__(VariableService)
@override
def create(self, settings_service: SettingsService):
# here you would have logic to create and configure a VariableService
# based on the settings_service

View file

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
from loguru import logger
from sqlmodel import select
from typing_extensions import override
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
@ -141,6 +142,7 @@ class DatabaseVariableService(VariableService, Service):
await session.refresh(db_variable)
return db_variable
@override
async def delete_variable(
self,
user_id: UUID | str,
@ -155,6 +157,7 @@ class DatabaseVariableService(VariableService, Service):
await session.delete(variable)
await session.commit()
@override
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)
variable = (await session.exec(stmt)).first()

View file

@ -36,7 +36,8 @@ class KeyedWorkerLockManager:
def __init__(self) -> None:
self.locks_dir = Path(user_cache_dir("langflow"), ensure_exists=True) / "worker_locks"
def _validate_key(self, key: str) -> bool:
@staticmethod
def _validate_key(key: str) -> bool:
"""Validate that the string only contains alphanumeric characters and underscores.
Parameters:

View file

@ -12,6 +12,7 @@ from docstring_parser import parse
from langflow.logging.logger import logger
from langflow.schema import Data
from langflow.services.deps import get_settings_service
from langflow.services.utils import initialize_settings_service
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
@ -416,8 +417,6 @@ async def update_settings(
max_file_size_upload: int = 100,
) -> None:
"""Update the settings from a config file."""
from langflow.services.utils import initialize_settings_service
# Check for database_url in the environment variables
initialize_settings_service()

View file

@ -4,7 +4,11 @@ from langflow.utils import constants
def truncate_long_strings(data, max_length=None):
"""Recursively traverse the dictionary or list and truncate strings longer than max_length."""
"""Recursively traverse the dictionary or list and truncate strings longer than max_length.
Returns:
The data with strings truncated if they exceed the max length.
"""
if max_length is None:
max_length = constants.MAX_TEXT_LENGTH

View file

@ -170,9 +170,15 @@ def create_function(code, function_name):
def create_class(code, class_name):
"""Dynamically create a class from a string of code and a specified class name.
:param code: String containing the Python code defining the class
:param class_name: Name of the class to be created
:return: A function that, when called, returns an instance of the created class
Args:
code: String containing the Python code defining the class
class_name: Name of the class to be created
Returns:
A function that, when called, returns an instance of the created class
Raises:
ValueError: If the code contains syntax errors or the class definition is invalid
"""
if not hasattr(ast, "TypeIgnore"):
ast.TypeIgnore = create_type_ignore_class()
@ -199,7 +205,8 @@ def create_class(code, class_name):
def create_type_ignore_class():
"""Create a TypeIgnore class for AST module if it doesn't exist.
:return: TypeIgnore class
Returns:
TypeIgnore class
"""
class TypeIgnore(ast.AST):
@ -211,8 +218,15 @@ def create_type_ignore_class():
def prepare_global_scope(code, module):
"""Prepares the global scope with necessary imports from the provided code module.
:param module: AST parsed module
:return: Dictionary representing the global scope with imported modules
Args:
code: The Python code
module: AST parsed module
Returns:
Dictionary representing the global scope with imported modules
Raises:
ModuleNotFoundError: If a module is not found in the code
"""
exec_globals = globals().copy()
exec_globals.update(get_default_imports(code))
@ -250,9 +264,12 @@ def prepare_global_scope(code, module):
def extract_class_code(module, class_name):
"""Extracts the AST node for the specified class from the module.
:param module: AST parsed module
:param class_name: Name of the class to extract
:return: AST node of the specified class
Args:
module: AST parsed module
class_name: Name of the class to extract
Returns:
AST node of the specified class
"""
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
@ -263,8 +280,11 @@ def extract_class_code(module, class_name):
def compile_class_code(class_code):
"""Compiles the AST node of a class into a code object.
:param class_code: AST node of the class
:return: Compiled code object of the class
Args:
class_code: AST node of the class
Returns:
Compiled code object of the class
"""
return compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
@ -272,10 +292,13 @@ def compile_class_code(class_code):
def build_class_constructor(compiled_class, exec_globals, class_name):
"""Builds a constructor function for the dynamically created class.
:param compiled_class: Compiled code object of the class
:param exec_globals: Global scope with necessary imports
:param class_name: Name of the class
:return: Constructor function for the class
Args:
compiled_class: Compiled code object of the class
exec_globals: Global scope with necessary imports
class_name: Name of the class
Returns:
Constructor function for the class
"""
exec(compiled_class, exec_globals, locals())
exec_globals[class_name] = locals()[class_name]
@ -313,9 +336,12 @@ def get_default_imports(code_string):
def find_names_in_code(code, names):
"""Finds if any of the specified names are present in the given code string.
:param code: The source code as a string.
:param names: A list of names to check for in the code.
:return: A set of names that are found in the code.
Args:
code: The source code as a string.
names: A list of names to check for in the code.
Returns:
A set of names that are found in the code.
"""
return {name for name in names if name in code}
@ -339,7 +365,7 @@ def extract_class_name(code: str) -> str:
str: Name of the first Component subclass found
Raises:
ValueError: If no Component subclass is found in the code
TypeError: If no Component subclass is found in the code
"""
try:
module = ast.parse(code)

View file

@ -1,4 +1,7 @@
from importlib import metadata
import httpx
from packaging import version as pkg_version
from langflow.logging.logger import logger
@ -22,8 +25,6 @@ def _get_version_info():
Raises:
ValueError: If the package is not found from the list of package names.
"""
from importlib import metadata
package_options = [
("langflow", "Langflow"),
("langflow-base", "Langflow Base"),
@ -57,8 +58,9 @@ VERSION_INFO = _get_version_info()
def is_pre_release(v: str) -> bool:
"""Whether the version is a pre-release version.
Returns a boolean indicating whether the version is a pre-release version,
as per the definition of a pre-release segment from PEP 440.
Returns:
Whether the version is a pre-release version,
as per the definition of a pre-release segment from PEP 440.
"""
return any(label in v for label in ["a", "b", "rc"])
@ -66,15 +68,14 @@ def is_pre_release(v: str) -> bool:
def is_nightly(v: str) -> bool:
"""Whether the version is a dev (nightly) version.
Returns a boolean indicating whether the version is a dev (nightly) version,
as per the definition of a dev segment from PEP 440.
Returns:
Whether the version is a dev (nightly) version,
as per the definition of a dev segment from PEP 440.
"""
return "dev" in v
def fetch_latest_version(package_name: str, *, include_prerelease: bool) -> str | None:
from packaging import version as pkg_version
package_name = package_name.replace(" ", "-").lower()
try:
response = httpx.get(f"https://pypi.org/pypi/{package_name}/json")

View file

@ -18,7 +18,11 @@ def test_celery(word: str) -> str:
@celery_app.task(bind=True, soft_time_limit=30, max_retries=3)
def build_vertex(self, vertex: Vertex) -> Vertex:
"""Build a vertex."""
"""Build a vertex.
Returns:
The built vertex.
"""
try:
vertex.task_id = self.request.id
async_to_sync(vertex.build)()

View file

@ -1,6 +1,7 @@
"""Module for package versioning."""
import contextlib
from importlib import metadata
def get_version() -> str:
@ -14,8 +15,6 @@ def get_version() -> str:
Raises:
ValueError: If the package is not found from the list of package names.
"""
from importlib import metadata
pkg_names = [
"langflow",
"langflow-base",

View file

@ -17,8 +17,10 @@ from blockbuster import blockbuster_ctx
from dotenv import load_dotenv
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from langflow.components.inputs import ChatInput
from langflow.graph import Graph
from langflow.initial_setup.constants import STARTER_FOLDER_NAME
from langflow.main import create_app
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.flow.model import Flow, FlowCreate
@ -155,8 +157,6 @@ def caplog(caplog: pytest.LogCaptureFixture):
@pytest.fixture
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
app = create_app()
async with AsyncClient(app=app, base_url="http://testserver", http2=True) as client:
yield client
@ -228,8 +228,6 @@ def distributed_client_fixture(
# def get_session_override():
# return session
from langflow.main import create_app
app = create_app()
# app.dependency_overrides[get_session] = get_session_override
@ -357,8 +355,6 @@ async def client_fixture(
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
from langflow.main import create_app
app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
@ -482,8 +478,6 @@ async def flow(
json_flow: str,
active_user,
):
from langflow.services.database.models.flow.model import FlowCreate
loaded_json = json.loads(json_flow)
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
@ -577,8 +571,6 @@ async def added_webhook_test(client, json_webhook_test, logged_in_headers):
@pytest.fixture
async def flow_component(client: AsyncClient, logged_in_headers):
from langflow.components.inputs import ChatInput
chat_input = ChatInput()
graph = Graph(start=chat_input, end=chat_input)
graph_dict = graph.dump(name="Chat Input Component")

View file

@ -1,12 +1,17 @@
import pytest
from langflow.components.astra_assistants import (
AssistantsCreateAssistant,
AssistantsCreateThread,
AssistantsGetAssistantName,
AssistantsListAssistants,
AssistantsRun,
)
from tests.integration.utils import run_single_component
@pytest.mark.api_key_required
async def test_list_assistants():
from langflow.components.astra_assistants import AssistantsListAssistants
results = await run_single_component(
AssistantsListAssistants,
inputs={},
@ -16,8 +21,6 @@ async def test_list_assistants():
@pytest.mark.api_key_required
async def test_create_assistants():
from langflow.components.astra_assistants import AssistantsCreateAssistant
results = await run_single_component(
AssistantsCreateAssistant,
inputs={
@ -36,8 +39,6 @@ async def test_create_assistants():
@pytest.mark.api_key_required
async def test_create_thread():
from langflow.components.astra_assistants import AssistantsCreateThread
results = await run_single_component(
AssistantsCreateThread,
inputs={},
@ -48,8 +49,6 @@ async def test_create_thread():
async def get_assistant_name(assistant_id):
from langflow.components.astra_assistants import AssistantsGetAssistantName
results = await run_single_component(
AssistantsGetAssistantName,
inputs={
@ -60,8 +59,6 @@ async def get_assistant_name(assistant_id):
async def run_assistant(assistant_id, thread_id):
from langflow.components.astra_assistants import AssistantsRun
results = await run_single_component(
AssistantsRun,
inputs={

View file

@ -2,6 +2,7 @@ import os
import pytest
from astrapy.db import AstraDB
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
from langchain_core.documents import Document
from langflow.components.embeddings import OpenAIEmbeddingsComponent
from langflow.components.vectorstores import AstraDBVectorStoreComponent
@ -37,8 +38,6 @@ def astradb_client():
@pytest.mark.api_key_required
async def test_base(astradb_client: AstraDB):
from langflow.components.embeddings import OpenAIEmbeddingsComponent
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
@ -88,8 +87,6 @@ async def test_astra_embeds_and_search():
@pytest.mark.api_key_required
def test_astra_vectorize():
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
@ -132,8 +129,6 @@ def test_astra_vectorize():
@pytest.mark.api_key_required
def test_astra_vectorize_with_provider_api_key():
"""Tests vectorize using an openai api key."""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
application_token = get_astradb_application_token()
api_endpoint = get_astradb_api_endpoint()
@ -189,8 +184,6 @@ def test_astra_vectorize_with_provider_api_key():
@pytest.mark.api_key_required
def test_astra_vectorize_passes_authentication():
"""Tests vectorize using the authentication parameter."""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
store = None
try:
application_token = get_astradb_application_token()

View file

@ -10,7 +10,7 @@ from tests.integration.utils import ComponentInputHandle, run_single_component
@pytest.mark.api_key_required
async def test_csv_output_parser_openai():
format_instructions = ComponentInputHandle(
format_instructions_ = ComponentInputHandle(
clazz=OutputParserComponent,
inputs={},
output_name="format_instructions",
@ -24,7 +24,7 @@ async def test_csv_output_parser_openai():
clazz=PromptComponent,
inputs={
"template": "List the first five positive integers.\n\n{format_instructions}",
"format_instructions": format_instructions,
"format_instructions": format_instructions_,
},
output_name="prompt",
)

View file

@ -28,7 +28,7 @@ async def test_initialize_services():
@pytest.mark.benchmark
async def test_setup_llm_caching():
def test_setup_llm_caching():
"""Benchmark LLM caching setup."""
from langflow.interface.utils import setup_llm_caching

View file

@ -11,6 +11,7 @@ import anyio
import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from langflow.main import create_app
from langflow.services.deps import get_storage_service
from langflow.services.storage.service import StorageService
from sqlmodel import Session
@ -60,8 +61,6 @@ async def files_client_fixture(
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
from langflow.main import create_app
app = create_app()
return app, db_path
@ -81,14 +80,14 @@ async def files_client_fixture(
@pytest.fixture
async def max_file_size_upload_fixture(monkeypatch):
def max_file_size_upload_fixture(monkeypatch):
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "1")
yield
monkeypatch.undo()
@pytest.fixture
async def max_file_size_upload_10mb_fixture(monkeypatch):
def max_file_size_upload_10mb_fixture(monkeypatch):
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "10")
yield
monkeypatch.undo()

View file

@ -11,7 +11,7 @@ from langflow.schema.data import Data
from pydantic import BaseModel
async def test_component_tool():
def test_component_tool():
calculator_component = CalculatorToolComponent()
component_toolkit = ComponentToolkit(component=calculator_component)
component_tool = component_toolkit.get_tools()[0]

View file

@ -118,7 +118,7 @@ class AllInputsComponent(Component):
return data
async def test_component_inputs_toolkit():
def test_component_inputs_toolkit():
component = AllInputsComponent()
component_toolkit = ComponentToolkit(component=component)
component_tool = component_toolkit.get_tools()[0]

View file

@ -1,8 +1,11 @@
import re
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.language_models import BaseLanguageModel
from langflow.components.helpers.structured_output import StructuredOutputComponent
from langflow.helpers.base_model import build_model_from_schema
from langflow.inputs.inputs import TableInput
from langflow.schema.data import Data
from pydantic import BaseModel
from typing_extensions import override
@ -12,8 +15,6 @@ class TestStructuredOutputComponent:
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by
# the mock function
def test_successful_structured_output_generation_with_patch_with_config(self):
from unittest.mock import patch
class MockLanguageModel(BaseLanguageModel):
@override
def with_structured_output(self, *args, **kwargs):
@ -87,15 +88,11 @@ class TestStructuredOutputComponent:
multiple=False,
)
with pytest.raises(TypeError, match="Language model does not support structured output."):
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
component.build_structured_output()
# Correctly builds the output model from the provided schema
def test_correctly_builds_output_model(self):
# Import internal organization modules, packages, and libraries
from langflow.helpers.base_model import build_model_from_schema
from langflow.inputs.inputs import TableInput
# Setup
component = StructuredOutputComponent()
schema = [
@ -134,10 +131,6 @@ class TestStructuredOutputComponent:
# Properly handles multiple outputs when 'multiple' is set to True
def test_handles_multiple_outputs(self):
# Import internal organization modules, packages, and libraries
from langflow.helpers.base_model import build_model_from_schema
from langflow.inputs.inputs import TableInput
# Setup
component = StructuredOutputComponent()
schema = [
@ -261,5 +254,5 @@ class TestStructuredOutputComponent:
multiple=False,
)
with pytest.raises(TypeError, match="Language model does not support structured output."):
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
component.build_structured_output()

View file

@ -44,7 +44,7 @@ def test_operations(sample_dataframe, operation, expected_columns, expected_valu
component.new_column_name = "Z"
elif operation == "Select Columns":
component.columns_to_select = ["A", "C"]
elif operation in ("Head", "Tail"):
elif operation in {"Head", "Tail"}:
component.num_rows = 1
elif operation == "Replace Value":
component.column_name = "C"

View file

@ -1,3 +1,5 @@
import re
import pytest
from langflow.components.processing import CreateDataComponent
from langflow.schema import Data
@ -48,7 +50,7 @@ def test_update_build_config_exceed_limit(create_data_component):
"value": False,
},
}
with pytest.raises(ValueError, match="Number of fields cannot exceed 15."):
with pytest.raises(ValueError, match=re.escape("Number of fields cannot exceed 15.")):
create_data_component.update_build_config(build_config, 16, "number_of_fields")

View file

@ -1,3 +1,5 @@
import re
import pytest
from langflow.components.processing import UpdateDataComponent
from langflow.schema import Data
@ -48,7 +50,7 @@ def test_update_build_config_exceed_limit(update_data_component):
"value": False,
},
}
with pytest.raises(ValueError, match="Number of fields cannot exceed 15."):
with pytest.raises(ValueError, match=re.escape("Number of fields cannot exceed 15.")):
update_data_component.update_build_config(build_config, 16, "number_of_fields")

View file

@ -14,11 +14,6 @@ from langflow.schema.properties import Properties, Source
from langflow.template.field.base import Output
async def create_event_queue():
"""Create a queue for testing events."""
return asyncio.Queue()
def blocking_cb(manager, event_type, data):
time.sleep(0.01)
manager.send_event(event_type=event_type, data=data)
@ -43,7 +38,7 @@ class ComponentForTesting(Component):
async def test_component_message_sending():
"""Test component's message sending functionality."""
# Create event queue and manager
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
event_manager.register_event("on_message", "message", callback=blocking_cb)
@ -75,7 +70,7 @@ async def test_component_message_sending():
async def test_component_tool_output():
"""Test component's tool output functionality."""
# Create event queue and manager
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
# Create component
@ -110,7 +105,7 @@ async def test_component_tool_output():
async def test_component_error_handling():
"""Test component's error handling."""
# Create event queue and manager
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
# Create component
@ -141,7 +136,7 @@ async def test_component_error_handling():
async def test_component_build_results():
"""Test component's build_results functionality."""
# Create event queue and manager
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
# Create component
@ -173,7 +168,7 @@ async def test_component_build_results():
async def test_component_logging():
"""Test component's logging functionality."""
# Create event queue and manager
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
# Create component
@ -207,7 +202,7 @@ async def test_component_logging():
@pytest.mark.usefixtures("client")
async def test_component_streaming_message():
"""Test component's streaming message functionality."""
queue = await create_event_queue()
queue = asyncio.Queue()
event_manager = EventManager(queue)
event_manager.register_event("on_token", "token", blocking_cb)

View file

@ -40,7 +40,7 @@ class TestEventManager:
def test_accessing_non_registered_event_callback_with_recommended_fix(self):
queue = asyncio.Queue()
manager = EventManager(queue)
result = manager.__getattr__("non_registered_event")
result = manager.non_registered_event
assert result == manager.noop
# Accessing a registered event callback via __getattr__
@ -130,8 +130,6 @@ class TestEventManager:
assert len(manager.events) == 1000
# Verifying the uniqueness of event IDs for each event triggered using await with asyncio decorator
import pytest
async def test_event_id_uniqueness_with_await(self):
queue = asyncio.Queue()
manager = EventManager(queue)

View file

@ -1,11 +1,10 @@
from unittest.mock import Mock, patch
from langflow.exceptions.api import APIException, ExceptionBody
from langflow.services.database.models.flow.model import Flow
def test_api_exception():
from langflow.exceptions.api import APIException, ExceptionBody
mock_exception = Exception("Test exception")
mock_flow = Mock(spec=Flow)
mock_outdated_components = ["component1", "component2"]
@ -45,8 +44,6 @@ def test_api_exception():
def test_api_exception_no_flow():
from langflow.exceptions.api import APIException, ExceptionBody
# Mock data
mock_exception = Exception("Test exception")

View file

@ -1,3 +1,5 @@
import re
import pytest
from langflow.components.inputs import ChatInput
from langflow.components.models import OpenAIModelComponent
@ -25,5 +27,7 @@ Answer:
chat_output = ChatOutput()
chat_output.set(input_value=openai_component.text_response)
with pytest.raises(ValueError, match="Component OpenAI field 'input_values' might not be a valid input."):
with pytest.raises(
ValueError, match=re.escape("Component OpenAI field 'input_values' might not be a valid input.")
):
Graph(start=chat_input, end=chat_output)

View file

@ -20,7 +20,7 @@ async def test_graph_not_prepared():
await graph.astep()
async def test_graph(caplog: pytest.LogCaptureFixture):
def test_graph(caplog: pytest.LogCaptureFixture):
chat_input = ChatInput()
chat_output = ChatOutput()
graph = Graph()

View file

@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Literal
import pytest
from langflow.components.inputs import ChatInput
from langflow.inputs.inputs import DropdownInput, FileInput, IntInput, NestedDictInput, StrInput
from langflow.io.schema import create_input_schema
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
def test_create_input_schema():
from langflow.io.schema import create_input_schema
schema = create_input_schema(ChatInput.inputs)
assert schema.__name__ == "InputSchema"
@ -17,9 +17,6 @@ def test_create_input_schema():
class TestCreateInputSchema:
# Single input type is converted to list and processed correctly
def test_single_input_type_conversion(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field")
schema = create_input_schema([input_instance])
assert schema.__name__ == "InputSchema"
@ -27,9 +24,6 @@ class TestCreateInputSchema:
# Multiple input types are processed and included in the schema
def test_multiple_input_types(self):
from langflow.inputs.inputs import IntInput, StrInput
from langflow.io.schema import create_input_schema
inputs = [StrInput(name="str_field"), IntInput(name="int_field")]
schema = create_input_schema(inputs)
assert schema.__name__ == "InputSchema"
@ -38,9 +32,6 @@ class TestCreateInputSchema:
# Fields are correctly created with appropriate types and attributes
def test_fields_creation_with_correct_types_and_attributes(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", info="Test Info", required=True)
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -49,18 +40,12 @@ class TestCreateInputSchema:
# Schema model is created and returned successfully
def test_schema_model_creation(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field")
schema = create_input_schema([input_instance])
assert schema.__name__ == "InputSchema"
# Default values are correctly assigned to fields
def test_default_values_assignment(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", value="default_value")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -68,16 +53,11 @@ class TestCreateInputSchema:
# Empty list of inputs is handled without errors
def test_empty_list_of_inputs(self):
from langflow.io.schema import create_input_schema
schema = create_input_schema([])
assert schema.__name__ == "InputSchema"
# Input with missing optional attributes (e.g., display_name, info) is processed correctly
def test_missing_optional_attributes(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -86,9 +66,6 @@ class TestCreateInputSchema:
# Input with is_list attribute set to True is processed correctly
def test_is_list_attribute_processing(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", is_list=True)
schema = create_input_schema([input_instance])
field_info: FieldInfo = schema.model_fields["test_field"]
@ -96,9 +73,6 @@ class TestCreateInputSchema:
# Input with options attribute is processed correctly
def test_options_attribute_processing(self):
from langflow.inputs.inputs import DropdownInput
from langflow.io.schema import create_input_schema
input_instance = DropdownInput(name="test_field", options=["option1", "option2"])
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -106,9 +80,6 @@ class TestCreateInputSchema:
# Non-standard field types are handled correctly
def test_non_standard_field_types_handling(self):
from langflow.inputs.inputs import FileInput
from langflow.io.schema import create_input_schema
input_instance = FileInput(name="file_field")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["file_field"]
@ -116,9 +87,6 @@ class TestCreateInputSchema:
# Inputs with mixed required and optional fields are processed correctly
def test_mixed_required_optional_fields_processing(self):
from langflow.inputs.inputs import IntInput, StrInput
from langflow.io.schema import create_input_schema
inputs = [
StrInput(name="required_field", required=True),
IntInput(name="optional_field", required=False),
@ -132,9 +100,6 @@ class TestCreateInputSchema:
# Inputs with complex nested structures are handled correctly
def test_complex_nested_structures_handling(self):
from langflow.inputs.inputs import NestedDictInput
from langflow.io.schema import create_input_schema
nested_input = NestedDictInput(name="nested_field", value={"key": "value"})
schema = create_input_schema([nested_input])
@ -145,9 +110,6 @@ class TestCreateInputSchema:
# Creating a schema from a single input type
def test_single_input_type_replica(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field")
schema = create_input_schema([input_instance])
assert schema.__name__ == "InputSchema"
@ -155,18 +117,12 @@ class TestCreateInputSchema:
# Creating a schema from a list of input types
def test_passing_input_type_directly(self):
from langflow.inputs.inputs import IntInput, StrInput
from langflow.io.schema import create_input_schema
inputs = StrInput(name="str_field"), IntInput(name="int_field")
with pytest.raises(TypeError):
create_input_schema(inputs)
# Handling input types with options correctly
def test_options_handling(self):
from langflow.inputs.inputs import DropdownInput
from langflow.io.schema import create_input_schema
input_instance = DropdownInput(name="test_field", options=["option1", "option2"])
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -174,9 +130,6 @@ class TestCreateInputSchema:
# Handling input types with is_list attribute correctly
def test_is_list_handling(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", is_list=True)
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -184,9 +137,6 @@ class TestCreateInputSchema:
# Converting FieldTypes to corresponding Python types
def test_field_types_conversion(self):
from langflow.inputs.inputs import IntInput
from langflow.io.schema import create_input_schema
input_instance = IntInput(name="int_field")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["int_field"]
@ -194,9 +144,6 @@ class TestCreateInputSchema:
# Setting default values for non-required fields
def test_default_values_for_non_required_fields(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", value="default_value")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -204,9 +151,6 @@ class TestCreateInputSchema:
# Handling input types with missing attributes
def test_missing_attributes_handling(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field")
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -217,9 +161,6 @@ class TestCreateInputSchema:
# Handling input types with None as default value
def test_none_default_value_handling(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test_field", value=None)
schema = create_input_schema([input_instance])
field_info = schema.model_fields["test_field"]
@ -227,9 +168,6 @@ class TestCreateInputSchema:
# Handling input types with special characters in names
def test_special_characters_in_names_handling(self):
from langflow.inputs.inputs import StrInput
from langflow.io.schema import create_input_schema
input_instance = StrInput(name="test@field#name")
schema = create_input_schema([input_instance])
assert "test@field#name" in schema.model_fields

View file

@ -1,3 +1,5 @@
import base64
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langflow.schema.data import Data
@ -9,8 +11,6 @@ def sample_image(tmp_path):
"""Create a sample image file for testing."""
image_path = tmp_path / "test_image.png"
# Create a small black 1x1 pixel PNG file
import base64
image_content = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
)

View file

@ -1,3 +1,4 @@
import base64
import shutil
from datetime import datetime, timezone
from pathlib import Path
@ -29,8 +30,6 @@ def sample_image(langflow_cache_dir):
# Create the image in the flow directory
image_path = flow_dir / "test_image.png"
# Create a small black 1x1 pixel PNG file
import base64
image_content = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
)

View file

@ -1,3 +1,5 @@
import copy
import pytest
from langchain_core.documents import Document
from langflow.schema import Data
@ -62,8 +64,6 @@ def test_custom_attribute_get_set_del():
def test_deep_copy():
import copy
record1 = Data(data={"text": "Hello", "number": 10})
record2 = copy.deepcopy(record1)
assert record2.text == "Hello"

View file

@ -13,6 +13,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.models.folder.model import FolderCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlalchemy import text
@pytest.fixture(scope="module")
@ -619,8 +620,6 @@ async def test_sqlite_pragmas():
db_service = get_db_service()
async with db_service.with_session() as session:
from sqlalchemy import text
assert (await session.exec(text("PRAGMA journal_mode;"))).scalar() == "wal"
assert (await session.exec(text("PRAGMA synchronous;"))).scalar() == 1

View file

@ -64,7 +64,7 @@ class TestInput:
# Empty lists and edge cases
assert set(post_process_type(list)) == {list}
assert set(post_process_type(Union[int, None])) == {int, NoneType} # noqa: UP007
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # noqa: UP007
assert set(post_process_type(Union[list[None], None])) == {None, NoneType} # noqa: UP007
# Handling complex nested structures
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} # noqa: UP007

View file

@ -1,3 +1,4 @@
import re
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
@ -48,7 +49,7 @@ def test_increment_counter_empty_label(opentelemetry_instance):
def test_increment_counter_missing_mandatory_label(opentelemetry_instance):
with pytest.raises(ValueError, match="Missing required labels: {'flow_id'}"):
with pytest.raises(ValueError, match=re.escape("Missing required labels: {'flow_id'}")):
opentelemetry_instance.increment_counter(metric_name="num_files_uploaded", value=5, labels={"service": "one"})

View file

@ -16,51 +16,57 @@ def sample_image(tmp_path):
return image_path
class TestImageUtils:
def test_convert_image_to_base64_success(self, sample_image):
"""Test successful conversion of image to base64."""
base64_str = convert_image_to_base64(sample_image)
assert isinstance(base64_str, str)
# Verify it's valid base64
assert base64.b64decode(base64_str)
def test_convert_image_to_base64_success(sample_image):
"""Test successful conversion of image to base64."""
base64_str = convert_image_to_base64(sample_image)
assert isinstance(base64_str, str)
# Verify it's valid base64
assert base64.b64decode(base64_str)
def test_convert_image_to_base64_empty_path(self):
"""Test conversion with empty path."""
with pytest.raises(ValueError, match="Image path cannot be empty"):
convert_image_to_base64("")
def test_convert_image_to_base64_nonexistent_file(self):
"""Test conversion with non-existent file."""
with pytest.raises(FileNotFoundError, match="Image file not found"):
convert_image_to_base64("nonexistent.png")
def test_convert_image_to_base64_empty_path():
"""Test conversion with empty path."""
with pytest.raises(ValueError, match="Image path cannot be empty"):
convert_image_to_base64("")
def test_convert_image_to_base64_directory(self, tmp_path):
"""Test conversion with directory path instead of file."""
with pytest.raises(ValueError, match="Path is not a file"):
convert_image_to_base64(tmp_path)
def test_create_data_url_success(self, sample_image):
"""Test successful creation of data URL."""
data_url = create_data_url(sample_image)
assert data_url.startswith("data:image/png;base64,")
# Verify the base64 part is valid
base64_part = data_url.split(",")[1]
assert base64.b64decode(base64_part)
def test_convert_image_to_base64_nonexistent_file():
"""Test conversion with non-existent file."""
with pytest.raises(FileNotFoundError, match="Image file not found"):
convert_image_to_base64("nonexistent.png")
def test_create_data_url_with_custom_mime(self, sample_image):
"""Test creation of data URL with custom MIME type."""
custom_mime = "image/custom"
data_url = create_data_url(sample_image, mime_type=custom_mime)
assert data_url.startswith(f"data:{custom_mime};base64,")
def test_create_data_url_invalid_file(self):
"""Test creation of data URL with invalid file."""
with pytest.raises(FileNotFoundError):
create_data_url("nonexistent.jpg")
def test_convert_image_to_base64_directory(tmp_path):
"""Test conversion with directory path instead of file."""
with pytest.raises(ValueError, match="Path is not a file"):
convert_image_to_base64(tmp_path)
def test_create_data_url_unrecognized_extension(self, tmp_path):
"""Test creation of data URL with unrecognized file extension."""
invalid_file = tmp_path / "test.unknown"
invalid_file.touch()
with pytest.raises(ValueError, match="Could not determine MIME type"):
create_data_url(invalid_file)
def test_create_data_url_success(sample_image):
"""Test successful creation of data URL."""
data_url = create_data_url(sample_image)
assert data_url.startswith("data:image/png;base64,")
# Verify the base64 part is valid
base64_part = data_url.split(",")[1]
assert base64.b64decode(base64_part)
def test_create_data_url_with_custom_mime(sample_image):
"""Test creation of data URL with custom MIME type."""
custom_mime = "image/custom"
data_url = create_data_url(sample_image, mime_type=custom_mime)
assert data_url.startswith(f"data:{custom_mime};base64,")
def test_create_data_url_invalid_file():
"""Test creation of data URL with invalid file."""
with pytest.raises(FileNotFoundError):
create_data_url("nonexistent.jpg")
def test_create_data_url_unrecognized_extension(tmp_path):
"""Test creation of data URL with unrecognized file extension."""
invalid_file = tmp_path / "test.unknown"
invalid_file.touch()
with pytest.raises(ValueError, match="Could not determine MIME type"):
create_data_url(invalid_file)

View file

@ -1,6 +1,7 @@
import math
import pytest
from langflow.utils.constants import MAX_TEXT_LENGTH
from langflow.utils.util_strings import truncate_long_strings
@ -48,8 +49,6 @@ def test_truncate_long_strings_negative_max_length():
# Test for None max_length (should use default MAX_TEXT_LENGTH)
def test_truncate_long_strings_none_max_length():
from langflow.utils.constants import MAX_TEXT_LENGTH
long_string = "a" * (MAX_TEXT_LENGTH + 10)
result = truncate_long_strings(long_string, None)
assert len(result) == MAX_TEXT_LENGTH + 3 # +3 for "..."