ref: Some ruff fixes from preview (#5420)
* Some ruff fixes from preview * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3454ede5a5
commit
e91bcc2520
79 changed files with 402 additions and 374 deletions
|
|
@ -90,7 +90,7 @@ class AgentQL(Component):
|
|||
|
||||
except httpx.HTTPStatusError as e:
|
||||
response = e.response
|
||||
if response.status_code in [401, 403]:
|
||||
if response.status_code in {401, 403}:
|
||||
self.status = "Please, provide a valid API Key. You can create one at https://dev.agentql.com."
|
||||
else:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class DirectoryComponent(Component):
|
|||
|
||||
def load_directory(self) -> list[Data]:
|
||||
path = self.path
|
||||
types = self.types if self.types else TEXT_FILE_TYPES
|
||||
types = self.types or TEXT_FILE_TYPES
|
||||
depth = self.depth
|
||||
max_concurrency = self.max_concurrency
|
||||
load_hidden = self.load_hidden
|
||||
|
|
|
|||
|
|
@ -131,12 +131,12 @@ class LangWatchComponent(Component):
|
|||
|
||||
# Clear component's dynamic attributes
|
||||
for attr in list(self.__dict__.keys()):
|
||||
if attr not in default_keys and attr not in [
|
||||
if attr not in default_keys and attr not in {
|
||||
"evaluators",
|
||||
"dynamic_inputs",
|
||||
"_code",
|
||||
"current_evaluator",
|
||||
]:
|
||||
}:
|
||||
delattr(self, attr)
|
||||
|
||||
# Add new dynamic inputs
|
||||
|
|
@ -177,7 +177,7 @@ class LangWatchComponent(Component):
|
|||
input_fields = [
|
||||
field
|
||||
for field in evaluator.get("requiredFields", []) + evaluator.get("optionalFields", [])
|
||||
if field not in ["input", "output"]
|
||||
if field not in {"input", "output"}
|
||||
]
|
||||
|
||||
for field in input_fields:
|
||||
|
|
|
|||
|
|
@ -126,8 +126,7 @@ class ConditionalRouterComponent(Component):
|
|||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None) -> dict:
|
||||
if field_name == "operator":
|
||||
if field_value == "matches regex":
|
||||
if "case_sensitive" in build_config:
|
||||
del build_config["case_sensitive"]
|
||||
build_config.pop("case_sensitive", None)
|
||||
# Ensure case_sensitive is present for all other operators
|
||||
elif "case_sensitive" not in build_config:
|
||||
case_sensitive_input = next(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class OpenRouterComponent(LCModelComponent):
|
|||
|
||||
display_name = "OpenRouter"
|
||||
description = (
|
||||
"OpenRouter provides unified access to multiple AI models " "from different providers through a single API."
|
||||
"OpenRouter provides unified access to multiple AI models from different providers through a single API."
|
||||
)
|
||||
icon = "OpenRouter"
|
||||
|
||||
|
|
@ -180,9 +180,7 @@ class OpenRouterComponent(LCModelComponent):
|
|||
build_config["model_name"]["value"] = models[0]["id"]
|
||||
|
||||
tooltips = {
|
||||
model["id"]: (
|
||||
f"{model['name']}\n" f"Context Length: {model['context_length']}\n" f"{model['description']}"
|
||||
)
|
||||
model["id"]: (f"{model['name']}\nContext Length: {model['context_length']}\n{model['description']}")
|
||||
for model in models
|
||||
}
|
||||
build_config["model_name"]["tooltips"] = tooltips
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from langflow.utils.constants import MESSAGE_SENDER_AI
|
|||
|
||||
class NeedleComponent(Component):
|
||||
display_name = "Needle Retriever"
|
||||
description = "A retriever that uses the Needle API to search collections " "and generates responses using OpenAI."
|
||||
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
|
||||
documentation = "https://docs.needle-ai.com"
|
||||
icon = "search"
|
||||
name = "needle"
|
||||
|
|
@ -105,8 +105,8 @@ class NeedleComponent(Component):
|
|||
if str(output_type).lower().strip() == "chunks":
|
||||
# If chunks selected, include full context and answer
|
||||
docs = result["source_documents"]
|
||||
context = "\n\n".join([f"Document {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {query}\n\n" f"Context:\n{context}\n\n" f"Answer: {result['answer']}"
|
||||
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
|
||||
else:
|
||||
# If answer selected, only include the answer
|
||||
text_content = result["answer"]
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class DataFrameOperationsComponent(Component):
|
|||
build_config["new_column_value"]["show"] = True
|
||||
elif field_value == "Select Columns":
|
||||
build_config["columns_to_select"]["show"] = True
|
||||
elif field_value in ["Head", "Tail"]:
|
||||
elif field_value in {"Head", "Tail"}:
|
||||
build_config["num_rows"]["show"] = True
|
||||
elif field_value == "Replace Value":
|
||||
build_config["column_name"]["show"] = True
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class MergeDataComponent(Component):
|
|||
for key, value in data_input.data.items():
|
||||
if key in result_data and isinstance(value, str):
|
||||
if isinstance(result_data[key], list):
|
||||
cast(list[str], result_data[key]).append(value)
|
||||
cast("list[str]", result_data[key]).append(value)
|
||||
else:
|
||||
result_data[key] = [result_data[key], value]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class AstraDBVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None):
|
||||
# Always attempt to update the database list
|
||||
if field_name in ["token", "api_endpoint", "collection_name"]:
|
||||
if field_name in {"token", "api_endpoint", "collection_name"}:
|
||||
# Update the database selector
|
||||
build_config["api_endpoint"]["options"] = self._initialize_database_options()
|
||||
|
||||
|
|
|
|||
|
|
@ -1101,7 +1101,8 @@ class Graph:
|
|||
return False
|
||||
return self.vertex_edges_are_identical(vertex, other_vertex)
|
||||
|
||||
def vertex_edges_are_identical(self, vertex: Vertex, other_vertex: Vertex) -> bool:
|
||||
@staticmethod
|
||||
def vertex_edges_are_identical(vertex: Vertex, other_vertex: Vertex) -> bool:
|
||||
same_length = len(vertex.edges) == len(other_vertex.edges)
|
||||
if not same_length:
|
||||
return False
|
||||
|
|
@ -1747,7 +1748,8 @@ class Graph:
|
|||
new_edge = Edge(source, target, edge)
|
||||
return new_edge
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> type[Vertex]:
|
||||
@staticmethod
|
||||
def _get_vertex_class(node_type: str, node_base_type: str, node_id: str) -> type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
# First we check for the node_base_type
|
||||
node_name = node_id.split("-")[0]
|
||||
|
|
@ -1830,7 +1832,8 @@ class Graph:
|
|||
self._record_snapshot()
|
||||
return self
|
||||
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> list[Vertex]:
|
||||
@staticmethod
|
||||
def get_children_by_vertex_type(vertex: Vertex, vertex_type: str) -> list[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
vertex_types = [vertex.data["type"]]
|
||||
|
|
@ -2059,7 +2062,8 @@ class Graph:
|
|||
self._first_layer = first_layer
|
||||
return first_layer
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: list[list[str]]) -> list[list[str]]:
|
||||
@staticmethod
|
||||
def sort_interface_components_first(vertices_layers: list[list[str]]) -> list[list[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
def contains_interface_component(vertex):
|
||||
|
|
@ -2097,9 +2101,6 @@ class Graph:
|
|||
|
||||
This method is responsible for building the run map for the graph,
|
||||
which maps each node in the graph to its corresponding run function.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.run_manager.build_run_map(predecessor_map=self.predecessor_map, vertices_to_run=self.vertices_to_run)
|
||||
|
||||
|
|
@ -2169,7 +2170,8 @@ class Graph:
|
|||
in_degree[vertex.id] = 0
|
||||
return in_degree
|
||||
|
||||
def build_adjacency_maps(self, edges: list[CycleEdge]) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
|
||||
@staticmethod
|
||||
def build_adjacency_maps(edges: list[CycleEdge]) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""Returns the adjacency maps for the graph."""
|
||||
predecessor_map: dict[str, list[str]] = defaultdict(list)
|
||||
successor_map: dict[str, list[str]] = defaultdict(list)
|
||||
|
|
|
|||
|
|
@ -110,9 +110,6 @@ def update_template(template, g_nodes) -> None:
|
|||
Args:
|
||||
template (dict): The new template to update the node with.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for value in template.values():
|
||||
if not value.get("proxy"):
|
||||
|
|
@ -161,9 +158,6 @@ def set_new_target_handle(proxy_id, new_edge, target_handle, node) -> None:
|
|||
new_edge (dict): The new edge to be created.
|
||||
target_handle (dict): The target handle of the edge.
|
||||
node (dict): The node containing the edge.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
new_edge["target"] = proxy_id
|
||||
type_ = target_handle.get("type")
|
||||
|
|
|
|||
|
|
@ -423,7 +423,7 @@ class Vertex:
|
|||
else:
|
||||
msg = f"Invalid value type {type(val)} for field {field_name}"
|
||||
raise ValueError(msg)
|
||||
elif val is not None and val != "":
|
||||
elif val:
|
||||
params[field_name] = val
|
||||
|
||||
if field.get("load_from_db"):
|
||||
|
|
@ -596,7 +596,8 @@ class Vertex:
|
|||
result = await value.get_result(self, target_handle_name=key)
|
||||
self.params[key][sub_key] = result
|
||||
|
||||
def _is_vertex(self, value):
|
||||
@staticmethod
|
||||
def _is_vertex(value):
|
||||
"""Checks if the provided value is an instance of Vertex."""
|
||||
return isinstance(value, Vertex)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
|
@ -13,6 +15,7 @@ class AllTypesDict(LazyLoadDictBase):
|
|||
"Custom": ["Custom Tool", "Python Function"],
|
||||
}
|
||||
|
||||
@override
|
||||
def get_type_dict(self):
|
||||
from langflow.interface.types import get_all_types_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from loguru._file_sink import FileSink
|
|||
from loguru._simple_sinks import AsyncSink
|
||||
from platformdirs import user_cache_dir
|
||||
from rich.logging import RichHandler
|
||||
from typing_extensions import NotRequired
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langflow.settings import DEV
|
||||
|
||||
|
|
@ -285,6 +285,7 @@ class InterceptHandler(logging.Handler):
|
|||
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging.
|
||||
"""
|
||||
|
||||
@override
|
||||
def emit(self, record) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|||
from pydantic import PydanticDeprecatedSince20
|
||||
from pydantic_core import PydanticSerializationError
|
||||
from rich import print as rprint
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
|
||||
from langflow.api import health_check_router, log_router, router
|
||||
from langflow.initial_setup.setup import (
|
||||
|
|
@ -45,7 +45,7 @@ class RequestCancelledMiddleware(BaseHTTPMiddleware):
|
|||
def __init__(self, app) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
sentinel = object()
|
||||
|
||||
async def cancel_handler():
|
||||
|
|
@ -68,7 +68,7 @@ class RequestCancelledMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
class JavaScriptMIMETypeMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
try:
|
||||
response = await call_next(request)
|
||||
except Exception as exc:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ class ContentSizeLimitMiddleware:
|
|||
self.app = app
|
||||
self.logger = logger
|
||||
|
||||
def receive_wrapper(self, receive):
|
||||
@staticmethod
|
||||
def receive_wrapper(receive):
|
||||
received = 0
|
||||
|
||||
async def inner():
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class DataFrame(pandas_DataFrame):
|
|||
>>> dataset = DataFrame({"name": ["John", "Jane"], "age": [30, 25]})
|
||||
"""
|
||||
|
||||
def __init__(self, data: None | list[dict | Data] | dict | pd.DataFrame = None, **kwargs):
|
||||
def __init__(self, data: list[dict | Data] | dict | pd.DataFrame | None = None, **kwargs):
|
||||
if data is None:
|
||||
super().__init__(**kwargs)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing_extensions import Protocol
|
|||
from langflow.schema.message import ContentBlock, Message
|
||||
from langflow.schema.playground_events import PlaygroundEvent
|
||||
|
||||
LoggableType: TypeAlias = str | dict | list | int | float | bool | None | BaseModel | PlaygroundEvent
|
||||
LoggableType: TypeAlias = str | dict | list | int | float | bool | BaseModel | PlaygroundEvent | None
|
||||
|
||||
|
||||
class LogFunctionType(Protocol):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.services.auth.service import AuthService
|
||||
from langflow.services.factory import ServiceFactory
|
||||
|
||||
|
|
@ -8,5 +10,6 @@ class AuthServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(AuthService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service):
|
||||
return AuthService(settings_service)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.logging.logger import logger
|
||||
from langflow.services.cache.disk import AsyncDiskCache
|
||||
from langflow.services.cache.service import AsyncInMemoryCache, CacheService, RedisCache, ThreadingInMemoryCache
|
||||
|
|
@ -15,6 +17,7 @@ class CacheServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(CacheService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
# Here you would have logic to create and configure a CacheService
|
||||
# based on the settings_service
|
||||
|
|
|
|||
|
|
@ -187,7 +187,8 @@ class DatabaseService(Service):
|
|||
await session.commit()
|
||||
logger.debug("Successfully assigned orphaned flows to the default superuser")
|
||||
|
||||
def _generate_unique_flow_name(self, original_name: str, existing_names: set[str]) -> str:
|
||||
@staticmethod
|
||||
def _generate_unique_flow_name(original_name: str, existing_names: set[str]) -> str:
|
||||
"""Generate a unique flow name by adding or incrementing a suffix."""
|
||||
if original_name not in existing_names:
|
||||
return original_name
|
||||
|
|
@ -212,7 +213,8 @@ class DatabaseService(Service):
|
|||
|
||||
return new_name
|
||||
|
||||
def _check_schema_health(self, connection) -> bool:
|
||||
@staticmethod
|
||||
def _check_schema_health(connection) -> bool:
|
||||
inspector = inspect(connection)
|
||||
|
||||
model_mapping: dict[str, type[SQLModel]] = {
|
||||
|
|
@ -249,7 +251,8 @@ class DatabaseService(Service):
|
|||
async with self.with_session() as session, session.bind.connect() as conn:
|
||||
await conn.run_sync(self._check_schema_health)
|
||||
|
||||
def init_alembic(self, alembic_cfg) -> None:
|
||||
@staticmethod
|
||||
def init_alembic(alembic_cfg) -> None:
|
||||
logger.info("Initializing alembic")
|
||||
command.ensure_version(alembic_cfg)
|
||||
# alembic_cfg.attributes["connection"].commit()
|
||||
|
|
@ -317,7 +320,8 @@ class DatabaseService(Service):
|
|||
should_initialize_alembic = True
|
||||
await asyncio.to_thread(self._run_migrations, should_initialize_alembic, fix)
|
||||
|
||||
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
|
||||
@staticmethod
|
||||
def try_downgrade_upgrade_until_success(alembic_cfg, retries=5) -> None:
|
||||
# Try -1 then head, if it fails, try -2 then head, etc.
|
||||
# until we reach the number of retries
|
||||
for i in range(1, retries + 1):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.session.service import SessionService
|
||||
|
||||
|
|
@ -11,5 +13,6 @@ class SessionServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(SessionService)
|
||||
|
||||
@override
|
||||
def create(self, cache_service: "CacheService"):
|
||||
return SessionService(cache_service)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ class SessionService(Service):
|
|||
|
||||
return graph, artifacts
|
||||
|
||||
def build_key(self, session_id, data_graph) -> str:
|
||||
@staticmethod
|
||||
def build_key(session_id, data_graph) -> str:
|
||||
json_hash = compute_dict_hash(data_graph)
|
||||
return f"{session_id}{':' if session_id else ''}{json_hash}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.settings.service import SettingsService
|
||||
|
||||
|
|
@ -13,6 +15,7 @@ class SettingsServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(SettingsService)
|
||||
|
||||
@override
|
||||
def create(self):
|
||||
# Here you would have logic to create and configure a SettingsService
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.shared_component_cache.service import SharedComponentCacheService
|
||||
|
||||
|
|
@ -11,5 +13,6 @@ class SharedComponentCacheServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(SharedComponentCacheService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: "SettingsService"):
|
||||
return SharedComponentCacheService(expiration_time=settings_service.settings.cache_expire)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.socket.service import SocketIOService
|
||||
|
||||
|
|
@ -13,5 +15,6 @@ class SocketIOFactory(ServiceFactory):
|
|||
service_class=SocketIOService,
|
||||
)
|
||||
|
||||
@override
|
||||
def create(self, cache_service: "CacheService"):
|
||||
return SocketIOService(cache_service)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.settings.service import SettingsService
|
||||
from langflow.services.state.service import InMemoryStateService
|
||||
|
|
@ -7,6 +9,7 @@ class StateServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(InMemoryStateService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
return InMemoryStateService(
|
||||
settings_service,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from loguru import logger
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.session.service import SessionService
|
||||
|
|
@ -12,6 +13,7 @@ class StorageServiceFactory(ServiceFactory):
|
|||
StorageService,
|
||||
)
|
||||
|
||||
@override
|
||||
def create(self, session_service: SessionService, settings_service: SettingsService):
|
||||
storage_type = settings_service.settings.storage_type
|
||||
if storage_type.lower() == "local":
|
||||
|
|
|
|||
|
|
@ -21,12 +21,15 @@ class LocalStorageService(StorageService):
|
|||
async def save_file(self, flow_id: str, file_name: str, data: bytes) -> None:
|
||||
"""Save a file in the local storage.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:param file_name: The name of the file to be saved.
|
||||
:param data: The byte content of the file.
|
||||
:raises FileNotFoundError: If the specified flow does not exist.
|
||||
:raises IsADirectoryError: If the file name is a directory.
|
||||
:raises PermissionError: If there is no permission to write the file.
|
||||
Args:
|
||||
flow_id: The identifier for the flow.
|
||||
file_name: The name of the file to be saved.
|
||||
data: The byte content of the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified flow does not exist.
|
||||
IsADirectoryError: If the file name is a directory.
|
||||
PermissionError: If there is no permission to write the file.
|
||||
"""
|
||||
folder_path = self.data_dir / flow_id
|
||||
await folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -43,10 +46,15 @@ class LocalStorageService(StorageService):
|
|||
async def get_file(self, flow_id: str, file_name: str) -> bytes:
|
||||
"""Retrieve a file from the local storage.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:param file_name: The name of the file to be retrieved.
|
||||
:return: The byte content of the file.
|
||||
:raises FileNotFoundError: If the file does not exist.
|
||||
Args:
|
||||
flow_id: The identifier for the flow.
|
||||
file_name: The name of the file to be retrieved.
|
||||
|
||||
Returns:
|
||||
The byte content of the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
"""
|
||||
file_path = self.data_dir / flow_id / file_name
|
||||
if not await file_path.exists():
|
||||
|
|
@ -63,9 +71,14 @@ class LocalStorageService(StorageService):
|
|||
async def list_files(self, flow_id: str):
|
||||
"""List all files in a specified flow.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:return: A list of file names.
|
||||
:raises FileNotFoundError: If the flow directory does not exist.
|
||||
Args:
|
||||
flow_id: The identifier for the flow.
|
||||
|
||||
Returns:
|
||||
A list of file names.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the flow directory does not exist.
|
||||
"""
|
||||
if not isinstance(flow_id, str):
|
||||
flow_id = str(flow_id)
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@ class S3StorageService(StorageService):
|
|||
async def save_file(self, folder: str, file_name: str, data) -> None:
|
||||
"""Save a file to the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket to save the file.
|
||||
:param file_name: The name of the file to be saved.
|
||||
:param data: The byte content of the file.
|
||||
:raises Exception: If an error occurs during file saving.
|
||||
Args:
|
||||
folder: The folder in the bucket to save the file.
|
||||
file_name: The name of the file to be saved.
|
||||
data: The byte content of the file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during file saving.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data)
|
||||
|
|
@ -36,10 +39,15 @@ class S3StorageService(StorageService):
|
|||
async def get_file(self, folder: str, file_name: str):
|
||||
"""Retrieve a file from the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket where the file is stored.
|
||||
:param file_name: The name of the file to be retrieved.
|
||||
:return: The byte content of the file.
|
||||
:raises Exception: If an error occurs during file retrieval.
|
||||
Args:
|
||||
folder: The folder in the bucket where the file is stored.
|
||||
file_name: The name of the file to be retrieved.
|
||||
|
||||
Returns:
|
||||
The byte content of the file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during file retrieval.
|
||||
"""
|
||||
try:
|
||||
response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
|
|
@ -52,9 +60,14 @@ class S3StorageService(StorageService):
|
|||
async def list_files(self, folder: str):
|
||||
"""List all files in a specified folder of the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket to list files from.
|
||||
:return: A list of file names.
|
||||
:raises Exception: If an error occurs during file listing.
|
||||
Args:
|
||||
folder: The folder in the bucket to list files from.
|
||||
|
||||
Returns:
|
||||
A list of file names.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during file listing.
|
||||
"""
|
||||
try:
|
||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder)
|
||||
|
|
@ -69,9 +82,12 @@ class S3StorageService(StorageService):
|
|||
async def delete_file(self, folder: str, file_name: str) -> None:
|
||||
"""Delete a file from the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket where the file is stored.
|
||||
:param file_name: The name of the file to be deleted.
|
||||
:raises Exception: If an error occurs during file deletion.
|
||||
Args:
|
||||
folder: The folder in the bucket where the file is stored.
|
||||
file_name: The name of the file to be deleted.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during file deletion.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.store.service import StoreService
|
||||
|
||||
|
|
@ -13,5 +15,6 @@ class StoreServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(StoreService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
return StoreService(settings_service)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,8 @@ class StoreService(Service):
|
|||
except Exception: # noqa: BLE001
|
||||
logger.opt(exception=True).debug("Webhook failed")
|
||||
|
||||
def build_tags_filter(self, tags: list[str]):
|
||||
@staticmethod
|
||||
def build_tags_filter(tags: list[str]):
|
||||
tags_filter: dict[str, Any] = {"tags": {"_and": []}}
|
||||
for tag in tags:
|
||||
tags_filter["tags"]["_and"].append({"_some": {"tags_id": {"name": {"_eq": tag}}}})
|
||||
|
|
@ -249,7 +250,8 @@ class StoreService(Service):
|
|||
|
||||
return filter_conditions
|
||||
|
||||
def build_liked_filter(self):
|
||||
@staticmethod
|
||||
def build_liked_filter():
|
||||
user_data = user_data_var.get()
|
||||
# params["filter"] = json.dumps({"user_created": {"_eq": user_data["id"]}})
|
||||
if not user_data:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.task.service import TaskService
|
||||
|
||||
|
|
@ -6,6 +8,7 @@ class TaskServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(TaskService)
|
||||
|
||||
@override
|
||||
def create(self):
|
||||
# Here you would have logic to create and configure a TaskService
|
||||
return TaskService()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.telemetry.service import TelemetryService
|
||||
|
||||
|
|
@ -13,5 +15,6 @@ class TelemetryServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(TelemetryService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
return TelemetryService(settings_service)
|
||||
|
|
|
|||
|
|
@ -132,7 +132,8 @@ class TelemetryService(Service):
|
|||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error flushing logs")
|
||||
|
||||
async def _cancel_task(self, task: asyncio.Task, cancel_msg: str) -> None:
|
||||
@staticmethod
|
||||
async def _cancel_task(task: asyncio.Task, cancel_msg: str) -> None:
|
||||
task.cancel(cancel_msg)
|
||||
await asyncio.wait([task])
|
||||
if not task.cancelled():
|
||||
|
|
|
|||
|
|
@ -319,7 +319,8 @@ class ArizePhoenixTracer(BaseTracer):
|
|||
|
||||
return value
|
||||
|
||||
def _error_to_string(self, error: Exception | None):
|
||||
@staticmethod
|
||||
def _error_to_string(error: Exception | None):
|
||||
"""Converts an error to a string with traceback details."""
|
||||
error_message = None
|
||||
if error:
|
||||
|
|
@ -327,11 +328,13 @@ class ArizePhoenixTracer(BaseTracer):
|
|||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
return error_message
|
||||
|
||||
def _get_current_timestamp(self) -> int:
|
||||
@staticmethod
|
||||
def _get_current_timestamp() -> int:
|
||||
"""Gets the current UTC timestamp in nanoseconds."""
|
||||
return int(datetime.now(timezone.utc).timestamp() * 1_000_000_000)
|
||||
|
||||
def _safe_json_dumps(self, obj: Any, **kwargs: Any) -> str:
|
||||
@staticmethod
|
||||
def _safe_json_dumps(obj: Any, **kwargs: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False, **kwargs)
|
||||
|
||||
|
|
@ -359,6 +362,7 @@ class ArizePhoenixTracer(BaseTracer):
|
|||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
@override
|
||||
def get_langchain_callback(self) -> BaseCallbackHandler | None:
|
||||
"""Returns the LangChain callback handler if applicable."""
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.tracing.service import TracingService
|
||||
|
||||
|
|
@ -13,5 +15,6 @@ class TracingServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(TracingService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
return TracingService(settings_service)
|
||||
|
|
|
|||
|
|
@ -135,7 +135,8 @@ class LangFuseTracer(BaseTracer):
|
|||
return None
|
||||
return None # self._callback
|
||||
|
||||
def _get_config(self) -> dict:
|
||||
@staticmethod
|
||||
def _get_config() -> dict:
|
||||
secret_key = os.getenv("LANGFUSE_SECRET_KEY", None)
|
||||
public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None)
|
||||
host = os.getenv("LANGFUSE_HOST", None)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.schema.data import Data
|
||||
from langflow.services.tracing.base import BaseTracer
|
||||
|
|
@ -140,7 +141,8 @@ class LangSmithTracer(BaseTracer):
|
|||
child.post()
|
||||
self._child_link[trace_name] = child.get_url()
|
||||
|
||||
def _error_to_string(self, error: Exception | None):
|
||||
@staticmethod
|
||||
def _error_to_string(error: Exception | None):
|
||||
error_message = None
|
||||
if error:
|
||||
string_stacktrace = traceback.format_exception(error)
|
||||
|
|
@ -163,5 +165,6 @@ class LangSmithTracer(BaseTracer):
|
|||
self._run_tree.post()
|
||||
self._run_link = self._run_tree.get_url()
|
||||
|
||||
@override
|
||||
def get_langchain_callback(self) -> BaseCallbackHandler | None:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -264,7 +264,8 @@ class TracingService(Service):
|
|||
self.outputs[trace_name] |= outputs or {}
|
||||
self.outputs_metadata[trace_name] |= output_metadata or {}
|
||||
|
||||
def _cleanup_inputs(self, inputs: dict[str, Any]):
|
||||
@staticmethod
|
||||
def _cleanup_inputs(inputs: dict[str, Any]):
|
||||
inputs = inputs.copy()
|
||||
for key in inputs:
|
||||
if "api_key" in key:
|
||||
|
|
|
|||
|
|
@ -177,9 +177,6 @@ async def clean_transactions(settings_service: SettingsService, session: AsyncSe
|
|||
Args:
|
||||
settings_service: The settings service containing configuration like max_transactions_to_keep
|
||||
session: The database session to use for the deletion
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
# Delete transactions using bulk delete
|
||||
|
|
@ -209,9 +206,6 @@ async def clean_vertex_builds(settings_service: SettingsService, session: AsyncS
|
|||
Args:
|
||||
settings_service: The settings service containing configuration like max_vertex_builds_to_keep
|
||||
session: The database session to use for the deletion
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
# Delete vertex builds using bulk delete
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.variable.service import DatabaseVariableService, VariableService
|
||||
|
||||
|
|
@ -13,6 +15,7 @@ class VariableServiceFactory(ServiceFactory):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(VariableService)
|
||||
|
||||
@override
|
||||
def create(self, settings_service: SettingsService):
|
||||
# here you would have logic to create and configure a VariableService
|
||||
# based on the settings_service
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
from typing_extensions import override
|
||||
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.base import Service
|
||||
|
|
@ -141,6 +142,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
await session.refresh(db_variable)
|
||||
return db_variable
|
||||
|
||||
@override
|
||||
async def delete_variable(
|
||||
self,
|
||||
user_id: UUID | str,
|
||||
|
|
@ -155,6 +157,7 @@ class DatabaseVariableService(VariableService, Service):
|
|||
await session.delete(variable)
|
||||
await session.commit()
|
||||
|
||||
@override
|
||||
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None:
|
||||
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)
|
||||
variable = (await session.exec(stmt)).first()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ class KeyedWorkerLockManager:
|
|||
def __init__(self) -> None:
|
||||
self.locks_dir = Path(user_cache_dir("langflow"), ensure_exists=True) / "worker_locks"
|
||||
|
||||
def _validate_key(self, key: str) -> bool:
|
||||
@staticmethod
|
||||
def _validate_key(key: str) -> bool:
|
||||
"""Validate that the string only contains alphanumeric characters and underscores.
|
||||
|
||||
Parameters:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from docstring_parser import parse
|
|||
from langflow.logging.logger import logger
|
||||
from langflow.schema import Data
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.utils import initialize_settings_service
|
||||
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
|
||||
from langflow.utils import constants
|
||||
|
||||
|
|
@ -416,8 +417,6 @@ async def update_settings(
|
|||
max_file_size_upload: int = 100,
|
||||
) -> None:
|
||||
"""Update the settings from a config file."""
|
||||
from langflow.services.utils import initialize_settings_service
|
||||
|
||||
# Check for database_url in the environment variables
|
||||
|
||||
initialize_settings_service()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ from langflow.utils import constants
|
|||
|
||||
|
||||
def truncate_long_strings(data, max_length=None):
|
||||
"""Recursively traverse the dictionary or list and truncate strings longer than max_length."""
|
||||
"""Recursively traverse the dictionary or list and truncate strings longer than max_length.
|
||||
|
||||
Returns:
|
||||
The data with strings truncated if they exceed the max length.
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = constants.MAX_TEXT_LENGTH
|
||||
|
||||
|
|
|
|||
|
|
@ -170,9 +170,15 @@ def create_function(code, function_name):
|
|||
def create_class(code, class_name):
|
||||
"""Dynamically create a class from a string of code and a specified class name.
|
||||
|
||||
:param code: String containing the Python code defining the class
|
||||
:param class_name: Name of the class to be created
|
||||
:return: A function that, when called, returns an instance of the created class
|
||||
Args:
|
||||
code: String containing the Python code defining the class
|
||||
class_name: Name of the class to be created
|
||||
|
||||
Returns:
|
||||
A function that, when called, returns an instance of the created class
|
||||
|
||||
Raises:
|
||||
ValueError: If the code contains syntax errors or the class definition is invalid
|
||||
"""
|
||||
if not hasattr(ast, "TypeIgnore"):
|
||||
ast.TypeIgnore = create_type_ignore_class()
|
||||
|
|
@ -199,7 +205,8 @@ def create_class(code, class_name):
|
|||
def create_type_ignore_class():
|
||||
"""Create a TypeIgnore class for AST module if it doesn't exist.
|
||||
|
||||
:return: TypeIgnore class
|
||||
Returns:
|
||||
TypeIgnore class
|
||||
"""
|
||||
|
||||
class TypeIgnore(ast.AST):
|
||||
|
|
@ -211,8 +218,15 @@ def create_type_ignore_class():
|
|||
def prepare_global_scope(code, module):
|
||||
"""Prepares the global scope with necessary imports from the provided code module.
|
||||
|
||||
:param module: AST parsed module
|
||||
:return: Dictionary representing the global scope with imported modules
|
||||
Args:
|
||||
code: The Python code
|
||||
module: AST parsed module
|
||||
|
||||
Returns:
|
||||
Dictionary representing the global scope with imported modules
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If a module is not found in the code
|
||||
"""
|
||||
exec_globals = globals().copy()
|
||||
exec_globals.update(get_default_imports(code))
|
||||
|
|
@ -250,9 +264,12 @@ def prepare_global_scope(code, module):
|
|||
def extract_class_code(module, class_name):
|
||||
"""Extracts the AST node for the specified class from the module.
|
||||
|
||||
:param module: AST parsed module
|
||||
:param class_name: Name of the class to extract
|
||||
:return: AST node of the specified class
|
||||
Args:
|
||||
module: AST parsed module
|
||||
class_name: Name of the class to extract
|
||||
|
||||
Returns:
|
||||
AST node of the specified class
|
||||
"""
|
||||
class_code = next(node for node in module.body if isinstance(node, ast.ClassDef) and node.name == class_name)
|
||||
|
||||
|
|
@ -263,8 +280,11 @@ def extract_class_code(module, class_name):
|
|||
def compile_class_code(class_code):
|
||||
"""Compiles the AST node of a class into a code object.
|
||||
|
||||
:param class_code: AST node of the class
|
||||
:return: Compiled code object of the class
|
||||
Args:
|
||||
class_code: AST node of the class
|
||||
|
||||
Returns:
|
||||
Compiled code object of the class
|
||||
"""
|
||||
return compile(ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec")
|
||||
|
||||
|
|
@ -272,10 +292,13 @@ def compile_class_code(class_code):
|
|||
def build_class_constructor(compiled_class, exec_globals, class_name):
|
||||
"""Builds a constructor function for the dynamically created class.
|
||||
|
||||
:param compiled_class: Compiled code object of the class
|
||||
:param exec_globals: Global scope with necessary imports
|
||||
:param class_name: Name of the class
|
||||
:return: Constructor function for the class
|
||||
Args:
|
||||
compiled_class: Compiled code object of the class
|
||||
exec_globals: Global scope with necessary imports
|
||||
class_name: Name of the class
|
||||
|
||||
Returns:
|
||||
Constructor function for the class
|
||||
"""
|
||||
exec(compiled_class, exec_globals, locals())
|
||||
exec_globals[class_name] = locals()[class_name]
|
||||
|
|
@ -313,9 +336,12 @@ def get_default_imports(code_string):
|
|||
def find_names_in_code(code, names):
|
||||
"""Finds if any of the specified names are present in the given code string.
|
||||
|
||||
:param code: The source code as a string.
|
||||
:param names: A list of names to check for in the code.
|
||||
:return: A set of names that are found in the code.
|
||||
Args:
|
||||
code: The source code as a string.
|
||||
names: A list of names to check for in the code.
|
||||
|
||||
Returns:
|
||||
A set of names that are found in the code.
|
||||
"""
|
||||
return {name for name in names if name in code}
|
||||
|
||||
|
|
@ -339,7 +365,7 @@ def extract_class_name(code: str) -> str:
|
|||
str: Name of the first Component subclass found
|
||||
|
||||
Raises:
|
||||
ValueError: If no Component subclass is found in the code
|
||||
TypeError: If no Component subclass is found in the code
|
||||
"""
|
||||
try:
|
||||
module = ast.parse(code)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from importlib import metadata
|
||||
|
||||
import httpx
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from langflow.logging.logger import logger
|
||||
|
||||
|
|
@ -22,8 +25,6 @@ def _get_version_info():
|
|||
Raises:
|
||||
ValueError: If the package is not found from the list of package names.
|
||||
"""
|
||||
from importlib import metadata
|
||||
|
||||
package_options = [
|
||||
("langflow", "Langflow"),
|
||||
("langflow-base", "Langflow Base"),
|
||||
|
|
@ -57,8 +58,9 @@ VERSION_INFO = _get_version_info()
|
|||
def is_pre_release(v: str) -> bool:
|
||||
"""Whether the version is a pre-release version.
|
||||
|
||||
Returns a boolean indicating whether the version is a pre-release version,
|
||||
as per the definition of a pre-release segment from PEP 440.
|
||||
Returns:
|
||||
Whether the version is a pre-release version,
|
||||
as per the definition of a pre-release segment from PEP 440.
|
||||
"""
|
||||
return any(label in v for label in ["a", "b", "rc"])
|
||||
|
||||
|
|
@ -66,15 +68,14 @@ def is_pre_release(v: str) -> bool:
|
|||
def is_nightly(v: str) -> bool:
|
||||
"""Whether the version is a dev (nightly) version.
|
||||
|
||||
Returns a boolean indicating whether the version is a dev (nightly) version,
|
||||
as per the definition of a dev segment from PEP 440.
|
||||
Returns:
|
||||
Whether the version is a dev (nightly) version,
|
||||
as per the definition of a dev segment from PEP 440.
|
||||
"""
|
||||
return "dev" in v
|
||||
|
||||
|
||||
def fetch_latest_version(package_name: str, *, include_prerelease: bool) -> str | None:
|
||||
from packaging import version as pkg_version
|
||||
|
||||
package_name = package_name.replace(" ", "-").lower()
|
||||
try:
|
||||
response = httpx.get(f"https://pypi.org/pypi/{package_name}/json")
|
||||
|
|
|
|||
|
|
@ -18,7 +18,11 @@ def test_celery(word: str) -> str:
|
|||
|
||||
@celery_app.task(bind=True, soft_time_limit=30, max_retries=3)
|
||||
def build_vertex(self, vertex: Vertex) -> Vertex:
|
||||
"""Build a vertex."""
|
||||
"""Build a vertex.
|
||||
|
||||
Returns:
|
||||
The built vertex.
|
||||
"""
|
||||
try:
|
||||
vertex.task_id = self.request.id
|
||||
async_to_sync(vertex.build)()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Module for package versioning."""
|
||||
|
||||
import contextlib
|
||||
from importlib import metadata
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
|
|
@ -14,8 +15,6 @@ def get_version() -> str:
|
|||
Raises:
|
||||
ValueError: If the package is not found from the list of package names.
|
||||
"""
|
||||
from importlib import metadata
|
||||
|
||||
pkg_names = [
|
||||
"langflow",
|
||||
"langflow-base",
|
||||
|
|
|
|||
|
|
@ -17,8 +17,10 @@ from blockbuster import blockbuster_ctx
|
|||
from dotenv import load_dotenv
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.graph import Graph
|
||||
from langflow.initial_setup.constants import STARTER_FOLDER_NAME
|
||||
from langflow.main import create_app
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.models.flow.model import Flow, FlowCreate
|
||||
|
|
@ -155,8 +157,6 @@ def caplog(caplog: pytest.LogCaptureFixture):
|
|||
|
||||
@pytest.fixture
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with AsyncClient(app=app, base_url="http://testserver", http2=True) as client:
|
||||
yield client
|
||||
|
|
@ -228,8 +228,6 @@ def distributed_client_fixture(
|
|||
# def get_session_override():
|
||||
# return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
|
|
@ -357,8 +355,6 @@ async def client_fixture(
|
|||
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
db_service = get_db_service()
|
||||
db_service.database_url = f"sqlite:///{db_path}"
|
||||
|
|
@ -482,8 +478,6 @@ async def flow(
|
|||
json_flow: str,
|
||||
active_user,
|
||||
):
|
||||
from langflow.services.database.models.flow.model import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
flow_data = FlowCreate(name="test_flow", data=loaded_json.get("data"), user_id=active_user.id)
|
||||
|
||||
|
|
@ -577,8 +571,6 @@ async def added_webhook_test(client, json_webhook_test, logged_in_headers):
|
|||
|
||||
@pytest.fixture
|
||||
async def flow_component(client: AsyncClient, logged_in_headers):
|
||||
from langflow.components.inputs import ChatInput
|
||||
|
||||
chat_input = ChatInput()
|
||||
graph = Graph(start=chat_input, end=chat_input)
|
||||
graph_dict = graph.dump(name="Chat Input Component")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
import pytest
|
||||
from langflow.components.astra_assistants import (
|
||||
AssistantsCreateAssistant,
|
||||
AssistantsCreateThread,
|
||||
AssistantsGetAssistantName,
|
||||
AssistantsListAssistants,
|
||||
AssistantsRun,
|
||||
)
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_list_assistants():
|
||||
from langflow.components.astra_assistants import AssistantsListAssistants
|
||||
|
||||
results = await run_single_component(
|
||||
AssistantsListAssistants,
|
||||
inputs={},
|
||||
|
|
@ -16,8 +21,6 @@ async def test_list_assistants():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_create_assistants():
|
||||
from langflow.components.astra_assistants import AssistantsCreateAssistant
|
||||
|
||||
results = await run_single_component(
|
||||
AssistantsCreateAssistant,
|
||||
inputs={
|
||||
|
|
@ -36,8 +39,6 @@ async def test_create_assistants():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_create_thread():
|
||||
from langflow.components.astra_assistants import AssistantsCreateThread
|
||||
|
||||
results = await run_single_component(
|
||||
AssistantsCreateThread,
|
||||
inputs={},
|
||||
|
|
@ -48,8 +49,6 @@ async def test_create_thread():
|
|||
|
||||
|
||||
async def get_assistant_name(assistant_id):
|
||||
from langflow.components.astra_assistants import AssistantsGetAssistantName
|
||||
|
||||
results = await run_single_component(
|
||||
AssistantsGetAssistantName,
|
||||
inputs={
|
||||
|
|
@ -60,8 +59,6 @@ async def get_assistant_name(assistant_id):
|
|||
|
||||
|
||||
async def run_assistant(assistant_id, thread_id):
|
||||
from langflow.components.astra_assistants import AssistantsRun
|
||||
|
||||
results = await run_single_component(
|
||||
AssistantsRun,
|
||||
inputs={
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
import pytest
|
||||
from astrapy.db import AstraDB
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
from langchain_core.documents import Document
|
||||
from langflow.components.embeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.vectorstores import AstraDBVectorStoreComponent
|
||||
|
|
@ -37,8 +38,6 @@ def astradb_client():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_base(astradb_client: AstraDB):
|
||||
from langflow.components.embeddings import OpenAIEmbeddingsComponent
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
|
|
@ -88,8 +87,6 @@ async def test_astra_embeds_and_search():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
def test_astra_vectorize():
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
|
|
@ -132,8 +129,6 @@ def test_astra_vectorize():
|
|||
@pytest.mark.api_key_required
|
||||
def test_astra_vectorize_with_provider_api_key():
|
||||
"""Tests vectorize using an openai api key."""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
api_endpoint = get_astradb_api_endpoint()
|
||||
|
||||
|
|
@ -189,8 +184,6 @@ def test_astra_vectorize_with_provider_api_key():
|
|||
@pytest.mark.api_key_required
|
||||
def test_astra_vectorize_passes_authentication():
|
||||
"""Tests vectorize using the authentication parameter."""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
store = None
|
||||
try:
|
||||
application_token = get_astradb_application_token()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from tests.integration.utils import ComponentInputHandle, run_single_component
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
async def test_csv_output_parser_openai():
|
||||
format_instructions = ComponentInputHandle(
|
||||
format_instructions_ = ComponentInputHandle(
|
||||
clazz=OutputParserComponent,
|
||||
inputs={},
|
||||
output_name="format_instructions",
|
||||
|
|
@ -24,7 +24,7 @@ async def test_csv_output_parser_openai():
|
|||
clazz=PromptComponent,
|
||||
inputs={
|
||||
"template": "List the first five positive integers.\n\n{format_instructions}",
|
||||
"format_instructions": format_instructions,
|
||||
"format_instructions": format_instructions_,
|
||||
},
|
||||
output_name="prompt",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ async def test_initialize_services():
|
|||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
async def test_setup_llm_caching():
|
||||
def test_setup_llm_caching():
|
||||
"""Benchmark LLM caching setup."""
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import anyio
|
|||
import pytest
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from langflow.main import create_app
|
||||
from langflow.services.deps import get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
from sqlmodel import Session
|
||||
|
|
@ -60,8 +61,6 @@ async def files_client_fixture(
|
|||
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
return app, db_path
|
||||
|
||||
|
|
@ -81,14 +80,14 @@ async def files_client_fixture(
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def max_file_size_upload_fixture(monkeypatch):
|
||||
def max_file_size_upload_fixture(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "1")
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def max_file_size_upload_10mb_fixture(monkeypatch):
|
||||
def max_file_size_upload_10mb_fixture(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "10")
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langflow.schema.data import Data
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def test_component_tool():
|
||||
def test_component_tool():
|
||||
calculator_component = CalculatorToolComponent()
|
||||
component_toolkit = ComponentToolkit(component=calculator_component)
|
||||
component_tool = component_toolkit.get_tools()[0]
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class AllInputsComponent(Component):
|
|||
return data
|
||||
|
||||
|
||||
async def test_component_inputs_toolkit():
|
||||
def test_component_inputs_toolkit():
|
||||
component = AllInputsComponent()
|
||||
component_toolkit = ComponentToolkit(component=component)
|
||||
component_tool = component_toolkit.get_tools()[0]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import re
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
from langflow.schema.data import Data
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
|
@ -12,8 +15,6 @@ class TestStructuredOutputComponent:
|
|||
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by
|
||||
# the mock function
|
||||
def test_successful_structured_output_generation_with_patch_with_config(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
class MockLanguageModel(BaseLanguageModel):
|
||||
@override
|
||||
def with_structured_output(self, *args, **kwargs):
|
||||
|
|
@ -87,15 +88,11 @@ class TestStructuredOutputComponent:
|
|||
multiple=False,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Language model does not support structured output."):
|
||||
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
|
||||
component.build_structured_output()
|
||||
|
||||
# Correctly builds the output model from the provided schema
|
||||
def test_correctly_builds_output_model(self):
|
||||
# Import internal organization modules, packages, and libraries
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
|
||||
# Setup
|
||||
component = StructuredOutputComponent()
|
||||
schema = [
|
||||
|
|
@ -134,10 +131,6 @@ class TestStructuredOutputComponent:
|
|||
|
||||
# Properly handles multiple outputs when 'multiple' is set to True
|
||||
def test_handles_multiple_outputs(self):
|
||||
# Import internal organization modules, packages, and libraries
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
|
||||
# Setup
|
||||
component = StructuredOutputComponent()
|
||||
schema = [
|
||||
|
|
@ -261,5 +254,5 @@ class TestStructuredOutputComponent:
|
|||
multiple=False,
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="Language model does not support structured output."):
|
||||
with pytest.raises(TypeError, match=re.escape("Language model does not support structured output.")):
|
||||
component.build_structured_output()
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def test_operations(sample_dataframe, operation, expected_columns, expected_valu
|
|||
component.new_column_name = "Z"
|
||||
elif operation == "Select Columns":
|
||||
component.columns_to_select = ["A", "C"]
|
||||
elif operation in ("Head", "Tail"):
|
||||
elif operation in {"Head", "Tail"}:
|
||||
component.num_rows = 1
|
||||
elif operation == "Replace Value":
|
||||
component.column_name = "C"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import re
|
||||
|
||||
import pytest
|
||||
from langflow.components.processing import CreateDataComponent
|
||||
from langflow.schema import Data
|
||||
|
|
@ -48,7 +50,7 @@ def test_update_build_config_exceed_limit(create_data_component):
|
|||
"value": False,
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="Number of fields cannot exceed 15."):
|
||||
with pytest.raises(ValueError, match=re.escape("Number of fields cannot exceed 15.")):
|
||||
create_data_component.update_build_config(build_config, 16, "number_of_fields")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import re
|
||||
|
||||
import pytest
|
||||
from langflow.components.processing import UpdateDataComponent
|
||||
from langflow.schema import Data
|
||||
|
|
@ -48,7 +50,7 @@ def test_update_build_config_exceed_limit(update_data_component):
|
|||
"value": False,
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="Number of fields cannot exceed 15."):
|
||||
with pytest.raises(ValueError, match=re.escape("Number of fields cannot exceed 15.")):
|
||||
update_data_component.update_build_config(build_config, 16, "number_of_fields")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,11 +14,6 @@ from langflow.schema.properties import Properties, Source
|
|||
from langflow.template.field.base import Output
|
||||
|
||||
|
||||
async def create_event_queue():
|
||||
"""Create a queue for testing events."""
|
||||
return asyncio.Queue()
|
||||
|
||||
|
||||
def blocking_cb(manager, event_type, data):
|
||||
time.sleep(0.01)
|
||||
manager.send_event(event_type=event_type, data=data)
|
||||
|
|
@ -43,7 +38,7 @@ class ComponentForTesting(Component):
|
|||
async def test_component_message_sending():
|
||||
"""Test component's message sending functionality."""
|
||||
# Create event queue and manager
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
event_manager.register_event("on_message", "message", callback=blocking_cb)
|
||||
|
|
@ -75,7 +70,7 @@ async def test_component_message_sending():
|
|||
async def test_component_tool_output():
|
||||
"""Test component's tool output functionality."""
|
||||
# Create event queue and manager
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
# Create component
|
||||
|
|
@ -110,7 +105,7 @@ async def test_component_tool_output():
|
|||
async def test_component_error_handling():
|
||||
"""Test component's error handling."""
|
||||
# Create event queue and manager
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
# Create component
|
||||
|
|
@ -141,7 +136,7 @@ async def test_component_error_handling():
|
|||
async def test_component_build_results():
|
||||
"""Test component's build_results functionality."""
|
||||
# Create event queue and manager
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
# Create component
|
||||
|
|
@ -173,7 +168,7 @@ async def test_component_build_results():
|
|||
async def test_component_logging():
|
||||
"""Test component's logging functionality."""
|
||||
# Create event queue and manager
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
# Create component
|
||||
|
|
@ -207,7 +202,7 @@ async def test_component_logging():
|
|||
@pytest.mark.usefixtures("client")
|
||||
async def test_component_streaming_message():
|
||||
"""Test component's streaming message functionality."""
|
||||
queue = await create_event_queue()
|
||||
queue = asyncio.Queue()
|
||||
event_manager = EventManager(queue)
|
||||
|
||||
event_manager.register_event("on_token", "token", blocking_cb)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class TestEventManager:
|
|||
def test_accessing_non_registered_event_callback_with_recommended_fix(self):
|
||||
queue = asyncio.Queue()
|
||||
manager = EventManager(queue)
|
||||
result = manager.__getattr__("non_registered_event")
|
||||
result = manager.non_registered_event
|
||||
assert result == manager.noop
|
||||
|
||||
# Accessing a registered event callback via __getattr__
|
||||
|
|
@ -130,8 +130,6 @@ class TestEventManager:
|
|||
assert len(manager.events) == 1000
|
||||
|
||||
# Verifying the uniqueness of event IDs for each event triggered using await with asyncio decorator
|
||||
import pytest
|
||||
|
||||
async def test_event_id_uniqueness_with_await(self):
|
||||
queue = asyncio.Queue()
|
||||
manager = EventManager(queue)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from langflow.exceptions.api import APIException, ExceptionBody
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
||||
|
||||
def test_api_exception():
|
||||
from langflow.exceptions.api import APIException, ExceptionBody
|
||||
|
||||
mock_exception = Exception("Test exception")
|
||||
mock_flow = Mock(spec=Flow)
|
||||
mock_outdated_components = ["component1", "component2"]
|
||||
|
|
@ -45,8 +44,6 @@ def test_api_exception():
|
|||
|
||||
|
||||
def test_api_exception_no_flow():
|
||||
from langflow.exceptions.api import APIException, ExceptionBody
|
||||
|
||||
# Mock data
|
||||
mock_exception = Exception("Test exception")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import re
|
||||
|
||||
import pytest
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.models import OpenAIModelComponent
|
||||
|
|
@ -25,5 +27,7 @@ Answer:
|
|||
|
||||
chat_output = ChatOutput()
|
||||
chat_output.set(input_value=openai_component.text_response)
|
||||
with pytest.raises(ValueError, match="Component OpenAI field 'input_values' might not be a valid input."):
|
||||
with pytest.raises(
|
||||
ValueError, match=re.escape("Component OpenAI field 'input_values' might not be a valid input.")
|
||||
):
|
||||
Graph(start=chat_input, end=chat_output)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ async def test_graph_not_prepared():
|
|||
await graph.astep()
|
||||
|
||||
|
||||
async def test_graph(caplog: pytest.LogCaptureFixture):
|
||||
def test_graph(caplog: pytest.LogCaptureFixture):
|
||||
chat_input = ChatInput()
|
||||
chat_output = ChatOutput()
|
||||
graph = Graph()
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Literal
|
|||
|
||||
import pytest
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.inputs.inputs import DropdownInput, FileInput, IntInput, NestedDictInput, StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
def test_create_input_schema():
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
schema = create_input_schema(ChatInput.inputs)
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
||||
|
|
@ -17,9 +17,6 @@ def test_create_input_schema():
|
|||
class TestCreateInputSchema:
|
||||
# Single input type is converted to list and processed correctly
|
||||
def test_single_input_type_conversion(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
|
@ -27,9 +24,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Multiple input types are processed and included in the schema
|
||||
def test_multiple_input_types(self):
|
||||
from langflow.inputs.inputs import IntInput, StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
inputs = [StrInput(name="str_field"), IntInput(name="int_field")]
|
||||
schema = create_input_schema(inputs)
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
|
@ -38,9 +32,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Fields are correctly created with appropriate types and attributes
|
||||
def test_fields_creation_with_correct_types_and_attributes(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", info="Test Info", required=True)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -49,18 +40,12 @@ class TestCreateInputSchema:
|
|||
|
||||
# Schema model is created and returned successfully
|
||||
def test_schema_model_creation(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
||||
# Default values are correctly assigned to fields
|
||||
def test_default_values_assignment(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", value="default_value")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -68,16 +53,11 @@ class TestCreateInputSchema:
|
|||
|
||||
# Empty list of inputs is handled without errors
|
||||
def test_empty_list_of_inputs(self):
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
schema = create_input_schema([])
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
||||
# Input with missing optional attributes (e.g., display_name, info) is processed correctly
|
||||
def test_missing_optional_attributes(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -86,9 +66,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Input with is_list attribute set to True is processed correctly
|
||||
def test_is_list_attribute_processing(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", is_list=True)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info: FieldInfo = schema.model_fields["test_field"]
|
||||
|
|
@ -96,9 +73,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Input with options attribute is processed correctly
|
||||
def test_options_attribute_processing(self):
|
||||
from langflow.inputs.inputs import DropdownInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = DropdownInput(name="test_field", options=["option1", "option2"])
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -106,9 +80,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Non-standard field types are handled correctly
|
||||
def test_non_standard_field_types_handling(self):
|
||||
from langflow.inputs.inputs import FileInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = FileInput(name="file_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["file_field"]
|
||||
|
|
@ -116,9 +87,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Inputs with mixed required and optional fields are processed correctly
|
||||
def test_mixed_required_optional_fields_processing(self):
|
||||
from langflow.inputs.inputs import IntInput, StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
inputs = [
|
||||
StrInput(name="required_field", required=True),
|
||||
IntInput(name="optional_field", required=False),
|
||||
|
|
@ -132,9 +100,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Inputs with complex nested structures are handled correctly
|
||||
def test_complex_nested_structures_handling(self):
|
||||
from langflow.inputs.inputs import NestedDictInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
nested_input = NestedDictInput(name="nested_field", value={"key": "value"})
|
||||
schema = create_input_schema([nested_input])
|
||||
|
||||
|
|
@ -145,9 +110,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Creating a schema from a single input type
|
||||
def test_single_input_type_replica(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
assert schema.__name__ == "InputSchema"
|
||||
|
|
@ -155,18 +117,12 @@ class TestCreateInputSchema:
|
|||
|
||||
# Creating a schema from a list of input types
|
||||
def test_passing_input_type_directly(self):
|
||||
from langflow.inputs.inputs import IntInput, StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
inputs = StrInput(name="str_field"), IntInput(name="int_field")
|
||||
with pytest.raises(TypeError):
|
||||
create_input_schema(inputs)
|
||||
|
||||
# Handling input types with options correctly
|
||||
def test_options_handling(self):
|
||||
from langflow.inputs.inputs import DropdownInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = DropdownInput(name="test_field", options=["option1", "option2"])
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -174,9 +130,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Handling input types with is_list attribute correctly
|
||||
def test_is_list_handling(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", is_list=True)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -184,9 +137,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Converting FieldTypes to corresponding Python types
|
||||
def test_field_types_conversion(self):
|
||||
from langflow.inputs.inputs import IntInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = IntInput(name="int_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["int_field"]
|
||||
|
|
@ -194,9 +144,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Setting default values for non-required fields
|
||||
def test_default_values_for_non_required_fields(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", value="default_value")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -204,9 +151,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Handling input types with missing attributes
|
||||
def test_missing_attributes_handling(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field")
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -217,9 +161,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Handling input types with None as default value
|
||||
def test_none_default_value_handling(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test_field", value=None)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
|
|
@ -227,9 +168,6 @@ class TestCreateInputSchema:
|
|||
|
||||
# Handling input types with special characters in names
|
||||
def test_special_characters_in_names_handling(self):
|
||||
from langflow.inputs.inputs import StrInput
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
||||
input_instance = StrInput(name="test@field#name")
|
||||
schema = create_input_schema([input_instance])
|
||||
assert "test@field#name" in schema.model_fields
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import base64
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langflow.schema.data import Data
|
||||
|
|
@ -9,8 +11,6 @@ def sample_image(tmp_path):
|
|||
"""Create a sample image file for testing."""
|
||||
image_path = tmp_path / "test_image.png"
|
||||
# Create a small black 1x1 pixel PNG file
|
||||
import base64
|
||||
|
||||
image_content = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import base64
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
|
@ -29,8 +30,6 @@ def sample_image(langflow_cache_dir):
|
|||
# Create the image in the flow directory
|
||||
image_path = flow_dir / "test_image.png"
|
||||
# Create a small black 1x1 pixel PNG file
|
||||
import base64
|
||||
|
||||
image_content = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import copy
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langflow.schema import Data
|
||||
|
|
@ -62,8 +64,6 @@ def test_custom_attribute_get_set_del():
|
|||
|
||||
|
||||
def test_deep_copy():
|
||||
import copy
|
||||
|
||||
record1 = Data(data={"text": "Hello", "number": 10})
|
||||
record2 = copy.deepcopy(record1)
|
||||
assert record2.text == "Hello"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
|||
from langflow.services.database.models.folder.model import FolderCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -619,8 +620,6 @@ async def test_sqlite_pragmas():
|
|||
db_service = get_db_service()
|
||||
|
||||
async with db_service.with_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
assert (await session.exec(text("PRAGMA journal_mode;"))).scalar() == "wal"
|
||||
assert (await session.exec(text("PRAGMA synchronous;"))).scalar() == 1
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class TestInput:
|
|||
# Empty lists and edge cases
|
||||
assert set(post_process_type(list)) == {list}
|
||||
assert set(post_process_type(Union[int, None])) == {int, NoneType} # noqa: UP007
|
||||
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # noqa: UP007
|
||||
assert set(post_process_type(Union[list[None], None])) == {None, NoneType} # noqa: UP007
|
||||
|
||||
# Handling complex nested structures
|
||||
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} # noqa: UP007
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ def test_increment_counter_empty_label(opentelemetry_instance):
|
|||
|
||||
|
||||
def test_increment_counter_missing_mandatory_label(opentelemetry_instance):
|
||||
with pytest.raises(ValueError, match="Missing required labels: {'flow_id'}"):
|
||||
with pytest.raises(ValueError, match=re.escape("Missing required labels: {'flow_id'}")):
|
||||
opentelemetry_instance.increment_counter(metric_name="num_files_uploaded", value=5, labels={"service": "one"})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,51 +16,57 @@ def sample_image(tmp_path):
|
|||
return image_path
|
||||
|
||||
|
||||
class TestImageUtils:
|
||||
def test_convert_image_to_base64_success(self, sample_image):
|
||||
"""Test successful conversion of image to base64."""
|
||||
base64_str = convert_image_to_base64(sample_image)
|
||||
assert isinstance(base64_str, str)
|
||||
# Verify it's valid base64
|
||||
assert base64.b64decode(base64_str)
|
||||
def test_convert_image_to_base64_success(sample_image):
|
||||
"""Test successful conversion of image to base64."""
|
||||
base64_str = convert_image_to_base64(sample_image)
|
||||
assert isinstance(base64_str, str)
|
||||
# Verify it's valid base64
|
||||
assert base64.b64decode(base64_str)
|
||||
|
||||
def test_convert_image_to_base64_empty_path(self):
|
||||
"""Test conversion with empty path."""
|
||||
with pytest.raises(ValueError, match="Image path cannot be empty"):
|
||||
convert_image_to_base64("")
|
||||
|
||||
def test_convert_image_to_base64_nonexistent_file(self):
|
||||
"""Test conversion with non-existent file."""
|
||||
with pytest.raises(FileNotFoundError, match="Image file not found"):
|
||||
convert_image_to_base64("nonexistent.png")
|
||||
def test_convert_image_to_base64_empty_path():
|
||||
"""Test conversion with empty path."""
|
||||
with pytest.raises(ValueError, match="Image path cannot be empty"):
|
||||
convert_image_to_base64("")
|
||||
|
||||
def test_convert_image_to_base64_directory(self, tmp_path):
|
||||
"""Test conversion with directory path instead of file."""
|
||||
with pytest.raises(ValueError, match="Path is not a file"):
|
||||
convert_image_to_base64(tmp_path)
|
||||
|
||||
def test_create_data_url_success(self, sample_image):
|
||||
"""Test successful creation of data URL."""
|
||||
data_url = create_data_url(sample_image)
|
||||
assert data_url.startswith("data:image/png;base64,")
|
||||
# Verify the base64 part is valid
|
||||
base64_part = data_url.split(",")[1]
|
||||
assert base64.b64decode(base64_part)
|
||||
def test_convert_image_to_base64_nonexistent_file():
|
||||
"""Test conversion with non-existent file."""
|
||||
with pytest.raises(FileNotFoundError, match="Image file not found"):
|
||||
convert_image_to_base64("nonexistent.png")
|
||||
|
||||
def test_create_data_url_with_custom_mime(self, sample_image):
|
||||
"""Test creation of data URL with custom MIME type."""
|
||||
custom_mime = "image/custom"
|
||||
data_url = create_data_url(sample_image, mime_type=custom_mime)
|
||||
assert data_url.startswith(f"data:{custom_mime};base64,")
|
||||
|
||||
def test_create_data_url_invalid_file(self):
|
||||
"""Test creation of data URL with invalid file."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
create_data_url("nonexistent.jpg")
|
||||
def test_convert_image_to_base64_directory(tmp_path):
|
||||
"""Test conversion with directory path instead of file."""
|
||||
with pytest.raises(ValueError, match="Path is not a file"):
|
||||
convert_image_to_base64(tmp_path)
|
||||
|
||||
def test_create_data_url_unrecognized_extension(self, tmp_path):
|
||||
"""Test creation of data URL with unrecognized file extension."""
|
||||
invalid_file = tmp_path / "test.unknown"
|
||||
invalid_file.touch()
|
||||
with pytest.raises(ValueError, match="Could not determine MIME type"):
|
||||
create_data_url(invalid_file)
|
||||
|
||||
def test_create_data_url_success(sample_image):
|
||||
"""Test successful creation of data URL."""
|
||||
data_url = create_data_url(sample_image)
|
||||
assert data_url.startswith("data:image/png;base64,")
|
||||
# Verify the base64 part is valid
|
||||
base64_part = data_url.split(",")[1]
|
||||
assert base64.b64decode(base64_part)
|
||||
|
||||
|
||||
def test_create_data_url_with_custom_mime(sample_image):
|
||||
"""Test creation of data URL with custom MIME type."""
|
||||
custom_mime = "image/custom"
|
||||
data_url = create_data_url(sample_image, mime_type=custom_mime)
|
||||
assert data_url.startswith(f"data:{custom_mime};base64,")
|
||||
|
||||
|
||||
def test_create_data_url_invalid_file():
|
||||
"""Test creation of data URL with invalid file."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
create_data_url("nonexistent.jpg")
|
||||
|
||||
|
||||
def test_create_data_url_unrecognized_extension(tmp_path):
|
||||
"""Test creation of data URL with unrecognized file extension."""
|
||||
invalid_file = tmp_path / "test.unknown"
|
||||
invalid_file.touch()
|
||||
with pytest.raises(ValueError, match="Could not determine MIME type"):
|
||||
create_data_url(invalid_file)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
|
||||
import pytest
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
|
|
@ -48,8 +49,6 @@ def test_truncate_long_strings_negative_max_length():
|
|||
|
||||
# Test for None max_length (should use default MAX_TEXT_LENGTH)
|
||||
def test_truncate_long_strings_none_max_length():
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
|
||||
long_string = "a" * (MAX_TEXT_LENGTH + 10)
|
||||
result = truncate_long_strings(long_string, None)
|
||||
assert len(result) == MAX_TEXT_LENGTH + 3 # +3 for "..."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue