ref: Auto-fix ruff rules in tests (#4154)
This commit is contained in:
parent
51b3909d60
commit
45c8f98692
80 changed files with 359 additions and 456 deletions
|
|
@ -1,7 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py
|
||||
"""
|
||||
"""Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py."""
|
||||
|
||||
import sys
|
||||
|
||||
|
|
@ -24,13 +22,15 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
|
|||
elif build_type == "main":
|
||||
url = PYPI_LANGFLOW_NIGHTLY_URL if is_nightly else PYPI_LANGFLOW_URL
|
||||
else:
|
||||
raise ValueError(f"Invalid build type: {build_type}")
|
||||
msg = f"Invalid build type: {build_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
res = requests.get(url)
|
||||
try:
|
||||
version_str = res.json()["info"]["version"]
|
||||
except Exception as e:
|
||||
raise RuntimeError("Got unexpected response from PyPI", e)
|
||||
msg = "Got unexpected response from PyPI"
|
||||
raise RuntimeError(msg, e)
|
||||
return Version(version_str)
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,8 @@ def create_tag(build_type: str):
|
|||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
raise Exception("Specify base or main")
|
||||
msg = "Specify base or main"
|
||||
raise Exception(msg)
|
||||
|
||||
build_type = sys.argv[1]
|
||||
tag = create_tag(build_type)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import sys
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
|
|
@ -10,39 +10,38 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
def update_base_dep(pyproject_path: str, new_version: str) -> None:
|
||||
"""Update the langflow-base dependency in pyproject.toml."""
|
||||
filepath = BASE_DIR / pyproject_path
|
||||
content = filepath.read_text()
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
|
||||
replacement = f'langflow-base-nightly = "{new_version}"'
|
||||
|
||||
# Updates the pattern for poetry
|
||||
pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}')
|
||||
if not pattern.search(content):
|
||||
raise Exception(f'langflow-base poetry dependency not found in "{filepath}"')
|
||||
msg = f'langflow-base poetry dependency not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
content = pattern.sub(replacement, content)
|
||||
filepath.write_text(content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def verify_pep440(version):
|
||||
"""
|
||||
Verify if version is PEP440 compliant.
|
||||
"""Verify if version is PEP440 compliant.
|
||||
|
||||
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
|
||||
"""
|
||||
|
||||
try:
|
||||
return packaging.version.Version(version)
|
||||
except packaging.version.InvalidVersion as e:
|
||||
raise e
|
||||
except packaging.version.InvalidVersion:
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 2:
|
||||
raise Exception("New version not specified")
|
||||
msg = "New version not specified"
|
||||
raise Exception(msg)
|
||||
base_version = sys.argv[1]
|
||||
|
||||
# Strip "v" prefix from version if present
|
||||
if base_version.startswith("v"):
|
||||
base_version = base_version[1:]
|
||||
base_version = base_version.removeprefix("v")
|
||||
|
||||
verify_pep440(base_version)
|
||||
update_base_dep("pyproject.toml", base_version)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import sys
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
|
|
@ -8,22 +8,23 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
|
||||
"""Update the project name in pyproject.toml."""
|
||||
filepath = BASE_DIR / pyproject_path
|
||||
content = filepath.read_text()
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
|
||||
# Regex to match the version line under [tool.poetry]
|
||||
pattern = re.compile(r'(?<=^name = ")[^"]+(?=")', re.MULTILINE)
|
||||
|
||||
if not pattern.search(content):
|
||||
raise Exception(f'Project name not found in "{filepath}"')
|
||||
msg = f'Project name not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
content = pattern.sub(new_project_name, content)
|
||||
|
||||
filepath.write_text(content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
|
||||
"""Update the langflow-base dependency in pyproject.toml."""
|
||||
filepath = BASE_DIR / pyproject_path
|
||||
content = filepath.read_text()
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
|
||||
if new_project_name == "langflow-nightly":
|
||||
pattern = re.compile(r"langflow = \{ workspace = true \}")
|
||||
|
|
@ -32,18 +33,21 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
|
|||
pattern = re.compile(r"langflow-base = \{ workspace = true \}")
|
||||
replacement = "langflow-base-nightly = { workspace = true }"
|
||||
else:
|
||||
raise ValueError(f"Invalid project name: {new_project_name}")
|
||||
msg = f"Invalid project name: {new_project_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Updates the dependency name for uv
|
||||
if not pattern.search(content):
|
||||
raise Exception(f"{replacement} uv dependency not found in {filepath}")
|
||||
msg = f"{replacement} uv dependency not found in {filepath}"
|
||||
raise Exception(msg)
|
||||
content = pattern.sub(replacement, content)
|
||||
filepath.write_text(content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 3:
|
||||
raise Exception("Must specify project name and build type, e.g. langflow-nightly base")
|
||||
msg = "Must specify project name and build type, e.g. langflow-nightly base"
|
||||
raise Exception(msg)
|
||||
new_project_name = sys.argv[1]
|
||||
build_type = sys.argv[2]
|
||||
|
||||
|
|
@ -54,7 +58,8 @@ def main() -> None:
|
|||
update_pyproject_name("pyproject.toml", new_project_name)
|
||||
update_uv_dep("pyproject.toml", new_project_name)
|
||||
else:
|
||||
raise ValueError(f"Invalid build type: {build_type}")
|
||||
msg = f"Invalid build type: {build_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import sys
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
|
|
@ -10,40 +10,39 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
|
||||
"""Update the version in pyproject.toml."""
|
||||
filepath = BASE_DIR / pyproject_path
|
||||
content = filepath.read_text()
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
|
||||
# Regex to match the version line under [tool.poetry]
|
||||
pattern = re.compile(r'(?<=^version = ")[^"]+(?=")', re.MULTILINE)
|
||||
|
||||
if not pattern.search(content):
|
||||
raise Exception(f'Project version not found in "{filepath}"')
|
||||
msg = f'Project version not found in "{filepath}"'
|
||||
raise Exception(msg)
|
||||
|
||||
content = pattern.sub(new_version, content)
|
||||
|
||||
filepath.write_text(content)
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def verify_pep440(version):
|
||||
"""
|
||||
Verify if version is PEP440 compliant.
|
||||
"""Verify if version is PEP440 compliant.
|
||||
|
||||
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
|
||||
"""
|
||||
|
||||
try:
|
||||
return packaging.version.Version(version)
|
||||
except packaging.version.InvalidVersion as e:
|
||||
raise e
|
||||
except packaging.version.InvalidVersion:
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 3:
|
||||
raise Exception("New version not specified")
|
||||
msg = "New version not specified"
|
||||
raise Exception(msg)
|
||||
new_version = sys.argv[1]
|
||||
|
||||
# Strip "v" prefix from version if present
|
||||
if new_version.startswith("v"):
|
||||
new_version = new_version[1:]
|
||||
new_version = new_version.removeprefix("v")
|
||||
|
||||
build_type = sys.argv[2]
|
||||
|
||||
|
|
@ -54,7 +53,8 @@ def main() -> None:
|
|||
elif build_type == "main":
|
||||
update_pyproject_version("pyproject.toml", new_version)
|
||||
else:
|
||||
raise ValueError(f"Invalid build type: {build_type}")
|
||||
msg = f"Invalid build type: {build_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import sys
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
|
|
@ -7,30 +7,31 @@ BASE_DIR = Path(__file__).parent.parent.parent
|
|||
|
||||
def update_uv_dep(base_version: str) -> None:
|
||||
"""Update the langflow-base dependency in pyproject.toml."""
|
||||
|
||||
pyproject_path = BASE_DIR / "pyproject.toml"
|
||||
|
||||
# Read the pyproject.toml file content
|
||||
content = pyproject_path.read_text()
|
||||
content = pyproject_path.read_text(encoding="utf-8")
|
||||
|
||||
# For the main project, update the langflow-base dependency in the UV section
|
||||
pattern = re.compile(r'(dependencies\s*=\s*\[\s*\n\s*)("langflow-base==[\d.]+")')
|
||||
replacement = r'\1"langflow-base-nightly=={}"'.format(base_version)
|
||||
replacement = rf'\1"langflow-base-nightly=={base_version}"'
|
||||
|
||||
# Check if the pattern is found
|
||||
if not pattern.search(content):
|
||||
raise Exception(f"{pattern} UV dependency not found in {pyproject_path}")
|
||||
msg = f"{pattern} UV dependency not found in {pyproject_path}"
|
||||
raise Exception(msg)
|
||||
|
||||
# Replace the matched pattern with the new one
|
||||
content = pattern.sub(replacement, content)
|
||||
|
||||
# Write the updated content back to the file
|
||||
pyproject_path.write_text(content)
|
||||
pyproject_path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 2:
|
||||
raise Exception("specify base version")
|
||||
msg = "specify base version"
|
||||
raise Exception(msg)
|
||||
base_version = sys.argv[1]
|
||||
base_version = base_version.lstrip("v")
|
||||
update_uv_dep(base_version)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
# ]
|
||||
# ///
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from huggingface_hub import HfApi, list_models
|
||||
from rich import print
|
||||
|
|
@ -23,11 +24,11 @@ space = parsed_args.space
|
|||
|
||||
if not space:
|
||||
print("Please provide a space to restart.")
|
||||
exit()
|
||||
sys.exit()
|
||||
|
||||
if not parsed_args.token:
|
||||
print("Please provide an API token.")
|
||||
exit()
|
||||
sys.exit()
|
||||
|
||||
# Or configure a HfApi client
|
||||
hf_api = HfApi(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import contextlib
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
"""
|
||||
Retrieves the version of the package from a possible list of package names.
|
||||
"""Retrieves the version of the package from a possible list of package names.
|
||||
This accounts for after package names are updated for -nightly builds.
|
||||
|
||||
Returns:
|
||||
|
|
@ -19,20 +21,18 @@ def get_version() -> str:
|
|||
]
|
||||
_version = None
|
||||
for pkg_name in pkg_names:
|
||||
try:
|
||||
with contextlib.suppress(ImportError, metadata.PackageNotFoundError):
|
||||
_version = metadata.version(pkg_name)
|
||||
except (ImportError, metadata.PackageNotFoundError):
|
||||
pass
|
||||
|
||||
if _version is None:
|
||||
raise ValueError(f"Package not found from options {pkg_names}")
|
||||
msg = f"Package not found from options {pkg_names}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return _version
|
||||
|
||||
|
||||
def is_pre_release(v: str) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether the version is a pre-release version,
|
||||
"""Returns a boolean indicating whether the version is a pre-release version,
|
||||
as per the definition of a pre-release segment from PEP 440.
|
||||
"""
|
||||
return any(label in v for label in ["a", "b", "rc"])
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import os.path
|
|||
|
||||
|
||||
def get_required_env_var(var: str) -> str:
|
||||
"""
|
||||
Get the value of the specified environment variable.
|
||||
"""Get the value of the specified environment variable.
|
||||
|
||||
Args:
|
||||
var (str): The environment variable to get.
|
||||
|
|
@ -18,7 +17,8 @@ def get_required_env_var(var: str) -> str:
|
|||
"""
|
||||
value = os.getenv(var)
|
||||
if not value:
|
||||
raise ValueError(f"Environment variable {var} is not set")
|
||||
msg = f"Environment variable {var} is not set"
|
||||
raise ValueError(msg)
|
||||
return value
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,13 +16,6 @@ from base.langflow.components.inputs.ChatInput import ChatInput
|
|||
from dotenv import load_dotenv
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from tests.api_keys import get_openai_api_key
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
|
@ -34,6 +27,13 @@ from langflow.services.database.models.user.model import User, UserCreate, UserR
|
|||
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from tests.api_keys import get_openai_api_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
|
@ -114,7 +114,7 @@ def caplog(caplog: LogCaptureFixture):
|
|||
logger.remove(handler_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
||||
|
|
@ -188,8 +188,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
|||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
|
||||
"""Get a graph from a json file."""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
|
|
@ -197,7 +196,7 @@ def get_graph(_type="basic"):
|
|||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with path.open() as f:
|
||||
with path.open(encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
|
|
@ -209,13 +208,13 @@ def get_graph(_type="basic"):
|
|||
|
||||
@pytest.fixture
|
||||
def basic_graph_data():
|
||||
with pytest.BASIC_EXAMPLE_PATH.open() as f:
|
||||
with pytest.BASIC_EXAMPLE_PATH.open(encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
yield get_graph()
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -230,47 +229,47 @@ def openapi_graph():
|
|||
|
||||
@pytest.fixture
|
||||
def json_flow():
|
||||
return pytest.BASIC_EXAMPLE_PATH.read_text()
|
||||
return pytest.BASIC_EXAMPLE_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grouped_chat_json_flow():
|
||||
return pytest.GROUPED_CHAT_EXAMPLE_PATH.read_text()
|
||||
return pytest.GROUPED_CHAT_EXAMPLE_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def one_grouped_chat_json_flow():
|
||||
return pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH.read_text()
|
||||
return pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_grouped_json_flow():
|
||||
return pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH.read_text()
|
||||
return pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_flow_with_prompt_and_history():
|
||||
return pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY.read_text()
|
||||
return pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_simple_api_test():
|
||||
return pytest.SIMPLE_API_TEST.read_text()
|
||||
return pytest.SIMPLE_API_TEST.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_vector_store():
|
||||
return pytest.VECTOR_STORE_PATH.read_text()
|
||||
return pytest.VECTOR_STORE_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_webhook_test():
|
||||
return pytest.WEBHOOK_TEST.read_text()
|
||||
return pytest.WEBHOOK_TEST.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_memory_chatbot_no_llm():
|
||||
return pytest.MEMORY_CHATBOT_NO_LLM.read_text()
|
||||
return pytest.MEMORY_CHATBOT_NO_LLM.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
|
|
@ -315,12 +314,12 @@ def session_getter_fixture(client):
|
|||
with Session(db_service.engine) as session:
|
||||
yield session
|
||||
|
||||
yield blank_session_getter
|
||||
return blank_session_getter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
yield CliRunner()
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -337,7 +336,7 @@ async def test_user(client):
|
|||
await client.delete(f"/api/v1/users/{user['id']}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def active_user(client):
|
||||
db_manager = get_db_service()
|
||||
with db_manager.with_session() as session:
|
||||
|
|
@ -372,7 +371,7 @@ async def logged_in_headers(client, active_user):
|
|||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
yield {"Authorization": f"Bearer {a_token}"}
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -395,12 +394,12 @@ def flow(client, json_flow: str, active_user):
|
|||
|
||||
@pytest.fixture
|
||||
def json_chat_input():
|
||||
return pytest.CHAT_INPUT.read_text()
|
||||
return pytest.CHAT_INPUT.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_two_outputs():
|
||||
return pytest.TWO_OUTPUTS.read_text()
|
||||
return pytest.TWO_OUTPUTS.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -528,7 +527,8 @@ def get_starter_project(active_user):
|
|||
.where(Flow.name == "Basic Prompting (Hello, World)")
|
||||
).first()
|
||||
if not flow:
|
||||
raise ValueError("No starter project found")
|
||||
msg = "No starter project found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# ensure openai api key is set
|
||||
get_openai_api_key()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
import os
|
||||
|
||||
from astrapy.db import AstraDB
|
||||
import pytest
|
||||
|
||||
from astrapy.db import AstraDB
|
||||
from langchain_core.documents import Document
|
||||
from langflow.components.embeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.vectorstores import AstraVectorStoreComponent
|
||||
from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import ComponentInputHandle
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from langflow.schema.data import Data
|
||||
from tests.integration.utils import run_single_component
|
||||
from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import ComponentInputHandle, run_single_component
|
||||
|
||||
BASIC_COLLECTION = "test_basic"
|
||||
SEARCH_COLLECTION = "test_search"
|
||||
|
|
@ -30,7 +26,7 @@ ALL_COLLECTIONS = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def astradb_client(request):
|
||||
client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token())
|
||||
yield client
|
||||
|
|
@ -139,7 +135,7 @@ def test_astra_vectorize():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
def test_astra_vectorize_with_provider_api_key():
|
||||
"""tests vectorize using an openai api key"""
|
||||
"""Tests vectorize using an openai api key."""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
application_token = get_astradb_application_token()
|
||||
|
|
@ -196,7 +192,7 @@ def test_astra_vectorize_with_provider_api_key():
|
|||
|
||||
@pytest.mark.api_key_required
|
||||
def test_astra_vectorize_passes_authentication():
|
||||
"""tests vectorize using the authentication parameter"""
|
||||
"""Tests vectorize using the authentication parameter."""
|
||||
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
|
||||
|
||||
store = None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.schema import Data
|
||||
from tests.integration.components.mock_components import TextToData
|
||||
from tests.integration.utils import run_single_component, ComponentInputHandle
|
||||
from tests.integration.utils import ComponentInputHandle, run_single_component
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.schema.message import Message
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
from langflow.components.inputs import ChatInput
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default():
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import pytest
|
||||
from langflow.components.inputs import TextInputComponent
|
||||
from langflow.schema.message import Message
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
from langflow.components.inputs import TextInputComponent
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_input():
|
||||
outputs = await run_single_component(TextInputComponent, run_input="sample text", input_type="text")
|
||||
print(outputs)
|
||||
assert isinstance(outputs["text"], Message)
|
||||
assert outputs["text"].text == "sample text"
|
||||
assert outputs["text"].sender is None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import StrInput, BoolInput
|
||||
from langflow.inputs import BoolInput, StrInput
|
||||
from langflow.schema import Data
|
||||
from langflow.template import Output
|
||||
|
||||
|
|
@ -21,5 +20,5 @@ class TextToData(Component):
|
|||
return Data(data=json.loads(text))
|
||||
return Data(text=text)
|
||||
|
||||
def create_data(self) -> List[Data]:
|
||||
def create_data(self) -> list[Data]:
|
||||
return [self._to_data(t) for t in self.text_data]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.output_parsers.OutputParser import OutputParserComponent
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
|
|
@ -23,7 +23,7 @@ async def test_csv_output_parser_openai():
|
|||
prompt_handler = ComponentInputHandle(
|
||||
clazz=PromptComponent,
|
||||
inputs={
|
||||
"template": "List the first five positive integers.\n\n{format_instructions}",
|
||||
"template": "List the first five positive integers.\n\n{format_instructions}", # noqa: RUF027
|
||||
"format_instructions": format_instructions,
|
||||
},
|
||||
output_name="prompt",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.memory import get_messages
|
||||
from langflow.schema.message import Message
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string():
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import pytest
|
||||
from langflow.components.outputs import TextOutputComponent
|
||||
from langflow.schema.message import Message
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import pytest
|
||||
from langflow.components.prompts import PromptComponent
|
||||
from langflow.schema.message import Message
|
||||
from tests.integration.utils import run_single_component
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
outputs = await run_single_component(PromptComponent, inputs={"template": "test {var1}", "var1": "from the var"})
|
||||
print(outputs)
|
||||
assert isinstance(outputs["prompt"], Message)
|
||||
assert outputs["prompt"].text == "test from the var"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.outputs import ChatOutput
|
||||
from langflow.components.prompts import PromptComponent
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.load import run_flow_from_json
|
||||
|
|
@ -80,9 +79,8 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_
|
|||
@pytest.mark.noclient
|
||||
@pytest.mark.api_key_required
|
||||
def test_run_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
_, projects = zip(*load_starter_projects())
|
||||
project = [project for project in projects if "Basic Prompting" in project["name"]][0]
|
||||
"""Test loading a flow from a json file and applying tweaks."""
|
||||
project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"])
|
||||
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
|
||||
assert results is not None
|
||||
assert all(isinstance(result, RunOutputs) for result in results)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,19 @@
|
|||
import dataclasses
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from astrapy.admin import parse_api_endpoint
|
||||
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Embeddings
|
||||
from langflow.graph import Graph
|
||||
from langflow.processing.process import run_graph_internal
|
||||
import requests
|
||||
|
||||
|
||||
def check_env_vars(*vars):
|
||||
"""
|
||||
Check if all specified environment variables are set.
|
||||
"""Check if all specified environment variables are set.
|
||||
|
||||
Args:
|
||||
*vars (str): The environment variables to check.
|
||||
|
|
@ -27,8 +25,7 @@ def check_env_vars(*vars):
|
|||
|
||||
|
||||
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
|
||||
"""
|
||||
Check if the specified region is valid.
|
||||
"""Check if the specified region is valid.
|
||||
|
||||
Args:
|
||||
region (str): The region to check.
|
||||
|
|
@ -38,8 +35,9 @@ def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
|
|||
"""
|
||||
parsed_endpoint = parse_api_endpoint(api_endpoint)
|
||||
if not parsed_endpoint:
|
||||
raise ValueError("Invalid ASTRA_DB_API_ENDPOINT")
|
||||
return parsed_endpoint.region in ["us-east-2"]
|
||||
msg = "Invalid ASTRA_DB_API_ENDPOINT"
|
||||
raise ValueError(msg)
|
||||
return parsed_endpoint.region == "us-east-2"
|
||||
|
||||
|
||||
class MockEmbeddings(Embeddings):
|
||||
|
|
@ -70,15 +68,15 @@ class JSONFlow:
|
|||
if node["data"]["type"] == component_type:
|
||||
result.append(node["id"])
|
||||
if not result:
|
||||
raise ValueError(
|
||||
f"Component of type {component_type} not found, available types: {', '.join(set(node['data']['type'] for node in self.json['data']['nodes']))}"
|
||||
)
|
||||
msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
|
||||
raise ValueError(msg)
|
||||
return result
|
||||
|
||||
def get_component_by_type(self, component_type):
|
||||
components = self.get_components_by_type(component_type)
|
||||
if len(components) > 1:
|
||||
raise ValueError(f"Multiple components of type {component_type} found")
|
||||
msg = f"Multiple components of type {component_type} found"
|
||||
raise ValueError(msg)
|
||||
return components[0]
|
||||
|
||||
def set_value(self, component_id, key, value):
|
||||
|
|
@ -86,13 +84,15 @@ class JSONFlow:
|
|||
for node in self.json["data"]["nodes"]:
|
||||
if node["id"] == component_id:
|
||||
if key not in node["data"]["node"]["template"]:
|
||||
raise ValueError(f"Component {component_id} does not have input {key}")
|
||||
msg = f"Component {component_id} does not have input {key}"
|
||||
raise ValueError(msg)
|
||||
node["data"]["node"]["template"][key]["value"] = value
|
||||
node["data"]["node"]["template"][key]["load_from_db"] = False
|
||||
done = True
|
||||
break
|
||||
if not done:
|
||||
raise ValueError(f"Component {component_id} not found")
|
||||
msg = f"Component {component_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def download_flow_from_github(name: str, version: str) -> JSONFlow:
|
||||
|
|
@ -105,18 +105,15 @@ def download_flow_from_github(name: str, version: str) -> JSONFlow:
|
|||
|
||||
|
||||
async def run_json_flow(
|
||||
json_flow: JSONFlow, run_input: Optional[Any] = None, session_id: Optional[str] = None
|
||||
json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
graph = Graph.from_payload(json_flow.json)
|
||||
return await run_flow(graph, run_input, session_id)
|
||||
|
||||
|
||||
async def run_flow(graph: Graph, run_input: Optional[Any] = None, session_id: Optional[str] = None) -> dict[str, Any]:
|
||||
async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]:
|
||||
graph.prepare()
|
||||
if run_input:
|
||||
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")]
|
||||
else:
|
||||
graph_run_inputs = []
|
||||
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else []
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
|
||||
|
|
@ -137,23 +134,24 @@ class ComponentInputHandle:
|
|||
|
||||
async def run_single_component(
|
||||
clazz: type,
|
||||
inputs: dict = None,
|
||||
run_input: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
input_type: Optional[str] = "chat",
|
||||
inputs: dict | None = None,
|
||||
run_input: Any | None = None,
|
||||
session_id: str | None = None,
|
||||
input_type: str | None = "chat",
|
||||
) -> dict[str, Any]:
|
||||
user_id = str(uuid.uuid4())
|
||||
flow_id = str(uuid.uuid4())
|
||||
graph = Graph(user_id=user_id, flow_id=flow_id)
|
||||
|
||||
def _add_component(clazz: type, inputs: Optional[dict] = None) -> str:
|
||||
def _add_component(clazz: type, inputs: dict | None = None) -> str:
|
||||
raw_inputs = {}
|
||||
if inputs:
|
||||
for key, value in inputs.items():
|
||||
if not isinstance(value, ComponentInputHandle):
|
||||
raw_inputs[key] = value
|
||||
if isinstance(value, Component):
|
||||
raise ValueError("Component inputs must be wrapped in ComponentInputHandle")
|
||||
msg = "Component inputs must be wrapped in ComponentInputHandle"
|
||||
raise ValueError(msg)
|
||||
component = clazz(**raw_inputs, _user_id=user_id)
|
||||
component_id = graph.add_component(component)
|
||||
if inputs:
|
||||
|
|
@ -165,10 +163,7 @@ async def run_single_component(
|
|||
|
||||
component_id = _add_component(clazz, inputs)
|
||||
graph.prepare()
|
||||
if run_input:
|
||||
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)]
|
||||
else:
|
||||
graph_run_inputs = []
|
||||
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else []
|
||||
|
||||
_, _ = await run_graph_internal(
|
||||
graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id]
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from rich import print
|
|||
class NameTest(FastHttpUser):
|
||||
wait_time = between(1, 5)
|
||||
|
||||
with Path("names.txt").open() as file:
|
||||
names = [line.strip() for line in file.readlines()]
|
||||
with Path("names.txt").open(encoding="utf-8") as file:
|
||||
names = [line.strip() for line in file]
|
||||
|
||||
headers: dict = {}
|
||||
|
||||
|
|
@ -28,8 +28,9 @@ class NameTest(FastHttpUser):
|
|||
print(f"Poll Response: {response.js}")
|
||||
if status == "SUCCESS":
|
||||
return response.js.get("result")
|
||||
elif status in ["FAILURE", "REVOKED"]:
|
||||
raise ValueError(f"Task failed with status: {status}")
|
||||
if status in {"FAILURE", "REVOKED"}:
|
||||
msg = f"Task failed with status: {status}"
|
||||
raise ValueError(msg)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def process(self, name, flow_id, payload):
|
||||
|
|
@ -45,7 +46,8 @@ class NameTest(FastHttpUser):
|
|||
print(response.js)
|
||||
if response.status_code != 200:
|
||||
response.failure("Process call failed")
|
||||
raise ValueError("Process call failed")
|
||||
msg = "Process call failed"
|
||||
raise ValueError(msg)
|
||||
task_id = response.js.get("id")
|
||||
session_id = response.js.get("session_id")
|
||||
assert task_id, "Inner Task ID not found"
|
||||
|
|
@ -86,7 +88,9 @@ class NameTest(FastHttpUser):
|
|||
a_token = tokens["access_token"]
|
||||
logged_in_headers = {"Authorization": f"Bearer {a_token}"}
|
||||
print("Logged in")
|
||||
json_flow = (Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json").read_text()
|
||||
json_flow = (Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json").read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langflow.api.utils import get_suggestion_message
|
||||
from unittest.mock import patch
|
||||
|
||||
from langflow.api.utils import get_suggestion_message
|
||||
from langflow.services.database.models.flow.utils import get_outdated_components
|
||||
from langflow.utils.version import get_version_info
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,11 @@ async def test_create_variable(client: AsyncClient, body, logged_in_headers):
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_201_CREATED == response.status_code
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert body["name"] == result["name"]
|
||||
assert body["type"] == result["type"]
|
||||
assert body["default_fields"] == result["default_fields"]
|
||||
assert "id" in result.keys()
|
||||
assert "id" in result
|
||||
assert body["value"] != result["value"]
|
||||
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "Variable name already exists" in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty(client:
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "Variable name and value cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClien
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "Variable name cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_400_BAD_REQUEST == response.status_code
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "Variable value cannot be empty" in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, logged_
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_418_IM_A_TEAPOT == response.status_code
|
||||
assert response.status_code == status.HTTP_418_IM_A_TEAPOT
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ async def test_create_variable__Exception(client: AsyncClient, body, logged_in_h
|
|||
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ async def test_read_variables(client: AsyncClient, body, logged_in_headers):
|
|||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert all(name in [r["name"] for r in result] for name in names)
|
||||
|
||||
|
||||
|
|
@ -125,23 +125,22 @@ async def test_read_variables__empty(client: AsyncClient, logged_in_headers):
|
|||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert [] == result
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
async def test_read_variables__(client: AsyncClient, logged_in_headers):
|
||||
generic_message = "Generic error message"
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
with mock.patch("sqlmodel.Session.exec") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m:
|
||||
m.side_effect = Exception(generic_message)
|
||||
|
||||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
response = await client.get("api/v1/variables/", headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
assert generic_message in result["detail"]
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert generic_message in result["detail"]
|
||||
|
||||
assert generic_message in str(exc.value)
|
||||
|
||||
|
|
@ -159,7 +158,7 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
|
|||
response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_200_OK == response.status_code
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert saved["id"] == result["id"]
|
||||
assert saved["name"] != result["name"]
|
||||
assert saved["default_fields"] != result["default_fields"]
|
||||
|
|
@ -173,7 +172,7 @@ async def test_update_variable__Exception(client: AsyncClient, body, logged_in_h
|
|||
response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert status.HTTP_404_NOT_FOUND == response.status_code
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Variable not found" in result["detail"]
|
||||
|
||||
|
||||
|
|
@ -183,7 +182,7 @@ async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
|
|||
saved = response.json()
|
||||
response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)
|
||||
|
||||
assert status.HTTP_204_NO_CONTENT == response.status_code
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("active_user")
|
||||
|
|
@ -192,4 +191,4 @@ async def test_delete_variable__Exception(client: AsyncClient, logged_in_headers
|
|||
|
||||
response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)
|
||||
|
||||
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
|
||||
from langflow.load import run_flow_from_json
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.base.tools.component_tool import ComponentToolkit
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
|
|
@ -76,7 +75,7 @@ def test_component_tool():
|
|||
}
|
||||
assert component_toolkit.component == chat_input
|
||||
|
||||
result = component_tool.invoke(input=dict(input_value="test"))
|
||||
result = component_tool.invoke(input={"input_value": "test"})
|
||||
assert isinstance(result, Message)
|
||||
assert result.get_text() == "test"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.schema.data import Data
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from urllib.parse import urljoin
|
|||
|
||||
import pytest
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
|
||||
from langflow.components.models.OllamaModel import ChatOllamaComponent
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.agents.CrewAIAgent import CrewAIAgentComponent
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.helpers.SequentialTask import SequentialTaskComponent
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import time
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.events.event_manager import EventManager
|
||||
from langflow.schema.log import LoggableType
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from langflow.components.inputs import ChatInput
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import logging
|
|||
from collections import deque
|
||||
|
||||
import pytest
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
|
|
@ -11,6 +9,7 @@ from langflow.components.outputs.TextOutput import TextOutputComponent
|
|||
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.events.event_manager import EventManager
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
|
|
@ -203,5 +202,3 @@ def test_updated_graph_with_prompts():
|
|||
# Extract the vertex IDs for analysis
|
||||
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
|
||||
|
||||
print(f"Execution completed with results: {results_ids}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from langflow.components.helpers.Memory import MemoryComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
|
|
@ -10,6 +10,9 @@ from langflow.graph import Graph
|
|||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.graph.state_model import create_state_model_from_graph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_graph_state_model():
|
||||
session_id = "test_session_id"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data():
|
||||
|
|
@ -23,7 +25,7 @@ def data():
|
|||
def test_to_dict(data):
|
||||
result = RunnableVerticesManager.from_dict(data).to_dict()
|
||||
|
||||
assert all(key in result.keys() for key in data.keys())
|
||||
assert all(key in result for key in data)
|
||||
|
||||
|
||||
def test_from_dict(data):
|
||||
|
|
@ -158,8 +160,8 @@ def test_build_run_map(data):
|
|||
|
||||
manager.build_run_map(predecessor_map, vertices_to_run)
|
||||
|
||||
assert all(v in manager.run_map.keys() for v in ["Z", "X", "Y"])
|
||||
assert "W" not in manager.run_map.keys()
|
||||
assert all(v in manager.run_map for v in ["Z", "X", "Y"])
|
||||
assert "W" not in manager.run_map
|
||||
|
||||
|
||||
def test_update_vertex_run_state(data):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph.graph import utils
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import copy
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.graph.utils import (
|
||||
find_last_node,
|
||||
|
|
@ -59,7 +58,7 @@ def sample_nodes():
|
|||
|
||||
|
||||
def get_node_by_type(graph, node_type: type[Vertex]) -> Vertex | None:
|
||||
"""Get a node by type"""
|
||||
"""Get a node by type."""
|
||||
return next((node for node in graph.vertices if isinstance(node, node_type)), None)
|
||||
|
||||
|
||||
|
|
@ -131,10 +130,10 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
|
|||
node_data = group_node["data"]["node"]
|
||||
assert node_data.get("flow") is not None
|
||||
template_data = node_data["template"]
|
||||
assert any("openai_api_key" in key for key in template_data.keys())
|
||||
assert any("openai_api_key" in key for key in template_data)
|
||||
# Get the openai_api_key dict
|
||||
openai_api_key = next(
|
||||
(template_data[key] for key in template_data.keys() if "openai_api_key" in key),
|
||||
(template_data[key] for key in template_data if "openai_api_key" in key),
|
||||
None,
|
||||
)
|
||||
assert openai_api_key is not None
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@
|
|||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
|
||||
|
||||
class TestBuildModelFromSchema:
|
||||
# Successfully creates a Pydantic model from a valid schema
|
||||
|
|
@ -115,9 +114,9 @@ class TestBuildModelFromSchema:
|
|||
{"name": "field3", "type": "list", "default": None, "description": "Field 3 description", "multiple": True},
|
||||
]
|
||||
model = build_model_from_schema(schema)
|
||||
assert model.model_fields["field1"].default == PydanticUndefined # noqa: E711
|
||||
assert model.model_fields["field2"].default == PydanticUndefined # noqa: E711
|
||||
assert model.model_fields["field3"].default == PydanticUndefined # noqa: E711
|
||||
assert model.model_fields["field1"].default == PydanticUndefined
|
||||
assert model.model_fields["field2"].default == PydanticUndefined
|
||||
assert model.model_fields["field3"].default == PydanticUndefined
|
||||
|
||||
# Checks for proper handling of nested list and dict types
|
||||
def test_nested_list_and_dict_types_handling(self):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import operator
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.components.helpers.Memory import MemoryComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
|
|
@ -9,7 +10,9 @@ from langflow.components.outputs.ChatOutput import ChatOutput
|
|||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.graph.schema import GraphDump
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.graph.schema import GraphDump
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -100,7 +103,7 @@ def test_memory_chatbot_dump_components_and_edges(memory_chatbot_graph: Graph):
|
|||
edges = data_dict["edges"]
|
||||
|
||||
# sort the nodes by id
|
||||
nodes = sorted(nodes, key=lambda x: x["id"])
|
||||
nodes = sorted(nodes, key=operator.itemgetter("id"))
|
||||
|
||||
# Check each node
|
||||
assert nodes[0]["data"]["type"] == "ChatInput"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import copy
|
||||
import operator
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.components.data.File import FileComponent
|
||||
from langflow.components.embeddings.OpenAIEmbeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.helpers.ParseData import ParseDataComponent
|
||||
|
|
@ -40,8 +40,7 @@ def ingestion_graph():
|
|||
vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
|
||||
vector_store.set_on_output(name="search_results", value=[Data(text="This is a test file.")], cache=True)
|
||||
|
||||
ingestion_graph = Graph(file_component, vector_store)
|
||||
return ingestion_graph
|
||||
return Graph(file_component, vector_store)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -89,8 +88,7 @@ def rag_graph():
|
|||
chat_output = ChatOutput(_id="chatoutput-123")
|
||||
chat_output.set(input_value=openai_component.text_response)
|
||||
|
||||
graph = Graph(start=chat_input, end=chat_output)
|
||||
return graph
|
||||
return Graph(start=chat_input, end=chat_output)
|
||||
|
||||
|
||||
def test_vector_store_rag(ingestion_graph, rag_graph):
|
||||
|
|
@ -111,7 +109,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
|
|||
"rag-vector-store-123",
|
||||
"openai-embeddings-124",
|
||||
]
|
||||
for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]):
|
||||
for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]:
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
|
@ -134,7 +132,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph):
|
|||
ingestion_edges = ingestion_data["edges"]
|
||||
|
||||
# Sort nodes by id to check components
|
||||
ingestion_nodes = sorted(ingestion_nodes, key=lambda x: x["id"])
|
||||
ingestion_nodes = sorted(ingestion_nodes, key=operator.itemgetter("id"))
|
||||
|
||||
# Check components in the ingestion graph
|
||||
assert ingestion_nodes[0]["data"]["type"] == "File"
|
||||
|
|
@ -172,7 +170,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph):
|
|||
rag_edges = rag_data["edges"]
|
||||
|
||||
# Sort nodes by id to check components
|
||||
rag_nodes = sorted(rag_nodes, key=lambda x: x["id"])
|
||||
rag_nodes = sorted(rag_nodes, key=operator.itemgetter("id"))
|
||||
|
||||
# Check components in the RAG graph
|
||||
assert rag_nodes[0]["data"]["type"] == "ChatInput"
|
||||
|
|
@ -235,7 +233,7 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
|
|||
combined_edges = combined_data["edges"]
|
||||
|
||||
# Sort nodes by id to check components
|
||||
combined_nodes = sorted(combined_nodes, key=lambda x: x["id"])
|
||||
combined_nodes = sorted(combined_nodes, key=operator.itemgetter("id"))
|
||||
|
||||
# Expected components in the combined graph (both ingestion and RAG nodes)
|
||||
expected_nodes = sorted(
|
||||
|
|
@ -252,10 +250,10 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
|
|||
{"id": "prompt-123", "type": "Prompt"},
|
||||
{"id": "rag-vector-store-123", "type": "AstraDB"},
|
||||
],
|
||||
key=lambda x: x["id"],
|
||||
key=operator.itemgetter("id"),
|
||||
)
|
||||
|
||||
for expected_node, combined_node in zip(expected_nodes, combined_nodes):
|
||||
for expected_node, combined_node in zip(expected_nodes, combined_nodes, strict=True):
|
||||
assert combined_node["data"]["type"] == expected_node["type"]
|
||||
assert combined_node["id"] == expected_node["id"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.inputs.inputs import (
|
||||
BoolInput,
|
||||
CodeInput,
|
||||
|
|
@ -24,6 +22,7 @@ from langflow.inputs.inputs import (
|
|||
)
|
||||
from langflow.inputs.utils import instantiate_input
|
||||
from langflow.schema.message import Message
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def test_table_input_valid():
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
def test_create_input_schema():
|
||||
from langflow.io.schema import create_input_schema
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Generated by qodo Gen
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.schema.table import Column, FormatterType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import pytest
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ from unittest.mock import patch
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from langflow.services.database.models.variable.model import VariableUpdate
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
|
||||
from langflow.services.variable.service import DatabaseVariableService
|
||||
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -44,8 +43,8 @@ def test_initialize_user_variables__create_and_update(service, session):
|
|||
value = service.get_variable(user_id, name, field, session=session)
|
||||
assert value == env_vars[name]
|
||||
|
||||
assert all([i in variables for i in good_vars.keys()])
|
||||
assert all([i not in variables for i in bad_vars.keys()])
|
||||
assert all(i in variables for i in good_vars)
|
||||
assert all(i not in variables for i in bad_vars)
|
||||
|
||||
|
||||
def test_initialize_user_variables__not_found_variable(service, session):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from langflow.services.database.models.api_key import ApiKeyCreate
|
||||
|
||||
|
||||
|
|
@ -30,7 +29,8 @@ async def test_create_api_key(client: AsyncClient, logged_in_headers):
|
|||
response = await client.post("api/v1/api_key/", json={"name": api_key_name}, headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
assert "name" in data
|
||||
assert data["name"] == api_key_name
|
||||
assert "api_key" in data
|
||||
assert "**" not in data["api_key"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from orjson import orjson
|
||||
|
||||
from langflow.memory import get_messages
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowUpdate
|
||||
from orjson import orjson
|
||||
|
||||
|
||||
async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers):
|
||||
|
|
@ -82,7 +81,8 @@ async def consume_and_assert_stream(r):
|
|||
elif count == 5:
|
||||
assert parsed["event"] == "end"
|
||||
else:
|
||||
raise ValueError(f"Unexpected line: {line}")
|
||||
msg = f"Unexpected line: {line}"
|
||||
raise ValueError(msg)
|
||||
count += 1
|
||||
|
||||
|
||||
|
|
@ -92,5 +92,4 @@ async def _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers):
|
|||
vector_store = FlowCreate(name="Flow", description="description", data=data, endpoint_name="f")
|
||||
response = await client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers)
|
||||
response.raise_for_status()
|
||||
flow_id = response.json()["id"]
|
||||
return flow_id
|
||||
return response.json()["id"]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.__main__ import app
|
||||
from langflow.services import deps
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.custom import Component, CustomComponent
|
||||
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
|
||||
|
|
@ -14,7 +13,7 @@ from langflow.custom.utils import build_custom_component_template
|
|||
|
||||
@pytest.fixture
|
||||
def code_component_with_multiple_outputs():
|
||||
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text()
|
||||
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text(encoding="utf-8")
|
||||
return Component(_code=code)
|
||||
|
||||
|
||||
|
|
@ -39,25 +38,20 @@ class YourComponent(CustomComponent):
|
|||
|
||||
|
||||
def test_code_parser_init():
|
||||
"""
|
||||
Test the initialization of the CodeParser class.
|
||||
"""
|
||||
"""Test the initialization of the CodeParser class."""
|
||||
parser = CodeParser(code_default)
|
||||
assert parser.code == code_default
|
||||
|
||||
|
||||
def test_code_parser_get_tree():
|
||||
"""
|
||||
Test the __get_tree method of the CodeParser class.
|
||||
"""
|
||||
"""Test the __get_tree method of the CodeParser class."""
|
||||
parser = CodeParser(code_default)
|
||||
tree = parser.get_tree()
|
||||
assert isinstance(tree, ast.AST)
|
||||
|
||||
|
||||
def test_code_parser_syntax_error():
|
||||
"""
|
||||
Test the __get_tree method raises the
|
||||
"""Test the __get_tree method raises the
|
||||
CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
code_syntax_error = "zzz import os"
|
||||
|
|
@ -68,26 +62,21 @@ def test_code_parser_syntax_error():
|
|||
|
||||
|
||||
def test_component_init():
|
||||
"""
|
||||
Test the initialization of the Component class.
|
||||
"""
|
||||
"""Test the initialization of the Component class."""
|
||||
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
assert component._code == code_default
|
||||
assert component._function_entrypoint_name == "build"
|
||||
|
||||
|
||||
def test_component_get_code_tree():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class.
|
||||
"""
|
||||
"""Test the get_code_tree method of the Component class."""
|
||||
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
tree = component.get_code_tree(component._code)
|
||||
assert "imports" in tree
|
||||
|
||||
|
||||
def test_component_code_null_error():
|
||||
"""
|
||||
Test the get_function method raises the
|
||||
"""Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
component = BaseComponent(_code="", _function_entrypoint_name="")
|
||||
|
|
@ -96,9 +85,7 @@ def test_component_code_null_error():
|
|||
|
||||
|
||||
def test_custom_component_init():
|
||||
"""
|
||||
Test the initialization of the CustomComponent class.
|
||||
"""
|
||||
"""Test the initialization of the CustomComponent class."""
|
||||
function_entrypoint_name = "build"
|
||||
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name=function_entrypoint_name)
|
||||
|
|
@ -107,26 +94,21 @@ def test_custom_component_init():
|
|||
|
||||
|
||||
def test_custom_component_build_template_config():
|
||||
"""
|
||||
Test the build_template_config property of the CustomComponent class.
|
||||
"""
|
||||
"""Test the build_template_config property of the CustomComponent class."""
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
config = custom_component.build_template_config()
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
def test_custom_component_get_function():
|
||||
"""
|
||||
Test the get_function property of the CustomComponent class.
|
||||
"""
|
||||
"""Test the get_function property of the CustomComponent class."""
|
||||
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
my_function = custom_component.get_function()
|
||||
assert isinstance(my_function, types.FunctionType)
|
||||
|
||||
|
||||
def test_code_parser_parse_imports_import():
|
||||
"""
|
||||
Test the parse_imports method of the CodeParser
|
||||
"""Test the parse_imports method of the CodeParser
|
||||
class with an import statement.
|
||||
"""
|
||||
parser = CodeParser(code_default)
|
||||
|
|
@ -138,8 +120,7 @@ def test_code_parser_parse_imports_import():
|
|||
|
||||
|
||||
def test_code_parser_parse_imports_importfrom():
|
||||
"""
|
||||
Test the parse_imports method of the CodeParser
|
||||
"""Test the parse_imports method of the CodeParser
|
||||
class with an import from statement.
|
||||
"""
|
||||
parser = CodeParser("from os import path")
|
||||
|
|
@ -151,9 +132,7 @@ def test_code_parser_parse_imports_importfrom():
|
|||
|
||||
|
||||
def test_code_parser_parse_functions():
|
||||
"""
|
||||
Test the parse_functions method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_functions method of the CodeParser class."""
|
||||
parser = CodeParser("def test(): pass")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
|
|
@ -164,9 +143,7 @@ def test_code_parser_parse_functions():
|
|||
|
||||
|
||||
def test_code_parser_parse_classes():
|
||||
"""
|
||||
Test the parse_classes method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_classes method of the CodeParser class."""
|
||||
parser = CodeParser("from langflow.custom import Component\n\nclass Test(Component): pass")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
|
|
@ -177,9 +154,7 @@ def test_code_parser_parse_classes():
|
|||
|
||||
|
||||
def test_code_parser_parse_classes_raises():
|
||||
"""
|
||||
Test the parse_classes method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_classes method of the CodeParser class."""
|
||||
parser = CodeParser("class Test: pass")
|
||||
tree = parser.get_tree()
|
||||
with pytest.raises(TypeError):
|
||||
|
|
@ -189,9 +164,7 @@ def test_code_parser_parse_classes_raises():
|
|||
|
||||
|
||||
def test_code_parser_parse_global_vars():
|
||||
"""
|
||||
Test the parse_global_vars method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_global_vars method of the CodeParser class."""
|
||||
parser = CodeParser("x = 1")
|
||||
tree = parser.get_tree()
|
||||
for node in ast.walk(tree):
|
||||
|
|
@ -202,8 +175,7 @@ def test_code_parser_parse_global_vars():
|
|||
|
||||
|
||||
def test_component_get_function_valid():
|
||||
"""
|
||||
Test the get_function method of the Component
|
||||
"""Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
|
|
@ -212,8 +184,7 @@ def test_component_get_function_valid():
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args():
|
||||
"""
|
||||
Test the get_function_entrypoint_args
|
||||
"""Test the get_function_entrypoint_args
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
|
|
@ -225,28 +196,23 @@ def test_custom_component_get_function_entrypoint_args():
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type():
|
||||
"""
|
||||
Test the get_function_entrypoint_return_type
|
||||
"""Test the get_function_entrypoint_return_type
|
||||
property of the CustomComponent class.
|
||||
"""
|
||||
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
assert return_type == [Document]
|
||||
|
||||
|
||||
def test_custom_component_get_main_class_name():
|
||||
"""
|
||||
Test the get_main_class_name property of the CustomComponent class.
|
||||
"""
|
||||
"""Test the get_main_class_name property of the CustomComponent class."""
|
||||
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
|
||||
class_name = custom_component.get_main_class_name
|
||||
assert class_name == "YourComponent"
|
||||
|
||||
|
||||
def test_custom_component_get_function_valid():
|
||||
"""
|
||||
Test the get_function property of the CustomComponent
|
||||
"""Test the get_function property of the CustomComponent
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
|
|
@ -255,9 +221,7 @@ def test_custom_component_get_function_valid():
|
|||
|
||||
|
||||
def test_code_parser_parse_arg_no_annotation():
|
||||
"""
|
||||
Test the parse_arg method of the CodeParser class without an annotation.
|
||||
"""
|
||||
"""Test the parse_arg method of the CodeParser class without an annotation."""
|
||||
parser = CodeParser("")
|
||||
arg = ast.arg(arg="x", annotation=None)
|
||||
result = parser.parse_arg(arg, None)
|
||||
|
|
@ -266,9 +230,7 @@ def test_code_parser_parse_arg_no_annotation():
|
|||
|
||||
|
||||
def test_code_parser_parse_arg_with_annotation():
|
||||
"""
|
||||
Test the parse_arg method of the CodeParser class with an annotation.
|
||||
"""
|
||||
"""Test the parse_arg method of the CodeParser class with an annotation."""
|
||||
parser = CodeParser("")
|
||||
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
|
||||
result = parser.parse_arg(arg, None)
|
||||
|
|
@ -277,8 +239,7 @@ def test_code_parser_parse_arg_with_annotation():
|
|||
|
||||
|
||||
def test_code_parser_parse_callable_details_no_args():
|
||||
"""
|
||||
Test the parse_callable_details method of the
|
||||
"""Test the parse_callable_details method of the
|
||||
CodeParser class with a function with no arguments.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
|
|
@ -295,9 +256,7 @@ def test_code_parser_parse_callable_details_no_args():
|
|||
|
||||
|
||||
def test_code_parser_parse_assign():
|
||||
"""
|
||||
Test the parse_assign method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_assign method of the CodeParser class."""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
|
||||
result = parser.parse_assign(stmt)
|
||||
|
|
@ -306,9 +265,7 @@ def test_code_parser_parse_assign():
|
|||
|
||||
|
||||
def test_code_parser_parse_ann_assign():
|
||||
"""
|
||||
Test the parse_ann_assign method of the CodeParser class.
|
||||
"""
|
||||
"""Test the parse_ann_assign method of the CodeParser class."""
|
||||
parser = CodeParser("")
|
||||
stmt = ast.AnnAssign(
|
||||
target=ast.Name(id="x", ctx=ast.Store()),
|
||||
|
|
@ -323,8 +280,7 @@ def test_code_parser_parse_ann_assign():
|
|||
|
||||
|
||||
def test_code_parser_parse_function_def_not_init():
|
||||
"""
|
||||
Test the parse_function_def method of the
|
||||
"""Test the parse_function_def method of the
|
||||
CodeParser class with a function that is not __init__.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
|
|
@ -341,8 +297,7 @@ def test_code_parser_parse_function_def_not_init():
|
|||
|
||||
|
||||
def test_code_parser_parse_function_def_init():
|
||||
"""
|
||||
Test the parse_function_def method of the
|
||||
"""Test the parse_function_def method of the
|
||||
CodeParser class with an __init__ function.
|
||||
"""
|
||||
parser = CodeParser("")
|
||||
|
|
@ -359,8 +314,7 @@ def test_code_parser_parse_function_def_init():
|
|||
|
||||
|
||||
def test_component_get_code_tree_syntax_error():
|
||||
"""
|
||||
Test the get_code_tree method of the Component class
|
||||
"""Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
component = BaseComponent(_code="import os as", _function_entrypoint_name="build")
|
||||
|
|
@ -369,8 +323,7 @@ def test_component_get_code_tree_syntax_error():
|
|||
|
||||
|
||||
def test_custom_component_class_template_validation_no_code():
|
||||
"""
|
||||
Test the _class_template_validation method of the CustomComponent class
|
||||
"""Test the _class_template_validation method of the CustomComponent class
|
||||
raises the HTTPException when the code is None.
|
||||
"""
|
||||
custom_component = CustomComponent(_code=None, _function_entrypoint_name="build")
|
||||
|
|
@ -379,8 +332,7 @@ def test_custom_component_class_template_validation_no_code():
|
|||
|
||||
|
||||
def test_custom_component_get_code_tree_syntax_error():
|
||||
"""
|
||||
Test the get_code_tree method of the CustomComponent class
|
||||
"""Test the get_code_tree method of the CustomComponent class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build")
|
||||
|
|
@ -389,8 +341,7 @@ def test_custom_component_get_code_tree_syntax_error():
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_args_no_args():
|
||||
"""
|
||||
Test the get_function_entrypoint_args property of
|
||||
"""Test the get_function_entrypoint_args property of
|
||||
the CustomComponent class with a build method with no arguments.
|
||||
"""
|
||||
my_code = """
|
||||
|
|
@ -405,8 +356,7 @@ class MyMainClass(CustomComponent):
|
|||
|
||||
|
||||
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
|
||||
"""
|
||||
Test the get_function_entrypoint_return_type property of the
|
||||
"""Test the get_function_entrypoint_return_type property of the
|
||||
CustomComponent class with a build method with no return type.
|
||||
"""
|
||||
my_code = """
|
||||
|
|
@ -421,8 +371,7 @@ class MyClass(CustomComponent):
|
|||
|
||||
|
||||
def test_custom_component_get_main_class_name_no_main_class():
|
||||
"""
|
||||
Test the get_main_class_name property of the
|
||||
"""Test the get_main_class_name property of the
|
||||
CustomComponent class when there is no main class.
|
||||
"""
|
||||
my_code = """
|
||||
|
|
@ -435,8 +384,7 @@ def build():
|
|||
|
||||
|
||||
def test_custom_component_build_not_implemented():
|
||||
"""
|
||||
Test the build method of the CustomComponent
|
||||
"""Test the build method of the CustomComponent
|
||||
class raises the NotImplementedError.
|
||||
"""
|
||||
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
|
||||
|
|
@ -453,7 +401,7 @@ def test_build_config_no_code():
|
|||
|
||||
@pytest.fixture
|
||||
def component():
|
||||
yield CustomComponent(
|
||||
return CustomComponent(
|
||||
field_config={
|
||||
"fields": {
|
||||
"llm": {"type": "str"},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.custom.custom_component.custom_component import CustomComponent
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
|
|
@ -11,7 +10,7 @@ from langflow.services.settings.feature_flags import FEATURE_FLAGS
|
|||
|
||||
@pytest.fixture
|
||||
def code_component_with_multiple_outputs():
|
||||
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text()
|
||||
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text(encoding="utf-8")
|
||||
return Component(_code=code)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, ANY
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from langflow.components import data
|
||||
|
||||
|
||||
|
|
@ -163,16 +162,16 @@ def test_directory_without_mocks():
|
|||
directory_component = data.DirectoryComponent()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
(Path(temp_dir) / "test.txt").write_text("test")
|
||||
(Path(temp_dir) / "test.txt").write_text("test", encoding="utf-8")
|
||||
# also add a json file
|
||||
(Path(temp_dir) / "test.json").write_text('{"test": "test"}')
|
||||
(Path(temp_dir) / "test.json").write_text('{"test": "test"}', encoding="utf-8")
|
||||
|
||||
directory_component.set_attributes({"path": str(temp_dir), "use_multithreading": False})
|
||||
results = directory_component.load_directory()
|
||||
assert len(results) == 2
|
||||
values = ["test", '{"test":"test"}']
|
||||
assert all(result.text in values for result in results), [
|
||||
(len(result.text), len(val)) for result, val in zip(results, values)
|
||||
(len(result.text), len(val)) for result, val in zip(results, values, strict=True)
|
||||
]
|
||||
|
||||
# in ../docs/docs/components there are many mdx files
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ from uuid import UUID, uuid4
|
|||
import orjson
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.api.v1.schemas import FlowListCreate, ResultDataResponse
|
||||
from langflow.graph.utils import log_transaction, log_vertex_build
|
||||
from langflow.initial_setup.setup import load_flows_from_directory, load_starter_projects
|
||||
|
|
@ -15,6 +13,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
|
|||
from langflow.services.database.models.folder.model import FolderCreate
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -520,5 +519,5 @@ def test_sqlite_pragmas():
|
|||
with db_service.with_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
assert "wal" == session.exec(text("PRAGMA journal_mode;")).scalar()
|
||||
assert 1 == session.exec(text("PRAGMA synchronous;")).scalar()
|
||||
assert session.exec(text("PRAGMA journal_mode;")).scalar() == "wal"
|
||||
assert session.exec(text("PRAGMA synchronous;")).scalar() == 1
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from uuid import UUID, uuid4
|
|||
import pytest
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
|
||||
from langflow.custom.directory_reader.directory_reader import DirectoryReader
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
|
@ -377,7 +376,7 @@ async def test_invalid_prompt(client: AsyncClient):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_input_variables",
|
||||
("prompt", "expected_input_variables"),
|
||||
[
|
||||
("{color} is my favorite color.", ["color"]),
|
||||
("The weather is {weather} today.", ["weather"]),
|
||||
|
|
@ -442,13 +441,13 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke
|
|||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
assert all(["ChatOutput" in _id for _id in ids])
|
||||
assert all("ChatOutput" in _id for _id in ids)
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]])
|
||||
assert all(name in display_names for name in ["Chat Output"])
|
||||
output_results_has_results = all("results" in output.get("results") for output in outputs_dict.get("outputs"))
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
|
||||
assert all([result is not None for result in inner_results]), (outputs_dict, output_results_has_results)
|
||||
assert all(result is not None for result in inner_results), (outputs_dict, output_results_has_results)
|
||||
|
||||
|
||||
async def test_successful_run_with_output_type_text(client, simple_api_test, created_api_key):
|
||||
|
|
@ -473,12 +472,12 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre
|
|||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
assert all(["ChatOutput" in _id for _id in ids]), ids
|
||||
assert all("ChatOutput" in _id for _id in ids), ids
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]]), display_names
|
||||
assert all(name in display_names for name in ["Chat Output"]), display_names
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
expected_keys = ["message"]
|
||||
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
|
||||
assert all(key in result for result in inner_results for key in expected_keys), outputs_dict
|
||||
|
||||
|
||||
async def test_successful_run_with_output_type_any(client, simple_api_test, created_api_key):
|
||||
|
|
@ -504,12 +503,12 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea
|
|||
assert isinstance(outputs_dict.get("outputs"), list)
|
||||
assert len(outputs_dict.get("outputs")) == 1
|
||||
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
|
||||
assert all(["ChatOutput" in _id or "TextOutput" in _id for _id in ids]), ids
|
||||
assert all("ChatOutput" in _id or "TextOutput" in _id for _id in ids), ids
|
||||
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
|
||||
assert all([name in display_names for name in ["Chat Output"]]), display_names
|
||||
assert all(name in display_names for name in ["Chat Output"]), display_names
|
||||
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
|
||||
expected_keys = ["message"]
|
||||
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
|
||||
assert all(key in result for result in inner_results for key in expected_keys), outputs_dict
|
||||
|
||||
|
||||
async def test_successful_run_with_output_type_debug(client, simple_api_test, created_api_key):
|
||||
|
|
@ -566,7 +565,7 @@ async def test_successful_run_with_input_type_text(client, simple_api_test, crea
|
|||
# Now we check if the input_value is correct
|
||||
# We get text key twice because the output is now a Message
|
||||
assert all(
|
||||
[output.get("results").get("text").get("text") == "value1" for output in text_input_outputs]
|
||||
output.get("results").get("text").get("text") == "value1" for output in text_input_outputs
|
||||
), text_input_outputs
|
||||
|
||||
|
||||
|
|
@ -599,7 +598,7 @@ async def test_successful_run_with_input_type_chat(client: AsyncClient, simple_a
|
|||
assert len(chat_input_outputs) == 1
|
||||
# Now we check if the input_value is correct
|
||||
assert all(
|
||||
[output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs]
|
||||
output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs
|
||||
), chat_input_outputs
|
||||
|
||||
|
||||
|
|
@ -653,7 +652,7 @@ async def test_successful_run_with_input_type_any(client, simple_api_test, creat
|
|||
result_dict.get("message", result_dict.get("text")) for result_dict in all_result_dicts
|
||||
]
|
||||
assert all(
|
||||
[message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts]
|
||||
message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts
|
||||
), any_input_outputs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,9 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.deps import get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -26,7 +25,7 @@ def mock_storage_service():
|
|||
return service
|
||||
|
||||
|
||||
@pytest.fixture(name="files_client", scope="function")
|
||||
@pytest.fixture(name="files_client")
|
||||
async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service):
|
||||
# Set the database url to a test database
|
||||
if "noclient" in request.keywords:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.template.field.base import Input
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from langflow.custom.utils import build_custom_component_template
|
|||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
# def test_update_data_component():
|
||||
# # Arrange
|
||||
# update_data_component = helpers.UpdateDataComponent()
|
||||
|
|
@ -36,7 +35,7 @@ from langflow.schema.message import Message
|
|||
def test_uuid_generator_component():
|
||||
# Arrange
|
||||
uuid_generator_component = helpers.IDGeneratorComponent()
|
||||
uuid_generator_component._code = Path(helpers.IDGenerator.__file__).read_text()
|
||||
uuid_generator_component._code = Path(helpers.IDGenerator.__file__).read_text(encoding="utf-8")
|
||||
|
||||
frontend_node, _ = build_custom_component_template(uuid_generator_component)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.custom.directory_reader.utils import build_custom_component_list_from_path
|
||||
from langflow.initial_setup.setup import (
|
||||
STARTER_FOLDER_NAME,
|
||||
|
|
@ -14,6 +12,7 @@ from langflow.initial_setup.setup import (
|
|||
from langflow.interface.types import aget_all_types_dict
|
||||
from langflow.services.database.models.folder.model import Folder
|
||||
from langflow.services.deps import session_scope
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
def test_load_starter_projects():
|
||||
|
|
@ -101,10 +100,11 @@ async def test_create_or_update_starter_projects():
|
|||
|
||||
|
||||
def find_componeny_by_name(components, name):
|
||||
for category, children in components.items():
|
||||
for children in components.values():
|
||||
if name in children:
|
||||
return children[name]
|
||||
raise ValueError(f"Component {name} not found in components")
|
||||
msg = f"Component {name} not found in components"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def set_value(component, input_name, value):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from kubernetes.client import V1ObjectMeta, V1Secret
|
||||
from base64 import b64encode
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from kubernetes.client import V1ObjectMeta, V1Secret
|
||||
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from langflow.graph import Graph
|
|||
from langflow.initial_setup.setup import load_starter_projects
|
||||
from langflow.load import load_flow_from_json
|
||||
|
||||
|
||||
# TODO: UPDATE BASIC EXAMPLE
|
||||
# def test_load_flow_from_json():
|
||||
# """Test loading a flow from a json file"""
|
||||
|
|
@ -20,9 +19,8 @@ from langflow.load import load_flow_from_json
|
|||
|
||||
|
||||
def test_load_flow_from_json_object():
|
||||
"""Test loading a flow from a json file and applying tweaks"""
|
||||
_, projects = zip(*load_starter_projects())
|
||||
project = projects[0]
|
||||
"""Test loading a flow from a json file and applying tweaks."""
|
||||
project = load_starter_projects()[0][1]
|
||||
loaded = load_flow_from_json(project)
|
||||
assert loaded is not None
|
||||
assert isinstance(loaded, Graph)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import pytest
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langflow.logging.logger import SizedLogBuffer
|
||||
|
||||
|
||||
|
|
@ -27,8 +28,8 @@ def test_write(sized_log_buffer):
|
|||
sized_log_buffer.max = 1 # Set max size to 1 for testing
|
||||
sized_log_buffer.write(message)
|
||||
assert len(sized_log_buffer.buffer) == 1
|
||||
assert 1625097600124 == sized_log_buffer.buffer[0][0]
|
||||
assert "Test log" == sized_log_buffer.buffer[0][1]
|
||||
assert sized_log_buffer.buffer[0][0] == 1625097600124
|
||||
assert sized_log_buffer.buffer[0][1] == "Test log"
|
||||
|
||||
|
||||
def test_write_overflow(sized_log_buffer):
|
||||
|
|
@ -38,8 +39,8 @@ def test_write_overflow(sized_log_buffer):
|
|||
sized_log_buffer.write(message)
|
||||
|
||||
assert len(sized_log_buffer.buffer) == 2
|
||||
assert 1625097601000 == sized_log_buffer.buffer[0][0]
|
||||
assert 1625097602000 == sized_log_buffer.buffer[1][0]
|
||||
assert sized_log_buffer.buffer[0][0] == 1625097601000
|
||||
assert sized_log_buffer.buffer[1][0] == 1625097602000
|
||||
|
||||
|
||||
def test_len(sized_log_buffer):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.deps import session_scope
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
|
@ -10,17 +9,16 @@ from langflow.services.deps import session_scope
|
|||
from langflow.services.tracing.utils import convert_to_langchain_type
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def created_message():
|
||||
with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
return message_read
|
||||
return MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
messages = [
|
||||
|
|
@ -30,10 +28,7 @@ def created_messages(session):
|
|||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
messagetables = add_messagetables(messagetables, session)
|
||||
messages_read = [
|
||||
MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables
|
||||
]
|
||||
return messages_read
|
||||
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("client")
|
||||
|
|
@ -87,10 +82,10 @@ def test_convert_to_langchain(method_name):
|
|||
def convert(value):
|
||||
if method_name == "message":
|
||||
return value.to_lc_message()
|
||||
elif method_name == "convert_to_langchain_type":
|
||||
if method_name == "convert_to_langchain_type":
|
||||
return convert_to_langchain_type(value)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method_name}")
|
||||
msg = f"Invalid method: {method_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
|
||||
assert lc_message.content == "Test message 1"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from uuid import UUID
|
|||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from langflow.memory import add_messagetables
|
||||
|
||||
# Assuming you have these imports available
|
||||
|
|
@ -11,17 +10,16 @@ from langflow.services.database.models.message.model import MessageTable
|
|||
from langflow.services.deps import session_scope
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
async def created_message():
|
||||
with session_scope() as session:
|
||||
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
|
||||
messagetable = MessageTable.model_validate(message, from_attributes=True)
|
||||
messagetables = add_messagetables([messagetable], session)
|
||||
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
return message_read
|
||||
return MessageRead.model_validate(messagetables[0], from_attributes=True)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def created_messages(session):
|
||||
with session_scope() as session:
|
||||
messages = [
|
||||
|
|
@ -30,9 +28,7 @@ def created_messages(session):
|
|||
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
|
||||
]
|
||||
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
|
||||
message_list = add_messagetables(messagetables, session)
|
||||
|
||||
return message_list
|
||||
return add_messagetables(messagetables, session)
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.deps import get_session_service
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from collections.abc import Sequence as SequenceABC
|
||||
from types import NoneType
|
||||
from typing import Union
|
||||
|
||||
from langflow.schema.data import Data
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.schema.data import Data
|
||||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
from collections.abc import Sequence as SequenceABC
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestInput:
|
||||
|
|
@ -68,8 +67,8 @@ class TestInput:
|
|||
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType}
|
||||
|
||||
# Handling complex nested structures
|
||||
assert set(post_process_type(Union[SequenceABC[Union[int, str]], list[float]])) == {int, str, float}
|
||||
assert set(post_process_type(Union[Union[Union[int, list[str]], list[float]], str])) == {int, str, float}
|
||||
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float}
|
||||
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float}
|
||||
|
||||
# Non-generic types should return as is
|
||||
assert set(post_process_type(dict)) == {dict}
|
||||
|
|
@ -87,12 +86,12 @@ class TestInput:
|
|||
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)}
|
||||
|
||||
# Multiple Data types combined
|
||||
assert set(post_process_type(Union[Data, Union[str, float]])) == {Data, str, float}
|
||||
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float}
|
||||
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str}
|
||||
|
||||
# Testing with nested unions and lists
|
||||
assert set(post_process_type(Union[list[Data], list[Union[int, str]]])) == {Data, int, str}
|
||||
assert set(post_process_type(Data | list[Union[float, str]])) == {Data, float, str}
|
||||
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str}
|
||||
assert set(post_process_type(Data | list[float | str])) == {Data, float, str}
|
||||
|
||||
def test_input_to_dict(self):
|
||||
input_obj = Input(field_type="str")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
import threading
|
||||
from langflow.services.telemetry.opentelemetry import OpenTelemetry
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from langflow.services.telemetry.opentelemetry import OpenTelemetry
|
||||
|
||||
fixed_labels = {"flow_id": "this_flow_id", "service": "this", "user": "that"}
|
||||
|
||||
|
|
@ -72,9 +72,9 @@ def test_missing_labels(opentelemetry_instance):
|
|||
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
|
||||
opentelemetry_instance.up_down_counter("num_files_uploaded", 1, None)
|
||||
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
|
||||
opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels=dict())
|
||||
opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels={})
|
||||
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
|
||||
opentelemetry_instance.observe_histogram("num_files_uploaded", 1, dict())
|
||||
opentelemetry_instance.observe_histogram("num_files_uploaded", 1, {})
|
||||
|
||||
|
||||
def test_multithreaded_singleton():
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from pydantic import BaseModel
|
|||
|
||||
# Dummy classes for testing purposes
|
||||
class Parent(BaseModel):
|
||||
"""Parent Class"""
|
||||
"""Parent Class."""
|
||||
|
||||
parent_field: str
|
||||
|
||||
|
||||
class Child(Parent):
|
||||
"""Child Class"""
|
||||
"""Child Class."""
|
||||
|
||||
child_field: int
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ def test_get_default_factory():
|
|||
return "default_value"
|
||||
|
||||
# Add dummy_function to your_module
|
||||
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
|
||||
importlib.import_module(module_name).dummy_function = dummy_function
|
||||
|
||||
default_value = get_default_factory(module_name, function_repr)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@ from datetime import datetime
|
|||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -86,7 +85,7 @@ async def test_user_waiting_for_approval(client):
|
|||
with session_getter(get_db_service()) as session:
|
||||
existing_user = session.exec(select(User).where(User.username == username)).first()
|
||||
if existing_user:
|
||||
print(f"User {username} still exists after the test. This is expected.")
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"User {username} does not exist after the test. This is unexpected.")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ from pathlib import Path
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import MissingSchema
|
||||
|
||||
from langflow.utils.validate import create_function, execute_function, extract_function_name, validate_code
|
||||
from requests.exceptions import MissingSchema
|
||||
|
||||
|
||||
def test_create_function():
|
||||
|
|
@ -99,6 +98,5 @@ import requests
|
|||
def my_function(x):
|
||||
return requests.get(x).text
|
||||
"""
|
||||
with mock.patch("requests.get", side_effect=MissingSchema):
|
||||
with pytest.raises(MissingSchema):
|
||||
execute_function(code, "my_function", "invalid_url")
|
||||
with mock.patch("requests.get", side_effect=MissingSchema), pytest.raises(MissingSchema):
|
||||
execute_function(code, "my_function", "invalid_url")
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ def test_version():
|
|||
|
||||
|
||||
def test_compute_main():
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.post0")
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.a1")
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.b112")
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.rc0")
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.dev9")
|
||||
assert "1.0.10" == _compute_non_prerelease_version("1.0.10")
|
||||
assert _compute_non_prerelease_version("1.0.10.post0") == "1.0.10"
|
||||
assert _compute_non_prerelease_version("1.0.10.a1") == "1.0.10"
|
||||
assert _compute_non_prerelease_version("1.0.10.b112") == "1.0.10"
|
||||
assert _compute_non_prerelease_version("1.0.10.rc0") == "1.0.10"
|
||||
assert _compute_non_prerelease_version("1.0.10.dev9") == "1.0.10"
|
||||
assert _compute_non_prerelease_version("1.0.10") == "1.0.10"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.utils.connection_string_parser import transform_connection_string
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"connection_string, expected",
|
||||
("connection_string", "expected"),
|
||||
[
|
||||
("protocol:user:password@host", "protocol:user:password@host"),
|
||||
("protocol:user@host", "protocol:user@host"),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.base.data.utils import format_directory_path
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_path, expected",
|
||||
("input_path", "expected"),
|
||||
[
|
||||
# Test case 1: Standard path with no newlines (no change expected)
|
||||
("/home/user/documents/file.txt", "/home/user/documents/file.txt"),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from langflow.base.data.utils import format_directory_path
|
||||
import pytest
|
||||
from langflow.base.data.utils import format_directory_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_path, expected",
|
||||
("input_path", "expected"),
|
||||
[
|
||||
# Test case 1: Standard path with no newlines
|
||||
("/home/user/documents/file.txt", "/home/user/documents/file.txt"),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import math
|
||||
|
||||
import pytest
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data, max_length, expected",
|
||||
("input_data", "max_length", "expected"),
|
||||
[
|
||||
# Test case 1: String shorter than max_length
|
||||
("short string", 20, "short string"),
|
||||
|
|
@ -20,7 +22,7 @@ from langflow.utils.util_strings import truncate_long_strings
|
|||
# Test case 7: Integer input
|
||||
(12345, 3, 12345),
|
||||
# Test case 8: Float input
|
||||
(3.14159, 4, 3.14159),
|
||||
(math.pi, 4, math.pi),
|
||||
# Test case 9: Boolean input
|
||||
(True, 2, True),
|
||||
# Test case 10: None input
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from langflow.utils.util_strings import truncate_long_strings
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
import pytest
|
||||
from langflow.utils.constants import MAX_TEXT_LENGTH
|
||||
from langflow.utils.util_strings import truncate_long_strings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data, max_length, expected",
|
||||
("input_data", "max_length", "expected"),
|
||||
[
|
||||
# Test case 1: Simple string truncation
|
||||
({"key": "a" * 100}, 10, {"key": "a" * 10 + "..."}),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue