parent
140cf890e3
commit
f96f2eaf8a
67 changed files with 421 additions and 361 deletions
|
|
@ -216,6 +216,43 @@ directory = "coverage"
|
|||
exclude = ["src/backend/langflow/alembic/*"]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
pydocstyle.convention = "google"
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"C90", # McCabe complexity
|
||||
"CPY", # Missing copyright
|
||||
"COM812", # Messes with the formatter
|
||||
"ERA", # Eradicate commented-out code
|
||||
"FIX002", # Line contains TODO
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"PLR09", # Too many something (arg, statements, etc)
|
||||
"RUF012", # Pydantic models are currently not well detected. See https://github.com/astral-sh/ruff/issues/13630
|
||||
"TD002", # Missing author in TODO
|
||||
"TD003", # Missing issue link in TODO
|
||||
"TRY301", # A bit too harsh (Abstract `raise` to an inner function)
|
||||
|
||||
# Rules that are TODOs
|
||||
"ANN",
|
||||
]
|
||||
|
||||
# Preview rules that are not yet activated
|
||||
external = ["RUF027"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"scripts/*" = [
|
||||
"D1",
|
||||
"INP",
|
||||
"T201",
|
||||
]
|
||||
"src/backend/tests/*" = [
|
||||
"D1",
|
||||
"PLR2004",
|
||||
"S101",
|
||||
"SLF001",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ["pydantic.mypy"]
|
||||
follow_imports = "skip"
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ PYPI_LANGFLOW_NIGHTLY_URL = "https://pypi.org/pypi/langflow-nightly/json"
|
|||
PYPI_LANGFLOW_BASE_URL = "https://pypi.org/pypi/langflow-base/json"
|
||||
PYPI_LANGFLOW_BASE_NIGHTLY_URL = "https://pypi.org/pypi/langflow-base-nightly/json"
|
||||
|
||||
ARGUMENT_NUMBER = 2
|
||||
|
||||
def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
|
||||
|
||||
def get_latest_published_version(build_type: str, *, is_nightly: bool) -> Version:
|
||||
import requests
|
||||
|
||||
url = ""
|
||||
|
|
@ -25,12 +27,12 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
|
|||
msg = f"Invalid build type: {build_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
res = requests.get(url)
|
||||
res = requests.get(url, timeout=10)
|
||||
try:
|
||||
version_str = res.json()["info"]["version"]
|
||||
except Exception as e:
|
||||
msg = "Got unexpected response from PyPI"
|
||||
raise RuntimeError(msg, e)
|
||||
raise RuntimeError(msg) from e
|
||||
return Version(version_str)
|
||||
|
||||
|
||||
|
|
@ -74,9 +76,9 @@ def create_tag(build_type: str):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
if len(sys.argv) != ARGUMENT_NUMBER:
|
||||
msg = "Specify base or main"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
build_type = sys.argv[1]
|
||||
tag = create_tag(build_type)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
|
@ -5,6 +7,7 @@ from pathlib import Path
|
|||
import packaging.version
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
ARGUMENT_NUMBER = 2
|
||||
|
||||
|
||||
def update_base_dep(pyproject_path: str, new_version: str) -> None:
|
||||
|
|
@ -18,7 +21,7 @@ def update_base_dep(pyproject_path: str, new_version: str) -> None:
|
|||
pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}')
|
||||
if not pattern.search(content):
|
||||
msg = f'langflow-base poetry dependency not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
content = pattern.sub(replacement, content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
|
@ -28,16 +31,13 @@ def verify_pep440(version):
|
|||
|
||||
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
|
||||
"""
|
||||
try:
|
||||
return packaging.version.Version(version)
|
||||
except packaging.version.InvalidVersion:
|
||||
raise
|
||||
return packaging.version.Version(version)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 2:
|
||||
if len(sys.argv) != ARGUMENT_NUMBER:
|
||||
msg = "New version not specified"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
base_version = sys.argv[1]
|
||||
|
||||
# Strip "v" prefix from version if present
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
ARGUMENT_NUMBER = 3
|
||||
|
||||
|
||||
def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
|
||||
|
|
@ -15,7 +18,7 @@ def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
|
|||
|
||||
if not pattern.search(content):
|
||||
msg = f'Project name not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
content = pattern.sub(new_project_name, content)
|
||||
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
|
@ -39,15 +42,15 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
|
|||
# Updates the dependency name for uv
|
||||
if not pattern.search(content):
|
||||
msg = f"{replacement} uv dependency not found in {filepath}"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
content = pattern.sub(replacement, content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 3:
|
||||
if len(sys.argv) != ARGUMENT_NUMBER:
|
||||
msg = "Must specify project name and build type, e.g. langflow-nightly base"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
new_project_name = sys.argv[1]
|
||||
build_type = sys.argv[2]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
|
@ -5,6 +7,7 @@ from pathlib import Path
|
|||
import packaging.version
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
ARGUMENT_NUMBER = 3
|
||||
|
||||
|
||||
def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
|
||||
|
|
@ -17,7 +20,7 @@ def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
|
|||
|
||||
if not pattern.search(content):
|
||||
msg = f'Project version not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
content = pattern.sub(new_version, content)
|
||||
|
||||
|
|
@ -29,16 +32,13 @@ def verify_pep440(version):
|
|||
|
||||
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
|
||||
"""
|
||||
try:
|
||||
return packaging.version.Version(version)
|
||||
except packaging.version.InvalidVersion:
|
||||
raise
|
||||
return packaging.version.Version(version)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 3:
|
||||
if len(sys.argv) != ARGUMENT_NUMBER:
|
||||
msg = "New version not specified"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
new_version = sys.argv[1]
|
||||
|
||||
# Strip "v" prefix from version if present
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
ARGUMENT_NUMBER = 2
|
||||
|
||||
|
||||
def update_uv_dep(base_version: str) -> None:
|
||||
|
|
@ -19,7 +22,7 @@ def update_uv_dep(base_version: str) -> None:
|
|||
# Check if the pattern is found
|
||||
if not pattern.search(content):
|
||||
msg = f"{pattern} UV dependency not found in {pyproject_path}"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Replace the matched pattern with the new one
|
||||
content = pattern.sub(replacement, content)
|
||||
|
|
@ -29,9 +32,9 @@ def update_uv_dep(base_version: str) -> None:
|
|||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 2:
|
||||
if len(sys.argv) != ARGUMENT_NUMBER:
|
||||
msg = "specify base version"
|
||||
raise Exception(msg)
|
||||
raise ValueError(msg)
|
||||
base_version = sys.argv[1]
|
||||
base_version = base_version.lstrip("v")
|
||||
update_uv_dep(base_version)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from pydantic import BaseModel
|
|||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.custom.tree_visitor import RequiredInputsVisitor
|
||||
from langflow.field_typing import Tool # noqa: TCH001 Needed by add_toolkit_output
|
||||
from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Version package."""
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
"""Module for package versioning."""
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
"""Retrieves the version of the package from a possible list of package names.
|
||||
|
||||
This accounts for after package names are updated for -nightly builds.
|
||||
|
||||
Returns:
|
||||
|
|
@ -32,7 +35,9 @@ def get_version() -> str:
|
|||
|
||||
|
||||
def is_pre_release(v: str) -> bool:
|
||||
"""Returns a boolean indicating whether the version is a pre-release version,
|
||||
"""Returns a boolean indicating whether the version is a pre-release version.
|
||||
|
||||
Returns a boolean indicating whether the version is a pre-release version,
|
||||
as per the definition of a pre-release segment from PEP 440.
|
||||
"""
|
||||
return any(label in v for label in ["a", "b", "rc"])
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_b
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
|
@ -102,7 +101,7 @@ def _delete_transactions_and_vertex_builds(session, user: User):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog: LogCaptureFixture):
|
||||
def caplog(caplog: pytest.LogCaptureFixture):
|
||||
handler_id = logger.add(
|
||||
caplog.handler,
|
||||
format="{message}",
|
||||
|
|
@ -144,7 +143,7 @@ def load_flows_dir():
|
|||
|
||||
|
||||
@pytest.fixture(name="distributed_env")
|
||||
def setup_env(monkeypatch):
|
||||
def _setup_env(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
|
||||
|
|
@ -158,7 +157,11 @@ def setup_env(monkeypatch):
|
|||
|
||||
|
||||
@pytest.fixture(name="distributed_client")
|
||||
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
||||
def distributed_client_fixture(
|
||||
session: Session, # noqa: ARG001
|
||||
monkeypatch,
|
||||
distributed_env, # noqa: ARG001
|
||||
):
|
||||
# Here we load the .env from ../deploy/.env
|
||||
from langflow.core import celery_app
|
||||
|
||||
|
|
@ -273,7 +276,12 @@ def json_memory_chatbot_no_llm():
|
|||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
async def client_fixture(session: Session, monkeypatch, request, load_flows_dir):
|
||||
async def client_fixture(
|
||||
session: Session, # noqa: ARG001
|
||||
monkeypatch,
|
||||
request,
|
||||
load_flows_dir,
|
||||
):
|
||||
# Set the database url to a test database
|
||||
if "noclient" in request.keywords:
|
||||
yield
|
||||
|
|
@ -296,9 +304,11 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
|
|||
db_service.database_url = f"sqlite:///{db_path}"
|
||||
db_service.reload_engine()
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
|
||||
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
|
||||
yield client
|
||||
async with (
|
||||
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
|
||||
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client,
|
||||
):
|
||||
yield client
|
||||
# app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
# clear the temp db
|
||||
|
|
@ -308,7 +318,7 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir)
|
|||
|
||||
# create a fixture for session_getter above
|
||||
@pytest.fixture(name="session_getter")
|
||||
def session_getter_fixture(client):
|
||||
def session_getter_fixture(client): # noqa: ARG001
|
||||
@contextmanager
|
||||
def blank_session_getter(db_service: "DatabaseService"):
|
||||
with Session(db_service.engine) as session:
|
||||
|
|
@ -326,7 +336,7 @@ def runner():
|
|||
async def test_user(client):
|
||||
user_data = UserCreate(
|
||||
username="testuser",
|
||||
password="testpassword",
|
||||
password="testpassword", # noqa: S106
|
||||
)
|
||||
response = await client.post("api/v1/users/", json=user_data.model_dump())
|
||||
assert response.status_code == 201
|
||||
|
|
@ -337,7 +347,7 @@ async def test_user(client):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def active_user(client):
|
||||
def active_user(client): # noqa: ARG001
|
||||
db_manager = get_db_service()
|
||||
with db_manager.with_session() as session:
|
||||
user = User(
|
||||
|
|
@ -375,7 +385,11 @@ async def logged_in_headers(client, active_user):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def flow(client, json_flow: str, active_user):
|
||||
def flow(
|
||||
client, # noqa: ARG001
|
||||
json_flow: str,
|
||||
active_user,
|
||||
):
|
||||
from langflow.services.database.models.flow.model import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ class TestComponent(CustomComponent):
|
|||
def refresh_values(self):
|
||||
# This is a function that will be called every time the component is updated
|
||||
# and should return a list of random strings
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311
|
||||
|
||||
def build_config(self):
|
||||
return {"param": {"display_name": "Param", "options": self.refresh_values}}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class MultipleOutputsComponent(Component):
|
|||
]
|
||||
|
||||
def certain_output(self) -> int:
|
||||
return randint(0, self.number)
|
||||
return randint(0, self.number) # noqa: S311
|
||||
|
||||
def other_output(self) -> int:
|
||||
return self.certain_output()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class TestComponent(CustomComponent):
|
|||
def refresh_values(self):
|
||||
# This is a function that will be called every time the component is updated
|
||||
# and should return a list of random strings
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)]
|
||||
return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311
|
||||
|
||||
def build_config(self):
|
||||
return {"param": Input(display_name="Param", options=self.refresh_values)}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.api_keys import get_openai_api_key
|
||||
from tests.integration.utils import download_flow_from_github, run_json_flow
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain_core.documents import Document
|
|||
from langflow.components.embeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.vectorstores import AstraVectorStoreComponent
|
||||
from langflow.schema.data import Data
|
||||
|
||||
from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import ComponentInputHandle, run_single_component
|
||||
|
|
@ -27,7 +28,7 @@ ALL_COLLECTIONS = [
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def astradb_client(request):
|
||||
def astradb_client():
|
||||
client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token())
|
||||
yield client
|
||||
for collection in ALL_COLLECTIONS:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.schema import Data
|
||||
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import ComponentInputHandle, run_single_component
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from langflow.components.inputs import ChatInput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from langflow.components.inputs import TextInputComponent
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.output_parsers.OutputParser import OutputParserComponent
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
|
||||
from tests.integration.utils import ComponentInputHandle, run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from langflow.components.outputs import TextOutputComponent
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from langflow.components.prompts import PromptComponent
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,18 +4,19 @@ from langflow.components.outputs import ChatOutput
|
|||
from langflow.components.prompts import PromptComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from tests.integration.utils import run_flow
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_no_llm():
|
||||
graph = Graph()
|
||||
input = graph.add_component(ChatInput())
|
||||
output = graph.add_component(ChatOutput())
|
||||
flow_input = graph.add_component(ChatInput())
|
||||
flow_output = graph.add_component(ChatOutput())
|
||||
component = PromptComponent(template="This is the message: {var1}", var1="")
|
||||
prompt = graph.add_component(component)
|
||||
graph.add_component_edge(input, ("message", "var1"), prompt)
|
||||
graph.add_component_edge(prompt, ("prompt", "input_value"), output)
|
||||
graph.add_component_edge(flow_input, ("message", "var1"), prompt)
|
||||
graph.add_component_edge(prompt, ("prompt", "input_value"), flow_output)
|
||||
outputs = await run_flow(graph, run_input="hello!")
|
||||
assert isinstance(outputs["message"], Message)
|
||||
assert outputs["message"].text == "This is the message: hello!"
|
||||
|
|
|
|||
|
|
@ -12,26 +12,26 @@ from langflow.graph import Graph
|
|||
from langflow.processing.process import run_graph_internal
|
||||
|
||||
|
||||
def check_env_vars(*vars):
|
||||
def check_env_vars(*env_vars):
|
||||
"""Check if all specified environment variables are set.
|
||||
|
||||
Args:
|
||||
*vars (str): The environment variables to check.
|
||||
*env_vars (str): The environment variables to check.
|
||||
|
||||
Returns:
|
||||
bool: True if all environment variables are set, False otherwise.
|
||||
"""
|
||||
return all(os.getenv(var) for var in vars)
|
||||
return all(os.getenv(var) for var in env_vars)
|
||||
|
||||
|
||||
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
|
||||
"""Check if the specified region is valid.
|
||||
|
||||
Args:
|
||||
region (str): The region to check.
|
||||
api_endpoint: The API endpoint to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the region is contains hosted nvidia models, False otherwise.
|
||||
True if the region contains hosted nvidia models, False otherwise.
|
||||
"""
|
||||
parsed_endpoint = parse_api_endpoint(api_endpoint)
|
||||
if not parsed_endpoint:
|
||||
|
|
@ -63,12 +63,12 @@ class JSONFlow:
|
|||
json: dict
|
||||
|
||||
def get_components_by_type(self, component_type):
|
||||
result = []
|
||||
for node in self.json["data"]["nodes"]:
|
||||
if node["data"]["type"] == component_type:
|
||||
result.append(node["id"])
|
||||
result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type]
|
||||
if not result:
|
||||
msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
|
||||
msg = (
|
||||
f"Component of type {component_type} not found, "
|
||||
f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return result
|
||||
|
||||
|
|
@ -97,7 +97,8 @@ class JSONFlow:
|
|||
|
||||
def download_flow_from_github(name: str, version: str) -> JSONFlow:
|
||||
response = requests.get(
|
||||
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json"
|
||||
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json",
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
as_json = response.json()
|
||||
|
|
@ -151,7 +152,7 @@ async def run_single_component(
|
|||
raw_inputs[key] = value
|
||||
if isinstance(value, Component):
|
||||
msg = "Component inputs must be wrapped in ComponentInputHandle"
|
||||
raise ValueError(msg)
|
||||
raise TypeError(msg)
|
||||
component = clazz(**raw_inputs, _user_id=user_id)
|
||||
component_id = graph.add_component(component)
|
||||
if inputs:
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class NameTest(FastHttpUser):
|
|||
|
||||
@task
|
||||
def send_name_and_check(self):
|
||||
name = random.choice(self.names)
|
||||
name = random.choice(self.names) # noqa: S311
|
||||
|
||||
payload1 = {
|
||||
"inputs": {"text": f"Hello, My name is {name}"},
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@ def test_get_suggestion_message():
|
|||
|
||||
# Test case 3: Multiple outdated components
|
||||
outdated_components = ["component1", "component2", "component3"]
|
||||
expected_message = "The flow contains 3 outdated components. We recommend updating the following components: component1, component2, component3."
|
||||
expected_message = (
|
||||
"The flow contains 3 outdated components. "
|
||||
"We recommend updating the following components: component1, component2, component3."
|
||||
)
|
||||
assert get_suggestion_message(outdated_components) == expected_message
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__HTTPException(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers):
|
||||
status_code = 418
|
||||
generic_message = "I'm a teapot"
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, logged_
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_create_variable__Exception(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers):
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m:
|
||||
|
|
@ -133,16 +133,10 @@ async def test_read_variables__empty(client: AsyncClient, logged_in_headers):
|
|||
async def test_read_variables__(client: AsyncClient, logged_in_headers):
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m:
|
||||
with mock.patch("sqlmodel.Session.exec") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
|
||||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
assert generic_message in str(exc.value)
|
||||
with pytest.raises(Exception, match=generic_message):
|
||||
await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
|
|
@ -165,7 +159,7 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_update_variable__Exception(client: AsyncClient, body, logged_in_headers):
|
||||
async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers):
|
||||
wrong_id = uuid4()
|
||||
body["id"] = str(wrong_id)
|
||||
|
||||
|
|
@ -186,7 +180,7 @@ async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_delete_variable__Exception(client: AsyncClient, logged_in_headers):
|
||||
async def test_delete_variable__exception(client: AsyncClient, logged_in_headers):
|
||||
wrong_id = uuid4()
|
||||
|
||||
response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)
|
||||
|
|
|
|||
|
|
@ -26,4 +26,5 @@ def test_run_flow_from_json_params():
|
|||
params = func_spec.args + func_spec.kwonlyargs
|
||||
assert expected_params.issubset(params), "Not all expected parameters are present in run_flow_from_json"
|
||||
|
||||
# TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the correct output
|
||||
# TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the
|
||||
# correct output
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def add_toolkit_output():
|
||||
def _add_toolkit_output():
|
||||
FEATURE_FLAGS.add_toolkit_output = True
|
||||
yield
|
||||
FEATURE_FLAGS.add_toolkit_output = False
|
||||
|
|
@ -81,7 +81,7 @@ def test_component_tool():
|
|||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
@pytest.mark.usefixtures("add_toolkit_output", "client")
|
||||
@pytest.mark.usefixtures("_add_toolkit_output", "client")
|
||||
def test_component_tool_with_api_key():
|
||||
chat_output = ChatOutput()
|
||||
openai_llm = OpenAIModelComponent()
|
||||
|
|
|
|||
|
|
@ -1,34 +1,60 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.schema.data import Data
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class TestStructuredOutputComponent:
|
||||
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by the mock function
|
||||
# Ensure that the structured output is successfully generated with the correct BaseModel instance returned by
|
||||
# the mock function
|
||||
def test_successful_structured_output_generation_with_patch_with_config(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
class MockLanguageModel:
|
||||
def with_structured_output(self, schema):
|
||||
class MockLanguageModel(BaseLanguageModel):
|
||||
@override
|
||||
def with_structured_output(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def with_config(self, config):
|
||||
@override
|
||||
def with_config(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def invoke(self, inputs):
|
||||
@override
|
||||
def invoke(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def mock_get_chat_result(runnable, input_value, config):
|
||||
@override
|
||||
def generate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def agenerate_prompt(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def predict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def apredict_messages(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mock_get_chat_result(runnable, input_value, config): # noqa: ARG001
|
||||
class MockBaseModel(BaseModel):
|
||||
def model_dump(self):
|
||||
@override
|
||||
def model_dump(self, **kwargs):
|
||||
return {"field": "value"}
|
||||
|
||||
return MockBaseModel()
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ def test_update_build_config_keep_alive(component):
|
|||
"langchain_community.chat_models.ChatOllama",
|
||||
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
|
||||
)
|
||||
def test_build_model(_mock_chat_ollama, component):
|
||||
def test_build_model(_mock_chat_ollama, component): # noqa: PT019
|
||||
component.base_url = "http://localhost:11434"
|
||||
component.model_name = "llama3.1"
|
||||
component.mirostat = "Mirostat 2.0"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.components.prompts.Prompt import PromptComponent # type: ignore
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
|
||||
|
||||
class TestPromptComponent:
|
||||
|
|
|
|||
|
|
@ -109,9 +109,5 @@ def test_validate_text_key_invalid(create_data_component):
|
|||
create_data_component.text_key = "invalid_key"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match="Text Key: 'invalid_key' not found in the Data keys: 'key1, key2'"):
|
||||
create_data_component.validate_text_key()
|
||||
|
||||
# Check for the exact error message
|
||||
expected_error_message = f"Text Key: '{create_data_component.text_key}' not found in the Data keys: '{', '.join(create_data_component.get_data().keys())}'"
|
||||
assert str(exc_info.value) == expected_error_message
|
||||
|
|
|
|||
|
|
@ -96,10 +96,5 @@ def test_validate_text_key_invalid(update_data_component):
|
|||
data = Data(data={"key1": "value1", "key2": "value2"}, text_key="key1")
|
||||
update_data_component.text_key = "invalid_key"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match="Text Key: invalid_key not found in the Data keys: key1,key2"):
|
||||
update_data_component.validate_text_key(data)
|
||||
|
||||
expected_error_message = (
|
||||
f"Text Key: {update_data_component.text_key} not found in the Data keys: {','.join(data.data.keys())}"
|
||||
)
|
||||
assert str(exc_info.value) == expected_error_message
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from langflow.template.field.base import Output
|
|||
def test_set_invalid_output():
|
||||
chatinput = ChatInput()
|
||||
chatoutput = ChatOutput()
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Method build_config is not a valid output of ChatInput"):
|
||||
chatoutput.set(input_value=chatinput.build_config)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -90,9 +90,9 @@ class TestEventManager:
|
|||
|
||||
queue = asyncio.Queue()
|
||||
manager = EventManager(queue)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Event name cannot be empty"):
|
||||
manager.register_event("", "test_type", mock_callback)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Event name must start with 'on_'"):
|
||||
manager.register_event("invalid_name", "test_type", mock_callback)
|
||||
|
||||
# Sending an event with complex data and verifying successful event transmission
|
||||
|
|
|
|||
|
|
@ -16,28 +16,32 @@ def test_api_exception():
|
|||
}
|
||||
# Expected result
|
||||
|
||||
with patch(
|
||||
"langflow.services.database.models.flow.utils.get_outdated_components", return_value=mock_outdated_components
|
||||
with (
|
||||
patch(
|
||||
"langflow.services.database.models.flow.utils.get_outdated_components",
|
||||
return_value=mock_outdated_components,
|
||||
),
|
||||
patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message),
|
||||
patch(
|
||||
"langflow.services.database.models.flow.utils.get_components_versions",
|
||||
return_value=mock_component_versions,
|
||||
),
|
||||
):
|
||||
with patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message):
|
||||
with patch(
|
||||
"langflow.services.database.models.flow.utils.get_components_versions",
|
||||
return_value=mock_component_versions,
|
||||
):
|
||||
# Create an APIException instance
|
||||
api_exception = APIException(mock_exception, mock_flow)
|
||||
# Create an APIException instance
|
||||
api_exception = APIException(mock_exception, mock_flow)
|
||||
|
||||
# Expected body
|
||||
expected_body = ExceptionBody(
|
||||
message="Test exception",
|
||||
suggestion="The flow contains 2 outdated components. We recommend updating the following components: component1, component2.",
|
||||
)
|
||||
# Expected body
|
||||
expected_body = ExceptionBody(
|
||||
message="Test exception",
|
||||
suggestion="The flow contains 2 outdated components. "
|
||||
"We recommend updating the following components: component1, component2.",
|
||||
)
|
||||
|
||||
# Assert the status code
|
||||
assert api_exception.status_code == 500
|
||||
# Assert the status code
|
||||
assert api_exception.status_code == 500
|
||||
|
||||
# Assert the detail
|
||||
assert api_exception.detail == expected_body.model_dump_json()
|
||||
# Assert the detail
|
||||
assert api_exception.detail == expected_body.model_dump_json()
|
||||
|
||||
|
||||
def test_api_exception_no_flow():
|
||||
|
|
|
|||
|
|
@ -22,27 +22,27 @@ class TestCreateStateModel:
|
|||
# Successfully create a model with valid method return type annotations
|
||||
|
||||
def test_create_model_with_valid_return_type_annotations(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response)
|
||||
state_model = create_state_model(method_one=chat_input_component.message_response)
|
||||
|
||||
state_instance = StateModel()
|
||||
state_instance = state_model()
|
||||
assert state_instance.method_one is UNDEFINED
|
||||
chat_input_component.set_output_value("message", "test")
|
||||
assert state_instance.method_one == "test"
|
||||
|
||||
def test_create_model_and_assign_values_fails(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response)
|
||||
state_model = create_state_model(method_one=chat_input_component.message_response)
|
||||
|
||||
state_instance = StateModel()
|
||||
state_instance = state_model()
|
||||
state_instance.method_one = "test"
|
||||
assert state_instance.method_one == "test"
|
||||
|
||||
def test_create_with_multiple_components(self, chat_input_component, chat_output_component):
|
||||
NewStateModel = create_state_model(
|
||||
new_state_model = create_state_model(
|
||||
model_name="NewStateModel",
|
||||
first_method=chat_input_component.message_response,
|
||||
second_method=chat_output_component.message_response,
|
||||
)
|
||||
state_instance = NewStateModel()
|
||||
state_instance = new_state_model()
|
||||
assert state_instance.first_method is UNDEFINED
|
||||
assert state_instance.second_method is UNDEFINED
|
||||
state_instance.first_method = "test"
|
||||
|
|
@ -51,9 +51,9 @@ class TestCreateStateModel:
|
|||
assert state_instance.second_method == 123
|
||||
|
||||
def test_create_with_pydantic_field(self, chat_input_component):
|
||||
StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
|
||||
state_model = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None))
|
||||
|
||||
state_instance = StateModel()
|
||||
state_instance = state_model()
|
||||
state_instance.method_one = "test"
|
||||
state_instance.my_attribute = "test"
|
||||
assert state_instance.method_one == "test"
|
||||
|
|
@ -64,8 +64,8 @@ class TestCreateStateModel:
|
|||
|
||||
# Creates a model with fields based on provided keyword arguments
|
||||
def test_create_model_with_fields_from_kwargs(self):
|
||||
StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123))
|
||||
state_instance = StateModel()
|
||||
state_model = create_state_model(field_one=(str, "default"), field_two=(int, 123))
|
||||
state_instance = state_model()
|
||||
assert state_instance.field_one == "default"
|
||||
assert state_instance.field_two == 123
|
||||
|
||||
|
|
@ -81,16 +81,16 @@ class TestCreateStateModel:
|
|||
|
||||
# Handles empty keyword arguments gracefully
|
||||
def test_handle_empty_kwargs_gracefully(self):
|
||||
StateModel = create_state_model()
|
||||
state_instance = StateModel()
|
||||
state_model = create_state_model()
|
||||
state_instance = state_model()
|
||||
assert state_instance is not None
|
||||
|
||||
# Ensures model name defaults to "State" if not provided
|
||||
def test_default_model_name_to_state(self):
|
||||
StateModel = create_state_model()
|
||||
assert StateModel.__name__ == "State"
|
||||
OtherNameModel = create_state_model(model_name="OtherName")
|
||||
assert OtherNameModel.__name__ == "OtherName"
|
||||
state_model = create_state_model()
|
||||
assert state_model.__name__ == "State"
|
||||
other_name_model = create_state_model(model_name="OtherName")
|
||||
assert other_name_model.__name__ == "OtherName"
|
||||
|
||||
# Validates that callable values are properly type-annotated
|
||||
|
||||
|
|
@ -110,8 +110,7 @@ class TestCreateStateModel:
|
|||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response)
|
||||
chat_state_model = ChatStateModel()
|
||||
chat_state_model = create_state_model(model_name="ChatState", message=chat_output.message_response)()
|
||||
assert chat_state_model.__class__.__name__ == "ChatState"
|
||||
assert chat_state_model.message is UNDEFINED
|
||||
|
||||
|
|
@ -121,9 +120,7 @@ class TestCreateStateModel:
|
|||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from langflow.components.outputs.TextOutput import TextOutputComponent
|
|||
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -19,12 +18,12 @@ async def test_graph_not_prepared():
|
|||
graph = Graph()
|
||||
graph.add_component(chat_input)
|
||||
graph.add_component(chat_output)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Graph not prepared"):
|
||||
await graph.astep()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph(caplog: LogCaptureFixture):
|
||||
async def test_graph(caplog: pytest.LogCaptureFixture):
|
||||
chat_input = ChatInput()
|
||||
chat_output = ChatOutput()
|
||||
graph = Graph()
|
||||
|
|
@ -83,9 +82,7 @@ async def test_graph_functional_async_start():
|
|||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
async for result in graph.async_start():
|
||||
results.append(result)
|
||||
results = [result async for result in graph.async_start()]
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
@ -102,9 +99,7 @@ def test_graph_functional_start():
|
|||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
@ -123,9 +118,7 @@ def test_graph_functional_start_end():
|
|||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "text_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == len(ids) + 1
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
|
|||
|
|
@ -103,11 +103,9 @@ def test_cycle_in_graph_max_iterations():
|
|||
# Run queue should contain chat_input and not router
|
||||
assert "chat_input" in graph._run_queue
|
||||
assert "router" not in graph._run_queue
|
||||
results = []
|
||||
|
||||
with pytest.raises(ValueError, match="Max iterations reached"):
|
||||
for result in graph.start(max_iterations=2, config={"output": {"cache": False}}):
|
||||
results.append(result)
|
||||
list(graph.start(max_iterations=2, config={"output": {"cache": False}}))
|
||||
|
||||
|
||||
def test_that_outputs_cache_is_set_to_false_in_cycle():
|
||||
|
|
@ -149,7 +147,10 @@ def test_updated_graph_with_prompts():
|
|||
|
||||
# First prompt: Guessing game with hints
|
||||
prompt_component_1 = PromptComponent(_id="prompt_component_1").set(
|
||||
template="Try to guess a word. I will give you hints if you get it wrong.\nHint: {hint}\nLast try: {last_try}\nAnswer:",
|
||||
template="Try to guess a word. I will give you hints if you get it wrong.\n"
|
||||
"Hint: {hint}\n"
|
||||
"Last try: {last_try}\n"
|
||||
"Answer:",
|
||||
)
|
||||
|
||||
# First OpenAI LLM component (Processes the guessing prompt)
|
||||
|
|
@ -168,7 +169,10 @@ def test_updated_graph_with_prompts():
|
|||
# Second prompt: After the last try, provide a new hint
|
||||
prompt_component_2 = PromptComponent(_id="prompt_component_2")
|
||||
prompt_component_2.set(
|
||||
template="Given the following word and the following last try. Give the guesser a new hint.\nLast try: {last_try}\nWord: {word}\nHint:",
|
||||
template="Given the following word and the following last try. Give the guesser a new hint.\n"
|
||||
"Last try: {last_try}\n"
|
||||
"Word: {word}\n"
|
||||
"Hint:",
|
||||
word=chat_input.message_response,
|
||||
last_try=router.false_response,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ AI: """
|
|||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
assert GraphStateModel.__name__ == "GraphStateModel"
|
||||
assert list(GraphStateModel.model_computed_fields.keys()) == [
|
||||
graph_state_model = create_state_model_from_graph(graph)
|
||||
assert graph_state_model.__name__ == "GraphStateModel"
|
||||
assert list(graph_state_model.model_computed_fields.keys()) == [
|
||||
"chat_input",
|
||||
"chat_output",
|
||||
"openai",
|
||||
|
|
@ -60,12 +60,9 @@ def test_graph_functional_start_graph_state_update():
|
|||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model = GraphStateModel()
|
||||
graph_state_model = create_state_model_from_graph(graph)()
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
@ -87,12 +84,9 @@ def test_graph_state_model_serialization():
|
|||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model = GraphStateModel()
|
||||
graph_state_model = create_state_model_from_graph(graph)()
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
|
|
@ -116,8 +110,7 @@ def test_graph_state_model_json_schema():
|
|||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
|
||||
GraphStateModel = create_state_model_from_graph(graph)
|
||||
graph_state_model: BaseModel = GraphStateModel()
|
||||
graph_state_model: BaseModel = create_state_model_from_graph(graph)()
|
||||
json_schema = graph_state_model.model_json_schema(mode="serialization")
|
||||
|
||||
# Test main schema structure
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def test_pickle(data):
|
|||
manager = RunnableVerticesManager.from_dict(data)
|
||||
|
||||
binary = pickle.dumps(manager)
|
||||
result = pickle.loads(binary)
|
||||
result = pickle.loads(binary) # noqa: S301
|
||||
|
||||
assert result.run_map == manager.run_map
|
||||
assert result.run_predecessors == manager.run_predecessors
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def test_sort_up_to_vertex_a(graph):
|
|||
def test_sort_up_to_vertex_invalid_vertex(graph):
|
||||
vertex_id = "7"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Parent node map is required to find the root of a group node"):
|
||||
utils.sort_up_to_vertex(graph, vertex_id)
|
||||
|
||||
|
||||
|
|
@ -432,7 +432,7 @@ class TestFindCycleVertices:
|
|||
assert sorted(result) == sorted(expected_output)
|
||||
|
||||
@pytest.mark.parametrize("_", range(5))
|
||||
def test_handle_two_inputs_in_cycle(self, _):
|
||||
def test_handle_two_inputs_in_cycle(self, _): # noqa: PT019
|
||||
edges = [
|
||||
("chat_input", "router"),
|
||||
("chat_input", "concatenate"),
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ def test_invalid_node_types():
|
|||
],
|
||||
"edges": [],
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
g = Graph()
|
||||
g = Graph()
|
||||
with pytest.raises(KeyError):
|
||||
g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,13 +79,13 @@ class TestBuildModelFromSchema:
|
|||
{"name": "field1", "type": "str", "default": "default_value1"},
|
||||
{"name": "field2", "type": "unknown_type", "default": "default_value2"},
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Invalid type: unknown_type"):
|
||||
build_model_from_schema(schema)
|
||||
|
||||
# Confirms that the function raises a specific exception for invalid input
|
||||
def test_raises_error_for_invalid_input_different_exception_with_specific_exception(self):
|
||||
with pytest.raises(ValueError):
|
||||
schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}]
|
||||
schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}]
|
||||
with pytest.raises(ValueError, match="Invalid type: invalid_type"):
|
||||
build_model_from_schema(schema)
|
||||
|
||||
# Processes schemas with missing optional keys like description or multiple
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ AI: """
|
|||
return graph
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
def test_memory_chatbot(memory_chatbot_graph):
|
||||
# Now we run step by step
|
||||
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def ingestion_graph():
|
|||
embedding=openai_embeddings.build_embeddings,
|
||||
ingest_data=text_splitter.split_text,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token",
|
||||
token="token", # noqa: S106
|
||||
)
|
||||
vector_store.set_on_output(name="vector_store", value="mock_vector_store", cache=True)
|
||||
vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
|
||||
|
|
@ -53,7 +53,7 @@ def rag_graph():
|
|||
rag_vector_store.set(
|
||||
search_input=chat_input.message_response,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token",
|
||||
token="token", # noqa: S106
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
)
|
||||
# Mock search_documents
|
||||
|
|
@ -110,9 +110,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
|
|||
"openai-embeddings-124",
|
||||
]
|
||||
for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]:
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
results = list(graph.start())
|
||||
|
||||
assert len(results) == len_results
|
||||
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
|
|
@ -217,12 +215,14 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
|
|||
rag_graph_copy = copy.deepcopy(rag_graph)
|
||||
ingestion_graph_copy += rag_graph_copy
|
||||
|
||||
assert (
|
||||
len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices)
|
||||
), f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} != {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}"
|
||||
assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len(
|
||||
rag_graph.edges
|
||||
), f"Edges mismatch: {len(ingestion_graph_copy.edges)} != {len(ingestion_graph.edges)} + {len(rag_graph.edges)}"
|
||||
assert len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices), (
|
||||
f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} "
|
||||
f"!= {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}"
|
||||
)
|
||||
assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len(rag_graph.edges), (
|
||||
f"Edges mismatch: {len(ingestion_graph_copy.edges)} "
|
||||
f"!= {len(ingestion_graph.edges)} + {len(rag_graph.edges)}"
|
||||
)
|
||||
|
||||
combined_graph_dump = ingestion_graph_copy.dump(
|
||||
name="Combined Graph", description="Graph for data ingestion and RAG", endpoint_name="combined"
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def test_instantiate_input_valid():
|
|||
|
||||
|
||||
def test_instantiate_input_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Invalid input type: InvalidInput"):
|
||||
instantiate_input("InvalidInput", {"name": "invalid_input", "value": "This is a string"})
|
||||
|
||||
|
||||
|
|
@ -224,5 +224,5 @@ def test_instantiate_input_comprehensive():
|
|||
input_instance = instantiate_input(input_type, data)
|
||||
assert isinstance(input_instance, InputTypesMap[input_type])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Invalid input type: InvalidInput"):
|
||||
instantiate_input("InvalidInput", {"name": "invalid_input", "value": "Invalid"})
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class TestCreateInputSchema:
|
|||
input_instance = StrInput(name="test_field", is_list=True)
|
||||
schema = create_input_schema([input_instance])
|
||||
field_info = schema.model_fields["test_field"]
|
||||
assert field_info.annotation == list[str] # type: ignore
|
||||
assert field_info.annotation == list[str]
|
||||
|
||||
# Converting FieldTypes to corresponding Python types
|
||||
def test_field_types_conversion(self):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class TestColumn:
|
|||
|
||||
# Invalid formatter raises ValueError
|
||||
def test_invalid_formatter_raises_value_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="'invalid' is not a valid FormatterType"):
|
||||
Column(display_name="Invalid Column", name="invalid_column", formatter="invalid")
|
||||
|
||||
# Formatter is None when not provided
|
||||
|
|
|
|||
|
|
@ -72,19 +72,16 @@ def test_get_variable(service, session):
|
|||
assert result == value
|
||||
|
||||
|
||||
def test_get_variable__ValueError(service, session):
|
||||
def test_get_variable__valueerror(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
field = ""
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_get_variable__TypeError(service, session):
|
||||
def test_get_variable__typeerror(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
|
@ -142,17 +139,14 @@ def test_update_variable(service, session):
|
|||
assert isinstance(result.updated_at, datetime)
|
||||
|
||||
|
||||
def test_update_variable__ValueError(service, session):
|
||||
def test_update_variable__valueerror(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.update_variable(user_id, name, value, session=session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_update_variable_fields(service, session):
|
||||
user_id = uuid4()
|
||||
|
|
@ -192,26 +186,21 @@ def test_delete_variable(service, session):
|
|||
service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert recovered == value
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_delete_variable__ValueError(service, session):
|
||||
def test_delete_variable__valueerror(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.delete_variable(user_id, name, session=session)
|
||||
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_delete_varaible_by_id(service, session):
|
||||
def test_delete_variable_by_id(service, session):
|
||||
user_id = uuid4()
|
||||
name = "name"
|
||||
value = "value"
|
||||
|
|
@ -220,24 +209,19 @@ def test_delete_varaible_by_id(service, session):
|
|||
saved = service.create_variable(user_id, name, value, session=session)
|
||||
recovered = service.get_variable(user_id, name, field, session=session)
|
||||
service.delete_variable_by_id(user_id, saved.id, session=session)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{name} variable not found."):
|
||||
service.get_variable(user_id, name, field, session)
|
||||
|
||||
assert recovered == value
|
||||
assert name in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_delete_variable_by_id__ValueError(service, session):
|
||||
def test_delete_variable_by_id__valueerror(service, session):
|
||||
user_id = uuid4()
|
||||
variable_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with pytest.raises(ValueError, match=f"{variable_id} variable not found."):
|
||||
service.delete_variable_by_id(user_id, variable_id, session=session)
|
||||
|
||||
assert str(variable_id) in str(exc.value)
|
||||
assert "variable not found" in str(exc.value)
|
||||
|
||||
|
||||
def test_create_variable(service, session):
|
||||
user_id = uuid4()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ from langflow.services.database.models.api_key import ApiKeyCreate
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def api_key(client, logged_in_headers, active_user):
|
||||
async def api_key(
|
||||
client,
|
||||
logged_in_headers,
|
||||
active_user, # noqa: ARG001
|
||||
):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = await client.post("api/v1/api_key/", data=api_key.model_dump_json(), headers=logged_in_headers)
|
||||
|
|
|
|||
|
|
@ -51,9 +51,7 @@ def test_code_parser_get_tree():
|
|||
|
||||
|
||||
def test_code_parser_syntax_error():
|
||||
"""Test the __get_tree method raises the
|
||||
CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
"""Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax."""
|
||||
code_syntax_error = "zzz import os"
|
||||
|
||||
parser = CodeParser(code_syntax_error)
|
||||
|
|
@ -76,9 +74,7 @@ def test_component_get_code_tree():
|
|||
|
||||
|
||||
def test_component_code_null_error():
|
||||
"""Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
"""Test the get_function method raises the ComponentCodeNullError when the code is empty."""
|
||||
component = BaseComponent(_code="", _function_entrypoint_name="")
|
||||
with pytest.raises(ComponentCodeNullError):
|
||||
component.get_function()
|
||||
|
|
@ -108,9 +104,7 @@ def test_custom_component_get_function():
|
|||
|
||||
|
||||
def test_code_parser_parse_imports_import():
|
||||
"""Test the parse_imports method of the CodeParser
|
||||
class with an import statement.
|
||||
"""
|
||||
"""Test the parse_imports method of the CodeParser class with an import statement."""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
|
|
@ -120,9 +114,7 @@ def test_code_parser_parse_imports_import():
|
|||
|
||||
|
||||
def test_code_parser_parse_imports_importfrom():
|
||||
"""Test the parse_imports method of the CodeParser
|
||||
class with an import from statement.
|
||||
"""
|
||||
"""Test the parse_imports method of the CodeParser class with an import from statement."""
|
||||
parser = CodeParser("from os import path")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
|
|
@ -157,9 +149,9 @@ def test_code_parser_parse_classes_raises():
|
|||
"""Test the parse_classes method of the CodeParser class."""
|
||||
parser = CodeParser("class Test: pass")
|
||||
tree = parser.get_tree()
|
||||
with pytest.raises(TypeError):
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
with pytest.raises(TypeError):
|
||||
parser.parse_classes(node)
|
||||
|
||||
|
||||
|
|
@ -175,18 +167,14 @@ def test_code_parser_parse_global_vars():
|
|||
|
||||
|
||||
def test_component_get_function_valid():
|
||||
"""Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
"""Test the get_function method of the Component class with valid code and function_entrypoint_name."""
|
||||
component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
my_function = component.get_function()
|
||||
assert callable(my_function)
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args():
|
||||
"""Test the get_function_entrypoint_args
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
"""Test the get_function_entrypoint_args property of the CustomComponent class."""
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
args = custom_component.get_function_entrypoint_args
|
||||
assert len(args) == 3
|
||||
|
|
@ -196,9 +184,7 @@ def test_custom_component_get_function_entrypoint_args():
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type():
|
||||
"""Test the get_function_entrypoint_return_type
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
"""Test the get_function_entrypoint_return_type property of the CustomComponent class."""
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == [Document]
|
||||
|
|
@ -212,9 +198,7 @@ def test_custom_component_get_main_class_name():
|
|||
|
||||
|
||||
def test_custom_component_get_function_valid():
|
||||
"""Test the get_function property of the CustomComponent
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
"""Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name."""
|
||||
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function
|
||||
assert callable(my_function)
|
||||
|
|
@ -239,9 +223,7 @@ def test_code_parser_parse_arg_with_annotation():
|
|||
|
||||
|
||||
def test_code_parser_parse_callable_details_no_args():
|
||||
"""Test the parse_callable_details method of the
|
||||
CodeParser class with a function with no arguments.
|
||||
"""
|
||||
"""Test the parse_callable_details method of the CodeParser class with a function with no arguments."""
|
||||
parser = CodeParser("")
|
||||
node = ast.FunctionDef(
|
||||
name="test",
|
||||
|
|
@ -280,9 +262,7 @@ def test_code_parser_parse_ann_assign():
|
|||
|
||||
|
||||
def test_code_parser_parse_function_def_not_init():
|
||||
"""Test the parse_function_def method of the
|
||||
CodeParser class with a function that is not __init__.
|
||||
"""
|
||||
"""Test the parse_function_def method of the CodeParser class with a function that is not __init__."""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="test",
|
||||
|
|
@ -297,9 +277,7 @@ def test_code_parser_parse_function_def_not_init():
|
|||
|
||||
|
||||
def test_code_parser_parse_function_def_init():
|
||||
"""Test the parse_function_def method of the
|
||||
CodeParser class with an __init__ function.
|
||||
"""
|
||||
"""Test the parse_function_def method of the CodeParser class with an __init__ function."""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.FunctionDef(
|
||||
name="__init__",
|
||||
|
|
@ -314,36 +292,28 @@ def test_code_parser_parse_function_def_init():
|
|||
|
||||
|
||||
def test_component_get_code_tree_syntax_error():
|
||||
"""Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
"""Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax."""
|
||||
component = BaseComponent(_code="import os as", _function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
component.get_code_tree(component._code)
|
||||
|
||||
|
||||
def test_custom_component_class_template_validation_no_code():
|
||||
"""Test the _class_template_validation method of the CustomComponent class
|
||||
raises the HTTPException when the code is None.
|
||||
"""
|
||||
"""Test CustomComponent._class_template_validation raises the HTTPException when the code is None."""
|
||||
custom_component = CustomComponent(_code=None, _function_entrypoint_name="build")
|
||||
with pytest.raises(TypeError):
|
||||
custom_component.get_function()
|
||||
|
||||
|
||||
def test_custom_component_get_code_tree_syntax_error():
|
||||
"""Test the get_code_tree method of the CustomComponent class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
"""Test CustomComponent.get_code_tree raises the CodeSyntaxError when given incorrect syntax."""
|
||||
custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
custom_component.get_code_tree(custom_component._code)
|
||||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args_no_args():
|
||||
"""Test the get_function_entrypoint_args property of
|
||||
the CustomComponent class with a build method with no arguments.
|
||||
"""
|
||||
"""Test CustomComponent.get_function_entrypoint_args with a build method with no arguments."""
|
||||
my_code = """
|
||||
from langflow.custom import CustomComponent
|
||||
class MyMainClass(CustomComponent):
|
||||
|
|
@ -356,9 +326,7 @@ class MyMainClass(CustomComponent):
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
|
||||
"""Test the get_function_entrypoint_return_type property of the
|
||||
CustomComponent class with a build method with no return type.
|
||||
"""
|
||||
"""Test CustomComponent.get_function_entrypoint_return_type with a build method with no return type."""
|
||||
my_code = """
|
||||
from langflow.custom import CustomComponent
|
||||
class MyClass(CustomComponent):
|
||||
|
|
@ -371,9 +339,7 @@ class MyClass(CustomComponent):
|
|||
|
||||
|
||||
def test_custom_component_get_main_class_name_no_main_class():
|
||||
"""Test the get_main_class_name property of the
|
||||
CustomComponent class when there is no main class.
|
||||
"""
|
||||
"""Test the get_main_class_name property of the CustomComponent class when there is no main class."""
|
||||
my_code = """
|
||||
def build():
|
||||
pass"""
|
||||
|
|
@ -384,9 +350,7 @@ def build():
|
|||
|
||||
|
||||
def test_custom_component_build_not_implemented():
|
||||
"""Test the build method of the CustomComponent
|
||||
class raises the NotImplementedError.
|
||||
"""
|
||||
"""Test the build method of the CustomComponent class raises the NotImplementedError."""
|
||||
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
with pytest.raises(NotImplementedError):
|
||||
custom_component.build()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,10 @@ def code_component_with_multiple_outputs():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def component(client, active_user):
|
||||
def component(
|
||||
client, # noqa: ARG001
|
||||
active_user,
|
||||
):
|
||||
return CustomComponent(
|
||||
user_id=active_user.id,
|
||||
field_config={
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import orjson
|
||||
|
|
@ -13,7 +13,6 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
|||
from langflow.services.database.models.folder.model import FolderCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -179,11 +178,11 @@ async def test_delete_flows_with_transaction_and_build(client: TestClient, logge
|
|||
assert response.status_code == 201
|
||||
flow_ids.append(response.json()["id"])
|
||||
|
||||
class VertexTuple(NamedTuple):
|
||||
id: str
|
||||
|
||||
# Create a transaction for each flow
|
||||
|
||||
for flow_id in flow_ids:
|
||||
VertexTuple = namedtuple("VertexTuple", ["id"])
|
||||
|
||||
await log_transaction(
|
||||
str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success"
|
||||
)
|
||||
|
|
@ -249,10 +248,11 @@ async def test_delete_folder_with_flows_with_transaction_and_build(client: TestC
|
|||
assert response.status_code == 201
|
||||
flow_ids.append(response.json()["id"])
|
||||
|
||||
class VertexTuple(NamedTuple):
|
||||
id: str
|
||||
|
||||
# Create a transaction for each flow
|
||||
for flow_id in flow_ids:
|
||||
VertexTuple = namedtuple("VertexTuple", ["id"])
|
||||
|
||||
await log_transaction(
|
||||
str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success"
|
||||
)
|
||||
|
|
@ -400,9 +400,9 @@ async def test_upload_file(client: TestClient, json_flow: str, logged_in_headers
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("session")
|
||||
async def test_download_file(
|
||||
client: TestClient,
|
||||
session: Session,
|
||||
json_flow,
|
||||
active_user,
|
||||
logged_in_headers,
|
||||
|
|
@ -419,14 +419,14 @@ async def test_download_file(
|
|||
]
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
with session_getter(db_manager) as _session:
|
||||
saved_flows = []
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
session.add(db_flow)
|
||||
_session.add(db_flow)
|
||||
saved_flows.append(db_flow)
|
||||
session.commit()
|
||||
_session.commit()
|
||||
# Make request to endpoint inside the session context
|
||||
flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings
|
||||
flow_ids_json = json.dumps(flow_ids)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import time
|
||||
import asyncio
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
|
@ -27,7 +27,7 @@ async def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1)
|
|||
)
|
||||
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
await asyncio.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,13 @@ def mock_storage_service():
|
|||
|
||||
|
||||
@pytest.fixture(name="files_client")
|
||||
async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service):
|
||||
async def files_client_fixture(
|
||||
session: Session, # noqa: ARG001
|
||||
monkeypatch,
|
||||
request,
|
||||
load_flows_dir,
|
||||
mock_storage_service,
|
||||
):
|
||||
# Set the database url to a test database
|
||||
if "noclient" in request.keywords:
|
||||
yield
|
||||
|
|
@ -47,9 +53,11 @@ async def files_client_fixture(session: Session, monkeypatch, request, load_flow
|
|||
app = create_app()
|
||||
|
||||
app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
|
||||
async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager:
|
||||
async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client:
|
||||
yield client
|
||||
async with (
|
||||
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
|
||||
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client,
|
||||
):
|
||||
yield client
|
||||
# app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
# clear the temp db
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ async def test_create_or_update_starter_projects():
|
|||
assert folder is not None
|
||||
num_db_projects = len(folder.flows)
|
||||
|
||||
# Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
|
||||
# Check that the number of projects in the database is the same as the number of projects returned by
|
||||
# load_starter_projects
|
||||
assert num_db_projects == num_projects
|
||||
|
||||
|
||||
|
|
@ -76,7 +77,8 @@ async def test_create_or_update_starter_projects():
|
|||
# # Get the number of projects in the database
|
||||
# num_db_projects = session.exec(select(func.count(Flow.id)).where(Flow.folder == STARTER_FOLDER_NAME)).one()
|
||||
|
||||
# # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects
|
||||
# # Check that the number of projects in the database is the same as the number of projects returned by
|
||||
# # load_starter_projects
|
||||
# assert num_db_projects == num_projects
|
||||
|
||||
# # Get all the starter projects
|
||||
|
|
@ -99,7 +101,7 @@ async def test_create_or_update_starter_projects():
|
|||
# delete_messages(session_id="test")
|
||||
|
||||
|
||||
def find_componeny_by_name(components, name):
|
||||
def find_component_by_name(components, name):
|
||||
for children in components.values():
|
||||
if name in children:
|
||||
return children[name]
|
||||
|
|
@ -111,17 +113,17 @@ def set_value(component, input_name, value):
|
|||
component["template"][input_name]["value"] = value
|
||||
|
||||
|
||||
def component_to_node(id, type, component):
|
||||
return {"id": type + id, "data": {"node": component, "type": type, "id": id}}
|
||||
def component_to_node(node_id, node_type, component):
|
||||
return {"id": node_type + node_id, "data": {"node": component, "type": node_type, "id": node_id}}
|
||||
|
||||
|
||||
def add_edge(input, output, from_output, to_input):
|
||||
def add_edge(source, target, from_output, to_input):
|
||||
return {
|
||||
"source": input,
|
||||
"target": output,
|
||||
"source": source,
|
||||
"target": target,
|
||||
"data": {
|
||||
"sourceHandle": {"dataType": "ChatInput", "id": input, "name": from_output, "output_types": ["Message"]},
|
||||
"targetHandle": {"fieldName": to_input, "id": output, "inputTypes": ["Message"], "type": "str"},
|
||||
"sourceHandle": {"dataType": "ChatInput", "id": source, "name": from_output, "output_types": ["Message"]},
|
||||
"targetHandle": {"fieldName": to_input, "id": target, "inputTypes": ["Message"], "type": "str"},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -131,8 +133,8 @@ async def test_refresh_starter_projects():
|
|||
data_path = str(Path(__file__).parent.parent.parent.absolute() / "base" / "langflow" / "components")
|
||||
components = build_custom_component_list_from_path(data_path)
|
||||
|
||||
chat_input = find_componeny_by_name(components, "ChatInput")
|
||||
chat_output = find_componeny_by_name(components, "ChatOutput")
|
||||
chat_input = find_component_by_name(components, "ChatInput")
|
||||
chat_output = find_component_by_name(components, "ChatOutput")
|
||||
chat_output["template"]["code"]["value"] = "changed !"
|
||||
del chat_output["template"]["should_store_message"]
|
||||
graph_data = {
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from langflow.services.variable.kubernetes_secrets import KubernetesSecretManage
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kube_config(mocker):
|
||||
def _mock_kube_config(mocker):
|
||||
mocker.patch("kubernetes.config.load_kube_config")
|
||||
mocker.patch("kubernetes.config.load_incluster_config")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secret_manager(mock_kube_config):
|
||||
def secret_manager(_mock_kube_config):
|
||||
return KubernetesSecretManager(namespace="test-namespace")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,15 +19,15 @@ def created_message():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
def created_messages(session): # noqa: ARG001
|
||||
with session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
messagetables = add_messagetables(messagetables, session)
|
||||
messagetables = add_messagetables(messagetables, _session)
|
||||
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,15 +20,15 @@ async def created_message():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
def created_messages(session): # noqa: ARG001
|
||||
with session_scope() as _session:
|
||||
messages = [
|
||||
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"),
|
||||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
return add_messagetables(messagetables, session)
|
||||
return add_messagetables(messagetables, _session)
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
|
|
@ -282,12 +282,16 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data):
|
|||
# session_service = get_session_service()
|
||||
# session_id1 = "non-existent-session-id"
|
||||
# session_id = session_service.build_key(session_id1, basic_graph_data)
|
||||
# graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
# graph1, artifacts1 = await session_service.load_session(
|
||||
# session_id, data_graph=basic_graph_data, flow_id="flow_id"
|
||||
# )
|
||||
# # Clear the cache
|
||||
# await session_service.clear_session(session_id)
|
||||
# # Use the new session_id to get the graph again
|
||||
# graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
|
||||
# graph2, artifacts2 = await session_service.load_session(
|
||||
# session_id, data_graph=basic_graph_data, flow_id="flow_id"
|
||||
# )
|
||||
#
|
||||
# # Since the cache was cleared, objects should be different
|
||||
# assert id(graph1) != id(graph2)
|
||||
|
||||
|
|
@ -297,8 +301,12 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data):
|
|||
# # Provide a non-existent session_id
|
||||
# session_service = get_session_service()
|
||||
# session_id1 = None
|
||||
# graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
# graph1, artifacts1 = await session_service.load_session(
|
||||
# session_id1, data_graph=basic_graph_data, flow_id="flow_id"
|
||||
# )
|
||||
# # Use the new session_id to get the langchain_object again
|
||||
# graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id")
|
||||
|
||||
# graph2, artifacts2 = await session_service.load_session(
|
||||
# session_id1, data_graph=basic_graph_data, flow_id="flow_id"
|
||||
# )
|
||||
#
|
||||
# assert graph1 == graph2
|
||||
|
|
|
|||
|
|
@ -45,52 +45,52 @@ class TestInput:
|
|||
assert set(post_process_type(SequenceABC[float])) == {float}
|
||||
|
||||
# Union types
|
||||
assert set(post_process_type(Union[int, str])) == {int, str}
|
||||
assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str}
|
||||
assert set(post_process_type(Union[int, SequenceABC[int]])) == {int}
|
||||
assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007
|
||||
assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str} # noqa: UP007
|
||||
assert set(post_process_type(Union[int, SequenceABC[int]])) == {int} # noqa: UP007
|
||||
|
||||
# Nested Union with lists
|
||||
assert set(post_process_type(Union[list[int], list[str]])) == {int, str}
|
||||
assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float}
|
||||
assert set(post_process_type(Union[list[int], list[str]])) == {int, str} # noqa: UP007
|
||||
assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float} # noqa: UP007
|
||||
|
||||
# Custom data types
|
||||
assert set(post_process_type(Data)) == {Data}
|
||||
assert set(post_process_type(list[Data])) == {Data}
|
||||
|
||||
# Union with custom types
|
||||
assert set(post_process_type(Union[Data, str])) == {Data, str}
|
||||
assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str}
|
||||
assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007
|
||||
assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str} # noqa: UP007
|
||||
|
||||
# Empty lists and edge cases
|
||||
assert set(post_process_type(list)) == {list}
|
||||
assert set(post_process_type(Union[int, None])) == {int, NoneType}
|
||||
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType}
|
||||
assert set(post_process_type(Union[int, None])) == {int, NoneType} # noqa: UP007
|
||||
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # noqa: UP007
|
||||
|
||||
# Handling complex nested structures
|
||||
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float}
|
||||
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float}
|
||||
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} # noqa: UP007
|
||||
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float} # noqa: UP007
|
||||
|
||||
# Non-generic types should return as is
|
||||
assert set(post_process_type(dict)) == {dict}
|
||||
assert set(post_process_type(tuple)) == {tuple}
|
||||
|
||||
# Union with custom types
|
||||
assert set(post_process_type(Union[Data, str])) == {Data, str}
|
||||
assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007
|
||||
assert set(post_process_type(Data | str)) == {Data, str}
|
||||
assert set(post_process_type(Data | int | list[str])) == {Data, int, str}
|
||||
|
||||
# More complex combinations with Data
|
||||
assert set(post_process_type(Data | list[float])) == {Data, float}
|
||||
assert set(post_process_type(Data | Union[int, str])) == {Data, int, str}
|
||||
assert set(post_process_type(Data | Union[int, str])) == {Data, int, str} # noqa: UP007
|
||||
assert set(post_process_type(Data | list[int] | None)) == {Data, int, type(None)}
|
||||
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)}
|
||||
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} # noqa: UP007
|
||||
|
||||
# Multiple Data types combined
|
||||
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float}
|
||||
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str}
|
||||
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float} # noqa: UP007
|
||||
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} # noqa: UP007
|
||||
|
||||
# Testing with nested unions and lists
|
||||
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str}
|
||||
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str} # noqa: UP007
|
||||
assert set(post_process_type(Data | list[float | str])) == {Data, float, str}
|
||||
|
||||
def test_input_to_dict(self):
|
||||
|
|
@ -157,7 +157,7 @@ class TestPostProcessType:
|
|||
assert post_process_type(list[int]) == [int]
|
||||
|
||||
def test_union_type(self):
|
||||
assert set(post_process_type(Union[int, str])) == {int, str}
|
||||
assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007
|
||||
|
||||
def test_custom_type(self):
|
||||
class CustomType:
|
||||
|
|
@ -175,4 +175,4 @@ class TestPostProcessType:
|
|||
class CustomType:
|
||||
pass
|
||||
|
||||
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int}
|
||||
assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007
|
||||
|
|
|
|||
|
|
@ -112,11 +112,11 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting
|
|||
@patch("langflow.services.deps.get_settings_service")
|
||||
@patch("langflow.services.deps.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service):
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
admin_user_name = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
mock_settings_service.auth_settings.SUPERUSER = admin_user_name
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
|
@ -11,7 +11,7 @@ from sqlmodel import select
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user(client):
|
||||
def super_user(client): # noqa: ARG001
|
||||
settings_manager = get_settings_service()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
with session_getter(get_db_service()) as session:
|
||||
|
|
@ -23,7 +23,10 @@ def super_user(client):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def super_user_headers(client: AsyncClient, super_user):
|
||||
async def super_user_headers(
|
||||
client: AsyncClient,
|
||||
super_user, # noqa: ARG001
|
||||
):
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
login_data = {
|
||||
|
|
@ -38,14 +41,14 @@ async def super_user_headers(client: AsyncClient, super_user):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def deactivated_user(client):
|
||||
def deactivated_user(client): # noqa: ARG001
|
||||
with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="deactivateduser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
is_superuser=False,
|
||||
last_login_at=datetime.now(),
|
||||
last_login_at=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -55,7 +58,7 @@ def deactivated_user(client):
|
|||
|
||||
async def test_user_waiting_for_approval(client):
|
||||
username = "waitingforapproval"
|
||||
password = "testpassword"
|
||||
password = "testpassword" # noqa: S105
|
||||
|
||||
# Debug: Check if the user already exists
|
||||
with session_getter(get_db_service()) as session:
|
||||
|
|
@ -140,7 +143,7 @@ async def test_inactive_user(client: AsyncClient):
|
|||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at=datetime(2023, 1, 1, 0, 0, 0),
|
||||
last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
|
@ -205,7 +208,7 @@ async def test_patch_user(client: AsyncClient, active_user, logged_in_headers):
|
|||
async def test_patch_reset_password(client: AsyncClient, active_user, logged_in_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(
|
||||
password="newpassword",
|
||||
password="newpassword", # noqa: S106
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_openai_api_key_in_environment_variables():
|
||||
def _check_openai_api_key_in_environment_variables():
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue