ref: Auto-fix ruff rules in tests (#4154)

This commit is contained in:
Christophe Bornet 2024-10-16 17:42:36 +02:00 committed by GitHub
commit 45c8f98692
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 359 additions and 456 deletions

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python
"""
Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py
"""
"""Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py."""
import sys
@ -24,13 +22,15 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version:
elif build_type == "main":
url = PYPI_LANGFLOW_NIGHTLY_URL if is_nightly else PYPI_LANGFLOW_URL
else:
raise ValueError(f"Invalid build type: {build_type}")
msg = f"Invalid build type: {build_type}"
raise ValueError(msg)
res = requests.get(url)
try:
version_str = res.json()["info"]["version"]
except Exception as e:
raise RuntimeError("Got unexpected response from PyPI", e)
msg = "Got unexpected response from PyPI"
raise RuntimeError(msg, e)
return Version(version_str)
@ -75,7 +75,8 @@ def create_tag(build_type: str):
if __name__ == "__main__":
if len(sys.argv) != 2:
raise Exception("Specify base or main")
msg = "Specify base or main"
raise Exception(msg)
build_type = sys.argv[1]
tag = create_tag(build_type)

View file

@ -1,5 +1,5 @@
import sys
import re
import sys
from pathlib import Path
import packaging.version
@ -10,39 +10,38 @@ BASE_DIR = Path(__file__).parent.parent.parent
def update_base_dep(pyproject_path: str, new_version: str) -> None:
"""Update the langflow-base dependency in pyproject.toml."""
filepath = BASE_DIR / pyproject_path
content = filepath.read_text()
content = filepath.read_text(encoding="utf-8")
replacement = f'langflow-base-nightly = "{new_version}"'
# Updates the pattern for poetry
pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}')
if not pattern.search(content):
raise Exception(f'langflow-base poetry dependency not found in "{filepath}"')
msg = f'langflow-base poetry dependency not found in "{filepath}"'
raise Exception(msg)
content = pattern.sub(replacement, content)
filepath.write_text(content)
filepath.write_text(content, encoding="utf-8")
def verify_pep440(version):
"""
Verify if version is PEP440 compliant.
"""Verify if version is PEP440 compliant.
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
"""
try:
return packaging.version.Version(version)
except packaging.version.InvalidVersion as e:
raise e
except packaging.version.InvalidVersion:
raise
def main() -> None:
if len(sys.argv) != 2:
raise Exception("New version not specified")
msg = "New version not specified"
raise Exception(msg)
base_version = sys.argv[1]
# Strip "v" prefix from version if present
if base_version.startswith("v"):
base_version = base_version[1:]
base_version = base_version.removeprefix("v")
verify_pep440(base_version)
update_base_dep("pyproject.toml", base_version)

View file

@ -1,5 +1,5 @@
import sys
import re
import sys
from pathlib import Path
BASE_DIR = Path(__file__).parent.parent.parent
@ -8,22 +8,23 @@ BASE_DIR = Path(__file__).parent.parent.parent
def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None:
"""Update the project name in pyproject.toml."""
filepath = BASE_DIR / pyproject_path
content = filepath.read_text()
content = filepath.read_text(encoding="utf-8")
# Regex to match the version line under [tool.poetry]
pattern = re.compile(r'(?<=^name = ")[^"]+(?=")', re.MULTILINE)
if not pattern.search(content):
raise Exception(f'Project name not found in "{filepath}"')
msg = f'Project name not found in "{filepath}"'
raise Exception(msg)
content = pattern.sub(new_project_name, content)
filepath.write_text(content)
filepath.write_text(content, encoding="utf-8")
def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
"""Update the langflow-base dependency in pyproject.toml."""
filepath = BASE_DIR / pyproject_path
content = filepath.read_text()
content = filepath.read_text(encoding="utf-8")
if new_project_name == "langflow-nightly":
pattern = re.compile(r"langflow = \{ workspace = true \}")
@ -32,18 +33,21 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None:
pattern = re.compile(r"langflow-base = \{ workspace = true \}")
replacement = "langflow-base-nightly = { workspace = true }"
else:
raise ValueError(f"Invalid project name: {new_project_name}")
msg = f"Invalid project name: {new_project_name}"
raise ValueError(msg)
# Updates the dependency name for uv
if not pattern.search(content):
raise Exception(f"{replacement} uv dependency not found in {filepath}")
msg = f"{replacement} uv dependency not found in {filepath}"
raise Exception(msg)
content = pattern.sub(replacement, content)
filepath.write_text(content)
filepath.write_text(content, encoding="utf-8")
def main() -> None:
if len(sys.argv) != 3:
raise Exception("Must specify project name and build type, e.g. langflow-nightly base")
msg = "Must specify project name and build type, e.g. langflow-nightly base"
raise Exception(msg)
new_project_name = sys.argv[1]
build_type = sys.argv[2]
@ -54,7 +58,8 @@ def main() -> None:
update_pyproject_name("pyproject.toml", new_project_name)
update_uv_dep("pyproject.toml", new_project_name)
else:
raise ValueError(f"Invalid build type: {build_type}")
msg = f"Invalid build type: {build_type}"
raise ValueError(msg)
if __name__ == "__main__":

View file

@ -1,5 +1,5 @@
import sys
import re
import sys
from pathlib import Path
import packaging.version
@ -10,40 +10,39 @@ BASE_DIR = Path(__file__).parent.parent.parent
def update_pyproject_version(pyproject_path: str, new_version: str) -> None:
"""Update the version in pyproject.toml."""
filepath = BASE_DIR / pyproject_path
content = filepath.read_text()
content = filepath.read_text(encoding="utf-8")
# Regex to match the version line under [tool.poetry]
pattern = re.compile(r'(?<=^version = ")[^"]+(?=")', re.MULTILINE)
if not pattern.search(content):
raise Exception(f'Project version not found in "{filepath}"')
msg = f'Project version not found in "{filepath}"'
raise Exception(msg)
content = pattern.sub(new_version, content)
filepath.write_text(content)
filepath.write_text(content, encoding="utf-8")
def verify_pep440(version):
"""
Verify if version is PEP440 compliant.
"""Verify if version is PEP440 compliant.
https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191
"""
try:
return packaging.version.Version(version)
except packaging.version.InvalidVersion as e:
raise e
except packaging.version.InvalidVersion:
raise
def main() -> None:
if len(sys.argv) != 3:
raise Exception("New version not specified")
msg = "New version not specified"
raise Exception(msg)
new_version = sys.argv[1]
# Strip "v" prefix from version if present
if new_version.startswith("v"):
new_version = new_version[1:]
new_version = new_version.removeprefix("v")
build_type = sys.argv[2]
@ -54,7 +53,8 @@ def main() -> None:
elif build_type == "main":
update_pyproject_version("pyproject.toml", new_version)
else:
raise ValueError(f"Invalid build type: {build_type}")
msg = f"Invalid build type: {build_type}"
raise ValueError(msg)
if __name__ == "__main__":

View file

@ -1,5 +1,5 @@
import sys
import re
import sys
from pathlib import Path
BASE_DIR = Path(__file__).parent.parent.parent
@ -7,30 +7,31 @@ BASE_DIR = Path(__file__).parent.parent.parent
def update_uv_dep(base_version: str) -> None:
"""Update the langflow-base dependency in pyproject.toml."""
pyproject_path = BASE_DIR / "pyproject.toml"
# Read the pyproject.toml file content
content = pyproject_path.read_text()
content = pyproject_path.read_text(encoding="utf-8")
# For the main project, update the langflow-base dependency in the UV section
pattern = re.compile(r'(dependencies\s*=\s*\[\s*\n\s*)("langflow-base==[\d.]+")')
replacement = r'\1"langflow-base-nightly=={}"'.format(base_version)
replacement = rf'\1"langflow-base-nightly=={base_version}"'
# Check if the pattern is found
if not pattern.search(content):
raise Exception(f"{pattern} UV dependency not found in {pyproject_path}")
msg = f"{pattern} UV dependency not found in {pyproject_path}"
raise Exception(msg)
# Replace the matched pattern with the new one
content = pattern.sub(replacement, content)
# Write the updated content back to the file
pyproject_path.write_text(content)
pyproject_path.write_text(content, encoding="utf-8")
def main() -> None:
if len(sys.argv) != 2:
raise Exception("specify base version")
msg = "specify base version"
raise Exception(msg)
base_version = sys.argv[1]
base_version = base_version.lstrip("v")
update_uv_dep(base_version)

View file

@ -6,6 +6,7 @@
# ]
# ///
import argparse
import sys
from huggingface_hub import HfApi, list_models
from rich import print
@ -23,11 +24,11 @@ space = parsed_args.space
if not space:
print("Please provide a space to restart.")
exit()
sys.exit()
if not parsed_args.token:
print("Please provide an API token.")
exit()
sys.exit()
# Or configure a HfApi client
hf_api = HfApi(

View file

@ -1,6 +1,8 @@
import contextlib
def get_version() -> str:
"""
Retrieves the version of the package from a possible list of package names.
"""Retrieves the version of the package from a possible list of package names.
This accounts for after package names are updated for -nightly builds.
Returns:
@ -19,20 +21,18 @@ def get_version() -> str:
]
_version = None
for pkg_name in pkg_names:
try:
with contextlib.suppress(ImportError, metadata.PackageNotFoundError):
_version = metadata.version(pkg_name)
except (ImportError, metadata.PackageNotFoundError):
pass
if _version is None:
raise ValueError(f"Package not found from options {pkg_names}")
msg = f"Package not found from options {pkg_names}"
raise ValueError(msg)
return _version
def is_pre_release(v: str) -> bool:
"""
Returns a boolean indicating whether the version is a pre-release version,
"""Returns a boolean indicating whether the version is a pre-release version,
as per the definition of a pre-release segment from PEP 440.
"""
return any(label in v for label in ["a", "b", "rc"])

View file

@ -4,8 +4,7 @@ import os.path
def get_required_env_var(var: str) -> str:
"""
Get the value of the specified environment variable.
"""Get the value of the specified environment variable.
Args:
var (str): The environment variable to get.
@ -18,7 +17,8 @@ def get_required_env_var(var: str) -> str:
"""
value = os.getenv(var)
if not value:
raise ValueError(f"Environment variable {var} is not set")
msg = f"Environment variable {var} is not set"
raise ValueError(msg)
return value

View file

@ -16,13 +16,6 @@ from base.langflow.components.inputs.ChatInput import ChatInput
from dotenv import load_dotenv
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from loguru import logger
from pytest import LogCaptureFixture
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from tests.api_keys import get_openai_api_key
from typer.testing import CliRunner
from langflow.graph.graph.base import Graph
from langflow.initial_setup.setup import STARTER_FOLDER_NAME
from langflow.services.auth.utils import get_password_hash
@ -34,6 +27,13 @@ from langflow.services.database.models.user.model import User, UserCreate, UserR
from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from loguru import logger
from pytest import LogCaptureFixture
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from tests.api_keys import get_openai_api_key
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
@ -114,7 +114,7 @@ def caplog(caplog: LogCaptureFixture):
logger.remove(handler_id)
@pytest.fixture()
@pytest.fixture
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
@ -188,8 +188,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
def get_graph(_type="basic"):
"""Get a graph from a json file"""
"""Get a graph from a json file."""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
@ -197,7 +196,7 @@ def get_graph(_type="basic"):
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with path.open() as f:
with path.open(encoding="utf-8") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
@ -209,13 +208,13 @@ def get_graph(_type="basic"):
@pytest.fixture
def basic_graph_data():
with pytest.BASIC_EXAMPLE_PATH.open() as f:
with pytest.BASIC_EXAMPLE_PATH.open(encoding="utf-8") as f:
return json.load(f)
@pytest.fixture
def basic_graph():
yield get_graph()
return get_graph()
@pytest.fixture
@ -230,47 +229,47 @@ def openapi_graph():
@pytest.fixture
def json_flow():
return pytest.BASIC_EXAMPLE_PATH.read_text()
return pytest.BASIC_EXAMPLE_PATH.read_text(encoding="utf-8")
@pytest.fixture
def grouped_chat_json_flow():
return pytest.GROUPED_CHAT_EXAMPLE_PATH.read_text()
return pytest.GROUPED_CHAT_EXAMPLE_PATH.read_text(encoding="utf-8")
@pytest.fixture
def one_grouped_chat_json_flow():
return pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH.read_text()
return pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH.read_text(encoding="utf-8")
@pytest.fixture
def vector_store_grouped_json_flow():
return pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH.read_text()
return pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH.read_text(encoding="utf-8")
@pytest.fixture
def json_flow_with_prompt_and_history():
return pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY.read_text()
return pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY.read_text(encoding="utf-8")
@pytest.fixture
def json_simple_api_test():
return pytest.SIMPLE_API_TEST.read_text()
return pytest.SIMPLE_API_TEST.read_text(encoding="utf-8")
@pytest.fixture
def json_vector_store():
return pytest.VECTOR_STORE_PATH.read_text()
return pytest.VECTOR_STORE_PATH.read_text(encoding="utf-8")
@pytest.fixture
def json_webhook_test():
return pytest.WEBHOOK_TEST.read_text()
return pytest.WEBHOOK_TEST.read_text(encoding="utf-8")
@pytest.fixture
def json_memory_chatbot_no_llm():
return pytest.MEMORY_CHATBOT_NO_LLM.read_text()
return pytest.MEMORY_CHATBOT_NO_LLM.read_text(encoding="utf-8")
@pytest.fixture(name="client")
@ -315,12 +314,12 @@ def session_getter_fixture(client):
with Session(db_service.engine) as session:
yield session
yield blank_session_getter
return blank_session_getter
@pytest.fixture
def runner():
yield CliRunner()
return CliRunner()
@pytest.fixture
@ -337,7 +336,7 @@ async def test_user(client):
await client.delete(f"/api/v1/users/{user['id']}")
@pytest.fixture(scope="function")
@pytest.fixture
def active_user(client):
db_manager = get_db_service()
with db_manager.with_session() as session:
@ -372,7 +371,7 @@ async def logged_in_headers(client, active_user):
assert response.status_code == 200
tokens = response.json()
a_token = tokens["access_token"]
yield {"Authorization": f"Bearer {a_token}"}
return {"Authorization": f"Bearer {a_token}"}
@pytest.fixture
@ -395,12 +394,12 @@ def flow(client, json_flow: str, active_user):
@pytest.fixture
def json_chat_input():
return pytest.CHAT_INPUT.read_text()
return pytest.CHAT_INPUT.read_text(encoding="utf-8")
@pytest.fixture
def json_two_outputs():
return pytest.TWO_OUTPUTS.read_text()
return pytest.TWO_OUTPUTS.read_text(encoding="utf-8")
@pytest.fixture
@ -528,7 +527,8 @@ def get_starter_project(active_user):
.where(Flow.name == "Basic Prompting (Hello, World)")
).first()
if not flow:
raise ValueError("No starter project found")
msg = "No starter project found"
raise ValueError(msg)
# ensure openai api key is set
get_openai_api_key()

View file

@ -1,18 +1,14 @@
import os
from astrapy.db import AstraDB
import pytest
from astrapy.db import AstraDB
from langchain_core.documents import Document
from langflow.components.embeddings import OpenAIEmbeddingsComponent
from langflow.components.vectorstores import AstraVectorStoreComponent
from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import ComponentInputHandle
from langchain_core.documents import Document
from langflow.schema.data import Data
from tests.integration.utils import run_single_component
from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import ComponentInputHandle, run_single_component
BASIC_COLLECTION = "test_basic"
SEARCH_COLLECTION = "test_search"
@ -30,7 +26,7 @@ ALL_COLLECTIONS = [
]
@pytest.fixture()
@pytest.fixture
def astradb_client(request):
client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token())
yield client
@ -139,7 +135,7 @@ def test_astra_vectorize():
@pytest.mark.api_key_required
def test_astra_vectorize_with_provider_api_key():
"""tests vectorize using an openai api key"""
"""Tests vectorize using an openai api key."""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
application_token = get_astradb_application_token()
@ -196,7 +192,7 @@ def test_astra_vectorize_with_provider_api_key():
@pytest.mark.api_key_required
def test_astra_vectorize_passes_authentication():
"""tests vectorize using the authentication parameter"""
"""Tests vectorize using the authentication parameter."""
from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions
store = None

View file

@ -1,10 +1,9 @@
import pytest
from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent
from langflow.components.inputs import ChatInput
from langflow.schema import Data
from tests.integration.components.mock_components import TextToData
from tests.integration.utils import run_single_component, ComponentInputHandle
from tests.integration.utils import ComponentInputHandle, run_single_component
@pytest.mark.asyncio

View file

@ -1,10 +1,9 @@
import pytest
from langflow.components.inputs import ChatInput
from langflow.memory import get_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
from langflow.components.inputs import ChatInput
import pytest
@pytest.mark.asyncio
async def test_default():

View file

@ -1,14 +1,12 @@
import pytest
from langflow.components.inputs import TextInputComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
from langflow.components.inputs import TextInputComponent
import pytest
@pytest.mark.asyncio
async def test_text_input():
outputs = await run_single_component(TextInputComponent, run_input="sample text", input_type="text")
print(outputs)
assert isinstance(outputs["text"], Message)
assert outputs["text"].text == "sample text"
assert outputs["text"].sender is None

View file

@ -1,8 +1,7 @@
import json
from typing import List
from langflow.custom import Component
from langflow.inputs import StrInput, BoolInput
from langflow.inputs import BoolInput, StrInput
from langflow.schema import Data
from langflow.template import Output
@ -21,5 +20,5 @@ class TextToData(Component):
return Data(data=json.loads(text))
return Data(text=text)
def create_data(self) -> List[Data]:
def create_data(self) -> list[Data]:
return [self._to_data(t) for t in self.text_data]

View file

@ -1,6 +1,6 @@
import os
import pytest
import pytest
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.output_parsers.OutputParser import OutputParserComponent
from langflow.components.prompts.Prompt import PromptComponent
@ -23,7 +23,7 @@ async def test_csv_output_parser_openai():
prompt_handler = ComponentInputHandle(
clazz=PromptComponent,
inputs={
"template": "List the first five positive integers.\n\n{format_instructions}",
"template": "List the first five positive integers.\n\n{format_instructions}", # noqa: RUF027
"format_instructions": format_instructions,
},
output_name="prompt",

View file

@ -1,10 +1,9 @@
import pytest
from langflow.components.outputs import ChatOutput
from langflow.memory import get_messages
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
import pytest
@pytest.mark.asyncio
async def test_string():

View file

@ -1,9 +1,8 @@
import pytest
from langflow.components.outputs import TextOutputComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
import pytest
@pytest.mark.asyncio
async def test():

View file

@ -1,13 +1,11 @@
import pytest
from langflow.components.prompts import PromptComponent
from langflow.schema.message import Message
from tests.integration.utils import run_single_component
import pytest
@pytest.mark.asyncio
async def test():
outputs = await run_single_component(PromptComponent, inputs={"template": "test {var1}", "var1": "from the var"})
print(outputs)
assert isinstance(outputs["prompt"], Message)
assert outputs["prompt"].text == "test from the var"

View file

@ -1,5 +1,4 @@
import pytest
from langflow.components.inputs import ChatInput
from langflow.components.outputs import ChatOutput
from langflow.components.prompts import PromptComponent

View file

@ -3,7 +3,6 @@ from uuid import uuid4
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from langflow.graph.schema import RunOutputs
from langflow.initial_setup.setup import load_starter_projects
from langflow.load import run_flow_from_json
@ -80,9 +79,8 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_
@pytest.mark.noclient
@pytest.mark.api_key_required
def test_run_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks"""
_, projects = zip(*load_starter_projects())
project = [project for project in projects if "Basic Prompting" in project["name"]][0]
"""Test loading a flow from a json file and applying tweaks."""
project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"])
results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True)
assert results is not None
assert all(isinstance(result, RunOutputs) for result in results)

View file

@ -1,21 +1,19 @@
import dataclasses
import os
import uuid
from typing import Optional, Any
from typing import Any
import requests
from astrapy.admin import parse_api_endpoint
from langflow.api.v1.schemas import InputValueRequest
from langflow.custom import Component
from langflow.field_typing import Embeddings
from langflow.graph import Graph
from langflow.processing.process import run_graph_internal
import requests
def check_env_vars(*vars):
"""
Check if all specified environment variables are set.
"""Check if all specified environment variables are set.
Args:
*vars (str): The environment variables to check.
@ -27,8 +25,7 @@ def check_env_vars(*vars):
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
"""
Check if the specified region is valid.
"""Check if the specified region is valid.
Args:
region (str): The region to check.
@ -38,8 +35,9 @@ def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
"""
parsed_endpoint = parse_api_endpoint(api_endpoint)
if not parsed_endpoint:
raise ValueError("Invalid ASTRA_DB_API_ENDPOINT")
return parsed_endpoint.region in ["us-east-2"]
msg = "Invalid ASTRA_DB_API_ENDPOINT"
raise ValueError(msg)
return parsed_endpoint.region == "us-east-2"
class MockEmbeddings(Embeddings):
@ -70,15 +68,15 @@ class JSONFlow:
if node["data"]["type"] == component_type:
result.append(node["id"])
if not result:
raise ValueError(
f"Component of type {component_type} not found, available types: {', '.join(set(node['data']['type'] for node in self.json['data']['nodes']))}"
)
msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
raise ValueError(msg)
return result
def get_component_by_type(self, component_type):
components = self.get_components_by_type(component_type)
if len(components) > 1:
raise ValueError(f"Multiple components of type {component_type} found")
msg = f"Multiple components of type {component_type} found"
raise ValueError(msg)
return components[0]
def set_value(self, component_id, key, value):
@ -86,13 +84,15 @@ class JSONFlow:
for node in self.json["data"]["nodes"]:
if node["id"] == component_id:
if key not in node["data"]["node"]["template"]:
raise ValueError(f"Component {component_id} does not have input {key}")
msg = f"Component {component_id} does not have input {key}"
raise ValueError(msg)
node["data"]["node"]["template"][key]["value"] = value
node["data"]["node"]["template"][key]["load_from_db"] = False
done = True
break
if not done:
raise ValueError(f"Component {component_id} not found")
msg = f"Component {component_id} not found"
raise ValueError(msg)
def download_flow_from_github(name: str, version: str) -> JSONFlow:
@ -105,18 +105,15 @@ def download_flow_from_github(name: str, version: str) -> JSONFlow:
async def run_json_flow(
json_flow: JSONFlow, run_input: Optional[Any] = None, session_id: Optional[str] = None
json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None
) -> dict[str, Any]:
graph = Graph.from_payload(json_flow.json)
return await run_flow(graph, run_input, session_id)
async def run_flow(graph: Graph, run_input: Optional[Any] = None, session_id: Optional[str] = None) -> dict[str, Any]:
async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]:
graph.prepare()
if run_input:
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")]
else:
graph_run_inputs = []
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else []
flow_id = str(uuid.uuid4())
@ -137,23 +134,24 @@ class ComponentInputHandle:
async def run_single_component(
clazz: type,
inputs: dict = None,
run_input: Optional[Any] = None,
session_id: Optional[str] = None,
input_type: Optional[str] = "chat",
inputs: dict | None = None,
run_input: Any | None = None,
session_id: str | None = None,
input_type: str | None = "chat",
) -> dict[str, Any]:
user_id = str(uuid.uuid4())
flow_id = str(uuid.uuid4())
graph = Graph(user_id=user_id, flow_id=flow_id)
def _add_component(clazz: type, inputs: Optional[dict] = None) -> str:
def _add_component(clazz: type, inputs: dict | None = None) -> str:
raw_inputs = {}
if inputs:
for key, value in inputs.items():
if not isinstance(value, ComponentInputHandle):
raw_inputs[key] = value
if isinstance(value, Component):
raise ValueError("Component inputs must be wrapped in ComponentInputHandle")
msg = "Component inputs must be wrapped in ComponentInputHandle"
raise ValueError(msg)
component = clazz(**raw_inputs, _user_id=user_id)
component_id = graph.add_component(component)
if inputs:
@ -165,10 +163,7 @@ async def run_single_component(
component_id = _add_component(clazz, inputs)
graph.prepare()
if run_input:
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)]
else:
graph_run_inputs = []
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else []
_, _ = await run_graph_internal(
graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id]

View file

@ -11,8 +11,8 @@ from rich import print
class NameTest(FastHttpUser):
wait_time = between(1, 5)
with Path("names.txt").open() as file:
names = [line.strip() for line in file.readlines()]
with Path("names.txt").open(encoding="utf-8") as file:
names = [line.strip() for line in file]
headers: dict = {}
@ -28,8 +28,9 @@ class NameTest(FastHttpUser):
print(f"Poll Response: {response.js}")
if status == "SUCCESS":
return response.js.get("result")
elif status in ["FAILURE", "REVOKED"]:
raise ValueError(f"Task failed with status: {status}")
if status in {"FAILURE", "REVOKED"}:
msg = f"Task failed with status: {status}"
raise ValueError(msg)
time.sleep(sleep_time)
def process(self, name, flow_id, payload):
@ -45,7 +46,8 @@ class NameTest(FastHttpUser):
print(response.js)
if response.status_code != 200:
response.failure("Process call failed")
raise ValueError("Process call failed")
msg = "Process call failed"
raise ValueError(msg)
task_id = response.js.get("id")
session_id = response.js.get("session_id")
assert task_id, "Inner Task ID not found"
@ -86,7 +88,9 @@ class NameTest(FastHttpUser):
a_token = tokens["access_token"]
logged_in_headers = {"Authorization": f"Bearer {a_token}"}
print("Logged in")
json_flow = (Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json").read_text()
json_flow = (Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json").read_text(
encoding="utf-8"
)
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data

View file

@ -1,5 +1,6 @@
from langflow.api.utils import get_suggestion_message
from unittest.mock import patch
from langflow.api.utils import get_suggestion_message
from langflow.services.database.models.flow.utils import get_outdated_components
from langflow.utils.version import get_version_info

View file

@ -21,11 +21,11 @@ async def test_create_variable(client: AsyncClient, body, logged_in_headers):
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_201_CREATED == response.status_code
assert response.status_code == status.HTTP_201_CREATED
assert body["name"] == result["name"]
assert body["type"] == result["type"]
assert body["default_fields"] == result["default_fields"]
assert "id" in result.keys()
assert "id" in result
assert body["value"] != result["value"]
@ -36,7 +36,7 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Variable name already exists" in result["detail"]
@ -48,7 +48,7 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty(client:
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Variable name and value cannot be empty" in result["detail"]
@ -59,7 +59,7 @@ async def test_create_variable__variable_name_cannot_be_empty(client: AsyncClien
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Variable name cannot be empty" in result["detail"]
@ -70,7 +70,7 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_400_BAD_REQUEST == response.status_code
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Variable value cannot be empty" in result["detail"]
@ -84,7 +84,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, logged_
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_418_IM_A_TEAPOT == response.status_code
assert response.status_code == status.HTTP_418_IM_A_TEAPOT
assert generic_message in result["detail"]
@ -97,7 +97,7 @@ async def test_create_variable__Exception(client: AsyncClient, body, logged_in_h
response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert generic_message in result["detail"]
@ -111,7 +111,7 @@ async def test_read_variables(client: AsyncClient, body, logged_in_headers):
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert response.status_code == status.HTTP_200_OK
assert all(name in [r["name"] for r in result] for name in names)
@ -125,23 +125,22 @@ async def test_read_variables__empty(client: AsyncClient, logged_in_headers):
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert [] == result
assert response.status_code == status.HTTP_200_OK
assert result == []
@pytest.mark.usefixtures("active_user")
async def test_read_variables__(client: AsyncClient, logged_in_headers):
generic_message = "Generic error message"
with pytest.raises(Exception) as exc:
with mock.patch("sqlmodel.Session.exec") as m:
m.side_effect = Exception(generic_message)
with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m:
m.side_effect = Exception(generic_message)
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
response = await client.get("api/v1/variables/", headers=logged_in_headers)
result = response.json()
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
assert generic_message in result["detail"]
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert generic_message in result["detail"]
assert generic_message in str(exc.value)
@ -159,7 +158,7 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers):
response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_200_OK == response.status_code
assert response.status_code == status.HTTP_200_OK
assert saved["id"] == result["id"]
assert saved["name"] != result["name"]
assert saved["default_fields"] != result["default_fields"]
@ -173,7 +172,7 @@ async def test_update_variable__Exception(client: AsyncClient, body, logged_in_h
response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers)
result = response.json()
assert status.HTTP_404_NOT_FOUND == response.status_code
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Variable not found" in result["detail"]
@ -183,7 +182,7 @@ async def test_delete_variable(client: AsyncClient, body, logged_in_headers):
saved = response.json()
response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers)
assert status.HTTP_204_NO_CONTENT == response.status_code
assert response.status_code == status.HTTP_204_NO_CONTENT
@pytest.mark.usefixtures("active_user")
@ -192,4 +191,4 @@ async def test_delete_variable__Exception(client: AsyncClient, logged_in_headers
response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers)
assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View file

@ -1,4 +1,5 @@
import inspect
from langflow.load import run_flow_from_json

View file

@ -1,7 +1,6 @@
import os
import pytest
from langflow.base.tools.component_tool import ComponentToolkit
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
from langflow.components.inputs.ChatInput import ChatInput
@ -76,7 +75,7 @@ def test_component_tool():
}
assert component_toolkit.component == chat_input
result = component_tool.invoke(input=dict(input_value="test"))
result = component_tool.invoke(input={"input_value": "test"})
assert isinstance(result, Message)
assert result.get_text() == "test"

View file

@ -1,10 +1,9 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel
from langflow.components.helpers.structured_output import StructuredOutputComponent
from langflow.schema.data import Data
from pydantic import BaseModel
@pytest.fixture

View file

@ -3,7 +3,6 @@ from urllib.parse import urljoin
import pytest
from langchain_community.chat_models.ollama import ChatOllama
from langflow.components.models.OllamaModel import ChatOllamaComponent

View file

@ -1,6 +1,5 @@
from collections.abc import Callable
from langflow.components.inputs.ChatInput import ChatInput

View file

@ -1,5 +1,4 @@
import pytest
from langflow.components.agents.CrewAIAgent import CrewAIAgentComponent
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
from langflow.components.helpers.SequentialTask import SequentialTaskComponent

View file

@ -4,7 +4,6 @@ import time
import uuid
import pytest
from langflow.events.event_manager import EventManager
from langflow.schema.log import LoggableType

View file

@ -1,4 +1,5 @@
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
from langflow.services.database.models.flow.model import Flow

View file

@ -1,5 +1,4 @@
import pytest
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.outputs.ChatOutput import ChatOutput

View file

@ -1,12 +1,11 @@
import pytest
from pydantic import Field
from langflow.components.inputs import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.state.model import create_state_model
from langflow.template.field.base import UNDEFINED
from pydantic import Field
@pytest.fixture

View file

@ -2,8 +2,6 @@ import logging
from collections import deque
import pytest
from pytest import LogCaptureFixture
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
@ -11,6 +9,7 @@ from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
from pytest import LogCaptureFixture
@pytest.mark.asyncio

View file

@ -1,6 +1,5 @@
import asyncio
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.custom.custom_component.component import Component
from langflow.events.event_manager import EventManager

View file

@ -1,7 +1,6 @@
import os
import pytest
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.outputs.ChatOutput import ChatOutput
@ -203,5 +202,3 @@ def test_updated_graph_with_prompts():
# Extract the vertex IDs for analysis
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}"
print(f"Execution completed with results: {results_ids}")

View file

@ -1,6 +1,6 @@
import pytest
from pydantic import BaseModel
from typing import TYPE_CHECKING
import pytest
from langflow.components.helpers.Memory import MemoryComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
@ -10,6 +10,9 @@ from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.state_model import create_state_model_from_graph
if TYPE_CHECKING:
from pydantic import BaseModel
def test_graph_state_model():
session_id = "test_session_id"

View file

@ -1,10 +1,12 @@
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING
import pytest
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
if TYPE_CHECKING:
from collections import defaultdict
@pytest.fixture
def data():
@ -23,7 +25,7 @@ def data():
def test_to_dict(data):
result = RunnableVerticesManager.from_dict(data).to_dict()
assert all(key in result.keys() for key in data.keys())
assert all(key in result for key in data)
def test_from_dict(data):
@ -158,8 +160,8 @@ def test_build_run_map(data):
manager.build_run_map(predecessor_map, vertices_to_run)
assert all(v in manager.run_map.keys() for v in ["Z", "X", "Y"])
assert "W" not in manager.run_map.keys()
assert all(v in manager.run_map for v in ["Z", "X", "Y"])
assert "W" not in manager.run_map
def test_update_vertex_run_state(data):

View file

@ -1,7 +1,6 @@
import copy
import pytest
from langflow.graph.graph import utils

View file

@ -2,7 +2,6 @@ import copy
import json
import pytest
from langflow.graph import Graph
from langflow.graph.graph.utils import (
find_last_node,
@ -59,7 +58,7 @@ def sample_nodes():
def get_node_by_type(graph, node_type: type[Vertex]) -> Vertex | None:
"""Get a node by type"""
"""Get a node by type."""
return next((node for node in graph.vertices if isinstance(node, node_type)), None)
@ -131,10 +130,10 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
node_data = group_node["data"]["node"]
assert node_data.get("flow") is not None
template_data = node_data["template"]
assert any("openai_api_key" in key for key in template_data.keys())
assert any("openai_api_key" in key for key in template_data)
# Get the openai_api_key dict
openai_api_key = next(
(template_data[key] for key in template_data.keys() if "openai_api_key" in key),
(template_data[key] for key in template_data if "openai_api_key" in key),
None,
)
assert openai_api_key is not None

View file

@ -3,11 +3,10 @@
from typing import Any
import pytest
from langflow.helpers.base_model import build_model_from_schema
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from langflow.helpers.base_model import build_model_from_schema
class TestBuildModelFromSchema:
# Successfully creates a Pydantic model from a valid schema
@ -115,9 +114,9 @@ class TestBuildModelFromSchema:
{"name": "field3", "type": "list", "default": None, "description": "Field 3 description", "multiple": True},
]
model = build_model_from_schema(schema)
assert model.model_fields["field1"].default == PydanticUndefined # noqa: E711
assert model.model_fields["field2"].default == PydanticUndefined # noqa: E711
assert model.model_fields["field3"].default == PydanticUndefined # noqa: E711
assert model.model_fields["field1"].default == PydanticUndefined
assert model.model_fields["field2"].default == PydanticUndefined
assert model.model_fields["field3"].default == PydanticUndefined
# Checks for proper handling of nested list and dict types
def test_nested_list_and_dict_types_handling(self):

View file

@ -1,7 +1,8 @@
import operator
from collections import deque
from typing import TYPE_CHECKING
import pytest
from langflow.components.helpers.Memory import MemoryComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
@ -9,7 +10,9 @@ from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.prompts.Prompt import PromptComponent
from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.schema import GraphDump
if TYPE_CHECKING:
from langflow.graph.graph.schema import GraphDump
@pytest.fixture
@ -100,7 +103,7 @@ def test_memory_chatbot_dump_components_and_edges(memory_chatbot_graph: Graph):
edges = data_dict["edges"]
# sort the nodes by id
nodes = sorted(nodes, key=lambda x: x["id"])
nodes = sorted(nodes, key=operator.itemgetter("id"))
# Check each node
assert nodes[0]["data"]["type"] == "ChatInput"

View file

@ -1,8 +1,8 @@
import copy
import operator
from textwrap import dedent
import pytest
from langflow.components.data.File import FileComponent
from langflow.components.embeddings.OpenAIEmbeddings import OpenAIEmbeddingsComponent
from langflow.components.helpers.ParseData import ParseDataComponent
@ -40,8 +40,7 @@ def ingestion_graph():
vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True)
vector_store.set_on_output(name="search_results", value=[Data(text="This is a test file.")], cache=True)
ingestion_graph = Graph(file_component, vector_store)
return ingestion_graph
return Graph(file_component, vector_store)
@pytest.fixture
@ -89,8 +88,7 @@ def rag_graph():
chat_output = ChatOutput(_id="chatoutput-123")
chat_output.set(input_value=openai_component.text_response)
graph = Graph(start=chat_input, end=chat_output)
return graph
return Graph(start=chat_input, end=chat_output)
def test_vector_store_rag(ingestion_graph, rag_graph):
@ -111,7 +109,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph):
"rag-vector-store-123",
"openai-embeddings-124",
]
for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]):
for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]:
results = []
for result in graph.start():
results.append(result)
@ -134,7 +132,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph):
ingestion_edges = ingestion_data["edges"]
# Sort nodes by id to check components
ingestion_nodes = sorted(ingestion_nodes, key=lambda x: x["id"])
ingestion_nodes = sorted(ingestion_nodes, key=operator.itemgetter("id"))
# Check components in the ingestion graph
assert ingestion_nodes[0]["data"]["type"] == "File"
@ -172,7 +170,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph):
rag_edges = rag_data["edges"]
# Sort nodes by id to check components
rag_nodes = sorted(rag_nodes, key=lambda x: x["id"])
rag_nodes = sorted(rag_nodes, key=operator.itemgetter("id"))
# Check components in the RAG graph
assert rag_nodes[0]["data"]["type"] == "ChatInput"
@ -235,7 +233,7 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
combined_edges = combined_data["edges"]
# Sort nodes by id to check components
combined_nodes = sorted(combined_nodes, key=lambda x: x["id"])
combined_nodes = sorted(combined_nodes, key=operator.itemgetter("id"))
# Expected components in the combined graph (both ingestion and RAG nodes)
expected_nodes = sorted(
@ -252,10 +250,10 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph):
{"id": "prompt-123", "type": "Prompt"},
{"id": "rag-vector-store-123", "type": "AstraDB"},
],
key=lambda x: x["id"],
key=operator.itemgetter("id"),
)
for expected_node, combined_node in zip(expected_nodes, combined_nodes):
for expected_node, combined_node in zip(expected_nodes, combined_nodes, strict=True):
assert combined_node["data"]["type"] == expected_node["type"]
assert combined_node["id"] == expected_node["id"]

View file

@ -1,6 +1,4 @@
import pytest
from pydantic import ValidationError
from langflow.inputs.inputs import (
BoolInput,
CodeInput,
@ -24,6 +22,7 @@ from langflow.inputs.inputs import (
)
from langflow.inputs.utils import instantiate_input
from langflow.schema.message import Message
from pydantic import ValidationError
def test_table_input_valid():

View file

@ -1,10 +1,11 @@
from typing import Literal
from typing import TYPE_CHECKING, Literal
import pytest
from pydantic.fields import FieldInfo
from langflow.components.inputs.ChatInput import ChatInput
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
def test_create_input_schema():
from langflow.io.schema import create_input_schema

View file

@ -1,7 +1,6 @@
# Generated by qodo Gen
import pytest
from langflow.schema.table import Column, FormatterType

View file

@ -1,6 +1,5 @@
import pytest
from langchain_core.prompts.chat import ChatPromptTemplate
from langflow.schema.message import Message

View file

@ -3,13 +3,12 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from langflow.services.database.models.variable.model import VariableUpdate
from langflow.services.deps import get_settings_service
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from sqlmodel import Session, SQLModel, create_engine
@pytest.fixture
@ -44,8 +43,8 @@ def test_initialize_user_variables__create_and_update(service, session):
value = service.get_variable(user_id, name, field, session=session)
assert value == env_vars[name]
assert all([i in variables for i in good_vars.keys()])
assert all([i not in variables for i in bad_vars.keys()])
assert all(i in variables for i in good_vars)
assert all(i not in variables for i in bad_vars)
def test_initialize_user_variables__not_found_variable(service, session):

View file

@ -1,6 +1,5 @@
import pytest
from httpx import AsyncClient
from langflow.services.database.models.api_key import ApiKeyCreate
@ -30,7 +29,8 @@ async def test_create_api_key(client: AsyncClient, logged_in_headers):
response = await client.post("api/v1/api_key/", json={"name": api_key_name}, headers=logged_in_headers)
assert response.status_code == 200
data = response.json()
assert "name" in data and data["name"] == api_key_name
assert "name" in data
assert data["name"] == api_key_name
assert "api_key" in data
assert "**" not in data["api_key"]

View file

@ -1,10 +1,9 @@
import json
from uuid import UUID
from orjson import orjson
from langflow.memory import get_messages
from langflow.services.database.models.flow import FlowCreate, FlowUpdate
from orjson import orjson
async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers):
@ -82,7 +81,8 @@ async def consume_and_assert_stream(r):
elif count == 5:
assert parsed["event"] == "end"
else:
raise ValueError(f"Unexpected line: {line}")
msg = f"Unexpected line: {line}"
raise ValueError(msg)
count += 1
@ -92,5 +92,4 @@ async def _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers):
vector_store = FlowCreate(name="Flow", description="description", data=data, endpoint_name="f")
response = await client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers)
response.raise_for_status()
flow_id = response.json()["id"]
return flow_id
return response.json()["id"]

View file

@ -1,5 +1,4 @@
import pytest
from langflow.__main__ import app
from langflow.services import deps

View file

@ -5,7 +5,6 @@ from textwrap import dedent
import pytest
from langchain_core.documents import Document
from langflow.custom import Component, CustomComponent
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
@ -14,7 +13,7 @@ from langflow.custom.utils import build_custom_component_template
@pytest.fixture
def code_component_with_multiple_outputs():
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text()
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text(encoding="utf-8")
return Component(_code=code)
@ -39,25 +38,20 @@ class YourComponent(CustomComponent):
def test_code_parser_init():
"""
Test the initialization of the CodeParser class.
"""
"""Test the initialization of the CodeParser class."""
parser = CodeParser(code_default)
assert parser.code == code_default
def test_code_parser_get_tree():
"""
Test the __get_tree method of the CodeParser class.
"""
"""Test the __get_tree method of the CodeParser class."""
parser = CodeParser(code_default)
tree = parser.get_tree()
assert isinstance(tree, ast.AST)
def test_code_parser_syntax_error():
"""
Test the __get_tree method raises the
"""Test the __get_tree method raises the
CodeSyntaxError when given incorrect syntax.
"""
code_syntax_error = "zzz import os"
@ -68,26 +62,21 @@ def test_code_parser_syntax_error():
def test_component_init():
"""
Test the initialization of the Component class.
"""
"""Test the initialization of the Component class."""
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
assert component._code == code_default
assert component._function_entrypoint_name == "build"
def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
"""Test the get_code_tree method of the Component class."""
component = BaseComponent(_code=code_default, _function_entrypoint_name="build")
tree = component.get_code_tree(component._code)
assert "imports" in tree
def test_component_code_null_error():
"""
Test the get_function method raises the
"""Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = BaseComponent(_code="", _function_entrypoint_name="")
@ -96,9 +85,7 @@ def test_component_code_null_error():
def test_custom_component_init():
"""
Test the initialization of the CustomComponent class.
"""
"""Test the initialization of the CustomComponent class."""
function_entrypoint_name = "build"
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name=function_entrypoint_name)
@ -107,26 +94,21 @@ def test_custom_component_init():
def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
"""Test the build_template_config property of the CustomComponent class."""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
"""Test the get_function property of the CustomComponent class."""
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
my_function = custom_component.get_function()
assert isinstance(my_function, types.FunctionType)
def test_code_parser_parse_imports_import():
"""
Test the parse_imports method of the CodeParser
"""Test the parse_imports method of the CodeParser
class with an import statement.
"""
parser = CodeParser(code_default)
@ -138,8 +120,7 @@ def test_code_parser_parse_imports_import():
def test_code_parser_parse_imports_importfrom():
"""
Test the parse_imports method of the CodeParser
"""Test the parse_imports method of the CodeParser
class with an import from statement.
"""
parser = CodeParser("from os import path")
@ -151,9 +132,7 @@ def test_code_parser_parse_imports_importfrom():
def test_code_parser_parse_functions():
"""
Test the parse_functions method of the CodeParser class.
"""
"""Test the parse_functions method of the CodeParser class."""
parser = CodeParser("def test(): pass")
tree = parser.get_tree()
for node in ast.walk(tree):
@ -164,9 +143,7 @@ def test_code_parser_parse_functions():
def test_code_parser_parse_classes():
"""
Test the parse_classes method of the CodeParser class.
"""
"""Test the parse_classes method of the CodeParser class."""
parser = CodeParser("from langflow.custom import Component\n\nclass Test(Component): pass")
tree = parser.get_tree()
for node in ast.walk(tree):
@ -177,9 +154,7 @@ def test_code_parser_parse_classes():
def test_code_parser_parse_classes_raises():
"""
Test the parse_classes method of the CodeParser class.
"""
"""Test the parse_classes method of the CodeParser class."""
parser = CodeParser("class Test: pass")
tree = parser.get_tree()
with pytest.raises(TypeError):
@ -189,9 +164,7 @@ def test_code_parser_parse_classes_raises():
def test_code_parser_parse_global_vars():
"""
Test the parse_global_vars method of the CodeParser class.
"""
"""Test the parse_global_vars method of the CodeParser class."""
parser = CodeParser("x = 1")
tree = parser.get_tree()
for node in ast.walk(tree):
@ -202,8 +175,7 @@ def test_code_parser_parse_global_vars():
def test_component_get_function_valid():
"""
Test the get_function method of the Component
"""Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build")
@ -212,8 +184,7 @@ def test_component_get_function_valid():
def test_custom_component_get_function_entrypoint_args():
"""
Test the get_function_entrypoint_args
"""Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
@ -225,28 +196,23 @@ def test_custom_component_get_function_entrypoint_args():
def test_custom_component_get_function_entrypoint_return_type():
"""
Test the get_function_entrypoint_return_type
"""Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == [Document]
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
"""Test the get_main_class_name property of the CustomComponent class."""
custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
def test_custom_component_get_function_valid():
"""
Test the get_function property of the CustomComponent
"""Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
@ -255,9 +221,7 @@ def test_custom_component_get_function_valid():
def test_code_parser_parse_arg_no_annotation():
"""
Test the parse_arg method of the CodeParser class without an annotation.
"""
"""Test the parse_arg method of the CodeParser class without an annotation."""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=None)
result = parser.parse_arg(arg, None)
@ -266,9 +230,7 @@ def test_code_parser_parse_arg_no_annotation():
def test_code_parser_parse_arg_with_annotation():
"""
Test the parse_arg method of the CodeParser class with an annotation.
"""
"""Test the parse_arg method of the CodeParser class with an annotation."""
parser = CodeParser("")
arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load()))
result = parser.parse_arg(arg, None)
@ -277,8 +239,7 @@ def test_code_parser_parse_arg_with_annotation():
def test_code_parser_parse_callable_details_no_args():
"""
Test the parse_callable_details method of the
"""Test the parse_callable_details method of the
CodeParser class with a function with no arguments.
"""
parser = CodeParser("")
@ -295,9 +256,7 @@ def test_code_parser_parse_callable_details_no_args():
def test_code_parser_parse_assign():
"""
Test the parse_assign method of the CodeParser class.
"""
"""Test the parse_assign method of the CodeParser class."""
parser = CodeParser("")
stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1))
result = parser.parse_assign(stmt)
@ -306,9 +265,7 @@ def test_code_parser_parse_assign():
def test_code_parser_parse_ann_assign():
"""
Test the parse_ann_assign method of the CodeParser class.
"""
"""Test the parse_ann_assign method of the CodeParser class."""
parser = CodeParser("")
stmt = ast.AnnAssign(
target=ast.Name(id="x", ctx=ast.Store()),
@ -323,8 +280,7 @@ def test_code_parser_parse_ann_assign():
def test_code_parser_parse_function_def_not_init():
"""
Test the parse_function_def method of the
"""Test the parse_function_def method of the
CodeParser class with a function that is not __init__.
"""
parser = CodeParser("")
@ -341,8 +297,7 @@ def test_code_parser_parse_function_def_not_init():
def test_code_parser_parse_function_def_init():
"""
Test the parse_function_def method of the
"""Test the parse_function_def method of the
CodeParser class with an __init__ function.
"""
parser = CodeParser("")
@ -359,8 +314,7 @@ def test_code_parser_parse_function_def_init():
def test_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the Component class
"""Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = BaseComponent(_code="import os as", _function_entrypoint_name="build")
@ -369,8 +323,7 @@ def test_component_get_code_tree_syntax_error():
def test_custom_component_class_template_validation_no_code():
"""
Test the _class_template_validation method of the CustomComponent class
"""Test the _class_template_validation method of the CustomComponent class
raises the HTTPException when the code is None.
"""
custom_component = CustomComponent(_code=None, _function_entrypoint_name="build")
@ -379,8 +332,7 @@ def test_custom_component_class_template_validation_no_code():
def test_custom_component_get_code_tree_syntax_error():
"""
Test the get_code_tree method of the CustomComponent class
"""Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build")
@ -389,8 +341,7 @@ def test_custom_component_get_code_tree_syntax_error():
def test_custom_component_get_function_entrypoint_args_no_args():
"""
Test the get_function_entrypoint_args property of
"""Test the get_function_entrypoint_args property of
the CustomComponent class with a build method with no arguments.
"""
my_code = """
@ -405,8 +356,7 @@ class MyMainClass(CustomComponent):
def test_custom_component_get_function_entrypoint_return_type_no_return_type():
"""
Test the get_function_entrypoint_return_type property of the
"""Test the get_function_entrypoint_return_type property of the
CustomComponent class with a build method with no return type.
"""
my_code = """
@ -421,8 +371,7 @@ class MyClass(CustomComponent):
def test_custom_component_get_main_class_name_no_main_class():
"""
Test the get_main_class_name property of the
"""Test the get_main_class_name property of the
CustomComponent class when there is no main class.
"""
my_code = """
@ -435,8 +384,7 @@ def build():
def test_custom_component_build_not_implemented():
"""
Test the build method of the CustomComponent
"""Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build")
@ -453,7 +401,7 @@ def test_build_config_no_code():
@pytest.fixture
def component():
yield CustomComponent(
return CustomComponent(
field_config={
"fields": {
"llm": {"type": "str"},

View file

@ -1,7 +1,6 @@
from pathlib import Path
import pytest
from langflow.custom import Component
from langflow.custom.custom_component.custom_component import CustomComponent
from langflow.custom.utils import build_custom_component_template
@ -11,7 +10,7 @@ from langflow.services.settings.feature_flags import FEATURE_FLAGS
@pytest.fixture
def code_component_with_multiple_outputs():
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text()
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text(encoding="utf-8")
return Component(_code=code)

View file

@ -1,6 +1,5 @@
import pytest
from langchain_core.documents import Document
from langflow.schema import Data

View file

@ -1,12 +1,11 @@
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch, ANY
from unittest.mock import ANY, Mock, patch
import httpx
import pytest
import respx
from httpx import Response
from langflow.components import data
@ -163,16 +162,16 @@ def test_directory_without_mocks():
directory_component = data.DirectoryComponent()
with tempfile.TemporaryDirectory() as temp_dir:
(Path(temp_dir) / "test.txt").write_text("test")
(Path(temp_dir) / "test.txt").write_text("test", encoding="utf-8")
# also add a json file
(Path(temp_dir) / "test.json").write_text('{"test": "test"}')
(Path(temp_dir) / "test.json").write_text('{"test": "test"}', encoding="utf-8")
directory_component.set_attributes({"path": str(temp_dir), "use_multithreading": False})
results = directory_component.load_directory()
assert len(results) == 2
values = ["test", '{"test":"test"}']
assert all(result.text in values for result in results), [
(len(result.text), len(val)) for result, val in zip(results, values)
(len(result.text), len(val)) for result, val in zip(results, values, strict=True)
]
# in ../docs/docs/components there are many mdx files

View file

@ -5,8 +5,6 @@ from uuid import UUID, uuid4
import orjson
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from langflow.api.v1.schemas import FlowListCreate, ResultDataResponse
from langflow.graph.utils import log_transaction, log_vertex_build
from langflow.initial_setup.setup import load_flows_from_directory, load_starter_projects
@ -15,6 +13,7 @@ from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.models.folder.model import FolderCreate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session
@pytest.fixture(scope="module")
@ -520,5 +519,5 @@ def test_sqlite_pragmas():
with db_service.with_session() as session:
from sqlalchemy import text
assert "wal" == session.exec(text("PRAGMA journal_mode;")).scalar()
assert 1 == session.exec(text("PRAGMA synchronous;")).scalar()
assert session.exec(text("PRAGMA journal_mode;")).scalar() == "wal"
assert session.exec(text("PRAGMA synchronous;")).scalar() == 1

View file

@ -4,7 +4,6 @@ from uuid import UUID, uuid4
import pytest
from fastapi import status
from httpx import AsyncClient
from langflow.custom.directory_reader.directory_reader import DirectoryReader
from langflow.services.deps import get_settings_service
@ -377,7 +376,7 @@ async def test_invalid_prompt(client: AsyncClient):
@pytest.mark.parametrize(
"prompt,expected_input_variables",
("prompt", "expected_input_variables"),
[
("{color} is my favorite color.", ["color"]),
("The weather is {weather} today.", ["weather"]),
@ -442,13 +441,13 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
assert all(["ChatOutput" in _id for _id in ids])
assert all("ChatOutput" in _id for _id in ids)
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
assert all([name in display_names for name in ["Chat Output"]])
assert all(name in display_names for name in ["Chat Output"])
output_results_has_results = all("results" in output.get("results") for output in outputs_dict.get("outputs"))
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
assert all([result is not None for result in inner_results]), (outputs_dict, output_results_has_results)
assert all(result is not None for result in inner_results), (outputs_dict, output_results_has_results)
async def test_successful_run_with_output_type_text(client, simple_api_test, created_api_key):
@ -473,12 +472,12 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
assert all(["ChatOutput" in _id for _id in ids]), ids
assert all("ChatOutput" in _id for _id in ids), ids
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
assert all([name in display_names for name in ["Chat Output"]]), display_names
assert all(name in display_names for name in ["Chat Output"]), display_names
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
expected_keys = ["message"]
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
assert all(key in result for result in inner_results for key in expected_keys), outputs_dict
async def test_successful_run_with_output_type_any(client, simple_api_test, created_api_key):
@ -504,12 +503,12 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea
assert isinstance(outputs_dict.get("outputs"), list)
assert len(outputs_dict.get("outputs")) == 1
ids = [output.get("component_id") for output in outputs_dict.get("outputs")]
assert all(["ChatOutput" in _id or "TextOutput" in _id for _id in ids]), ids
assert all("ChatOutput" in _id or "TextOutput" in _id for _id in ids), ids
display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")]
assert all([name in display_names for name in ["Chat Output"]]), display_names
assert all(name in display_names for name in ["Chat Output"]), display_names
inner_results = [output.get("results") for output in outputs_dict.get("outputs")]
expected_keys = ["message"]
assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict
assert all(key in result for result in inner_results for key in expected_keys), outputs_dict
async def test_successful_run_with_output_type_debug(client, simple_api_test, created_api_key):
@ -566,7 +565,7 @@ async def test_successful_run_with_input_type_text(client, simple_api_test, crea
# Now we check if the input_value is correct
# We get text key twice because the output is now a Message
assert all(
[output.get("results").get("text").get("text") == "value1" for output in text_input_outputs]
output.get("results").get("text").get("text") == "value1" for output in text_input_outputs
), text_input_outputs
@ -599,7 +598,7 @@ async def test_successful_run_with_input_type_chat(client: AsyncClient, simple_a
assert len(chat_input_outputs) == 1
# Now we check if the input_value is correct
assert all(
[output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs]
output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs
), chat_input_outputs
@ -653,7 +652,7 @@ async def test_successful_run_with_input_type_any(client, simple_api_test, creat
result_dict.get("message", result_dict.get("text")) for result_dict in all_result_dicts
]
assert all(
[message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts]
message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts
), any_input_outputs

View file

@ -8,10 +8,9 @@ from unittest.mock import MagicMock
import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from sqlmodel import Session
from langflow.services.deps import get_storage_service
from langflow.services.storage.service import StorageService
from sqlmodel import Session
@pytest.fixture
@ -26,7 +25,7 @@ def mock_storage_service():
return service
@pytest.fixture(name="files_client", scope="function")
@pytest.fixture(name="files_client")
async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service):
# Set the database url to a test database
if "noclient" in request.keywords:

View file

@ -1,5 +1,4 @@
import pytest
from langflow.template.field.base import Input
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.template.base import Template

View file

@ -5,7 +5,6 @@ from langflow.custom.utils import build_custom_component_template
from langflow.schema import Data
from langflow.schema.message import Message
# def test_update_data_component():
# # Arrange
# update_data_component = helpers.UpdateDataComponent()
@ -36,7 +35,7 @@ from langflow.schema.message import Message
def test_uuid_generator_component():
# Arrange
uuid_generator_component = helpers.IDGeneratorComponent()
uuid_generator_component._code = Path(helpers.IDGenerator.__file__).read_text()
uuid_generator_component._code = Path(helpers.IDGenerator.__file__).read_text(encoding="utf-8")
frontend_node, _ = build_custom_component_template(uuid_generator_component)

View file

@ -2,8 +2,6 @@ from datetime import datetime
from pathlib import Path
import pytest
from sqlmodel import select
from langflow.custom.directory_reader.utils import build_custom_component_list_from_path
from langflow.initial_setup.setup import (
STARTER_FOLDER_NAME,
@ -14,6 +12,7 @@ from langflow.initial_setup.setup import (
from langflow.interface.types import aget_all_types_dict
from langflow.services.database.models.folder.model import Folder
from langflow.services.deps import session_scope
from sqlmodel import select
def test_load_starter_projects():
@ -101,10 +100,11 @@ async def test_create_or_update_starter_projects():
def find_componeny_by_name(components, name):
for category, children in components.items():
for children in components.values():
if name in children:
return children[name]
raise ValueError(f"Component {name} not found in components")
msg = f"Component {name} not found in components"
raise ValueError(msg)
def set_value(component, input_name, value):

View file

@ -1,9 +1,9 @@
import pytest
from unittest.mock import MagicMock
from kubernetes.client import V1ObjectMeta, V1Secret
from base64 import b64encode
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from kubernetes.client import V1ObjectMeta, V1Secret
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id

View file

@ -2,7 +2,6 @@ from langflow.graph import Graph
from langflow.initial_setup.setup import load_starter_projects
from langflow.load import load_flow_from_json
# TODO: UPDATE BASIC EXAMPLE
# def test_load_flow_from_json():
# """Test loading a flow from a json file"""
@ -20,9 +19,8 @@ from langflow.load import load_flow_from_json
def test_load_flow_from_json_object():
"""Test loading a flow from a json file and applying tweaks"""
_, projects = zip(*load_starter_projects())
project = projects[0]
"""Test loading a flow from a json file and applying tweaks."""
project = load_starter_projects()[0][1]
loaded = load_flow_from_json(project)
assert loaded is not None
assert isinstance(loaded, Graph)

View file

@ -1,7 +1,8 @@
import pytest
import os
import json
import os
from unittest.mock import patch
import pytest
from langflow.logging.logger import SizedLogBuffer
@ -27,8 +28,8 @@ def test_write(sized_log_buffer):
sized_log_buffer.max = 1 # Set max size to 1 for testing
sized_log_buffer.write(message)
assert len(sized_log_buffer.buffer) == 1
assert 1625097600124 == sized_log_buffer.buffer[0][0]
assert "Test log" == sized_log_buffer.buffer[0][1]
assert sized_log_buffer.buffer[0][0] == 1625097600124
assert sized_log_buffer.buffer[0][1] == "Test log"
def test_write_overflow(sized_log_buffer):
@ -38,8 +39,8 @@ def test_write_overflow(sized_log_buffer):
sized_log_buffer.write(message)
assert len(sized_log_buffer.buffer) == 2
assert 1625097601000 == sized_log_buffer.buffer[0][0]
assert 1625097602000 == sized_log_buffer.buffer[1][0]
assert sized_log_buffer.buffer[0][0] == 1625097601000
assert sized_log_buffer.buffer[1][0] == 1625097602000
def test_len(sized_log_buffer):

View file

@ -1,9 +1,8 @@
import pytest
from sqlalchemy.exc import IntegrityError
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.user import User
from langflow.services.deps import session_scope
from sqlalchemy.exc import IntegrityError
@pytest.fixture

View file

@ -1,5 +1,4 @@
import pytest
from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message
from langflow.schema.message import Message
@ -10,17 +9,16 @@ from langflow.services.deps import session_scope
from langflow.services.tracing.utils import convert_to_langchain_type
@pytest.fixture()
@pytest.fixture
def created_message():
with session_scope() as session:
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
messagetable = MessageTable.model_validate(message, from_attributes=True)
messagetables = add_messagetables([messagetable], session)
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
return message_read
return MessageRead.model_validate(messagetables[0], from_attributes=True)
@pytest.fixture()
@pytest.fixture
def created_messages(session):
with session_scope() as session:
messages = [
@ -30,10 +28,7 @@ def created_messages(session):
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
messagetables = add_messagetables(messagetables, session)
messages_read = [
MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables
]
return messages_read
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables]
@pytest.mark.usefixtures("client")
@ -87,10 +82,10 @@ def test_convert_to_langchain(method_name):
def convert(value):
if method_name == "message":
return value.to_lc_message()
elif method_name == "convert_to_langchain_type":
if method_name == "convert_to_langchain_type":
return convert_to_langchain_type(value)
else:
raise ValueError(f"Invalid method: {method_name}")
msg = f"Invalid method: {method_name}"
raise ValueError(msg)
lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"))
assert lc_message.content == "Test message 1"

View file

@ -2,7 +2,6 @@ from uuid import UUID
import pytest
from httpx import AsyncClient
from langflow.memory import add_messagetables
# Assuming you have these imports available
@ -11,17 +10,16 @@ from langflow.services.database.models.message.model import MessageTable
from langflow.services.deps import session_scope
@pytest.fixture()
@pytest.fixture
async def created_message():
with session_scope() as session:
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id")
messagetable = MessageTable.model_validate(message, from_attributes=True)
messagetables = add_messagetables([messagetable], session)
message_read = MessageRead.model_validate(messagetables[0], from_attributes=True)
return message_read
return MessageRead.model_validate(messagetables[0], from_attributes=True)
@pytest.fixture()
@pytest.fixture
def created_messages(session):
with session_scope() as session:
messages = [
@ -30,9 +28,7 @@ def created_messages(session):
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"),
]
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages]
message_list = add_messagetables(messagetables, session)
return message_list
return add_messagetables(messagetables, session)
@pytest.mark.api_key_required

View file

@ -1,5 +1,4 @@
import pytest
from langflow.processing.process import process_tweaks
from langflow.services.deps import get_session_service

View file

@ -1,14 +1,13 @@
from collections.abc import Sequence as SequenceABC
from types import NoneType
from typing import Union
from langflow.schema.data import Data
import pytest
from pydantic import ValidationError
from langflow.schema.data import Data
from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from collections.abc import Sequence as SequenceABC
from pydantic import ValidationError
class TestInput:
@ -68,8 +67,8 @@ class TestInput:
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType}
# Handling complex nested structures
assert set(post_process_type(Union[SequenceABC[Union[int, str]], list[float]])) == {int, str, float}
assert set(post_process_type(Union[Union[Union[int, list[str]], list[float]], str])) == {int, str, float}
assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float}
assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float}
# Non-generic types should return as is
assert set(post_process_type(dict)) == {dict}
@ -87,12 +86,12 @@ class TestInput:
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)}
# Multiple Data types combined
assert set(post_process_type(Union[Data, Union[str, float]])) == {Data, str, float}
assert set(post_process_type(Union[Data, str | float])) == {Data, str, float}
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str}
# Testing with nested unions and lists
assert set(post_process_type(Union[list[Data], list[Union[int, str]]])) == {Data, int, str}
assert set(post_process_type(Data | list[Union[float, str]])) == {Data, float, str}
assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str}
assert set(post_process_type(Data | list[float | str])) == {Data, float, str}
def test_input_to_dict(self):
input_obj = Input(field_type="str")

View file

@ -1,8 +1,8 @@
import pytest
import threading
from langflow.services.telemetry.opentelemetry import OpenTelemetry
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from langflow.services.telemetry.opentelemetry import OpenTelemetry
fixed_labels = {"flow_id": "this_flow_id", "service": "this", "user": "that"}
@ -72,9 +72,9 @@ def test_missing_labels(opentelemetry_instance):
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
opentelemetry_instance.up_down_counter("num_files_uploaded", 1, None)
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels=dict())
opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels={})
with pytest.raises(ValueError, match="Labels must be provided for the metric"):
opentelemetry_instance.observe_histogram("num_files_uploaded", 1, dict())
opentelemetry_instance.observe_histogram("num_files_uploaded", 1, {})
def test_multithreaded_singleton():

View file

@ -7,13 +7,13 @@ from pydantic import BaseModel
# Dummy classes for testing purposes
class Parent(BaseModel):
"""Parent Class"""
"""Parent Class."""
parent_field: str
class Child(Parent):
"""Child Class"""
"""Child Class."""
child_field: int
@ -85,7 +85,7 @@ def test_get_default_factory():
return "default_value"
# Add dummy_function to your_module
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
importlib.import_module(module_name).dummy_function = dummy_function
default_value = get_default_factory(module_name, function_repr)

View file

@ -2,13 +2,12 @@ from datetime import datetime
import pytest
from httpx import AsyncClient
from sqlmodel import select
from langflow.services.auth.utils import create_super_user, get_password_hash
from langflow.services.database.models.user import UserUpdate
from langflow.services.database.models.user.model import User
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from sqlmodel import select
@pytest.fixture
@ -86,7 +85,7 @@ async def test_user_waiting_for_approval(client):
with session_getter(get_db_service()) as session:
existing_user = session.exec(select(User).where(User.username == username)).first()
if existing_user:
print(f"User {username} still exists after the test. This is expected.")
pass
else:
pytest.fail(f"User {username} does not exist after the test. This is unexpected.")

View file

@ -2,9 +2,8 @@ from pathlib import Path
from unittest import mock
import pytest
from requests.exceptions import MissingSchema
from langflow.utils.validate import create_function, execute_function, extract_function_name, validate_code
from requests.exceptions import MissingSchema
def test_create_function():
@ -99,6 +98,5 @@ import requests
def my_function(x):
return requests.get(x).text
"""
with mock.patch("requests.get", side_effect=MissingSchema):
with pytest.raises(MissingSchema):
execute_function(code, "my_function", "invalid_url")
with mock.patch("requests.get", side_effect=MissingSchema), pytest.raises(MissingSchema):
execute_function(code, "my_function", "invalid_url")

View file

@ -9,9 +9,9 @@ def test_version():
def test_compute_main():
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.post0")
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.a1")
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.b112")
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.rc0")
assert "1.0.10" == _compute_non_prerelease_version("1.0.10.dev9")
assert "1.0.10" == _compute_non_prerelease_version("1.0.10")
assert _compute_non_prerelease_version("1.0.10.post0") == "1.0.10"
assert _compute_non_prerelease_version("1.0.10.a1") == "1.0.10"
assert _compute_non_prerelease_version("1.0.10.b112") == "1.0.10"
assert _compute_non_prerelease_version("1.0.10.rc0") == "1.0.10"
assert _compute_non_prerelease_version("1.0.10.dev9") == "1.0.10"
assert _compute_non_prerelease_version("1.0.10") == "1.0.10"

View file

@ -3,7 +3,7 @@ from langflow.utils.connection_string_parser import transform_connection_string
@pytest.mark.parametrize(
"connection_string, expected",
("connection_string", "expected"),
[
("protocol:user:password@host", "protocol:user:password@host"),
("protocol:user@host", "protocol:user@host"),

View file

@ -3,7 +3,7 @@ from langflow.base.data.utils import format_directory_path
@pytest.mark.parametrize(
"input_path, expected",
("input_path", "expected"),
[
# Test case 1: Standard path with no newlines (no change expected)
("/home/user/documents/file.txt", "/home/user/documents/file.txt"),

View file

@ -1,9 +1,9 @@
from langflow.base.data.utils import format_directory_path
import pytest
from langflow.base.data.utils import format_directory_path
@pytest.mark.parametrize(
"input_path, expected",
("input_path", "expected"),
[
# Test case 1: Standard path with no newlines
("/home/user/documents/file.txt", "/home/user/documents/file.txt"),

View file

@ -1,9 +1,11 @@
import math
import pytest
from langflow.utils.util_strings import truncate_long_strings
@pytest.mark.parametrize(
"input_data, max_length, expected",
("input_data", "max_length", "expected"),
[
# Test case 1: String shorter than max_length
("short string", 20, "short string"),
@ -20,7 +22,7 @@ from langflow.utils.util_strings import truncate_long_strings
# Test case 7: Integer input
(12345, 3, 12345),
# Test case 8: Float input
(3.14159, 4, 3.14159),
(math.pi, 4, math.pi),
# Test case 9: Boolean input
(True, 2, True),
# Test case 10: None input

View file

@ -1,10 +1,10 @@
from langflow.utils.util_strings import truncate_long_strings
from langflow.utils.constants import MAX_TEXT_LENGTH
import pytest
from langflow.utils.constants import MAX_TEXT_LENGTH
from langflow.utils.util_strings import truncate_long_strings
@pytest.mark.parametrize(
"input_data, max_length, expected",
("input_data", "max_length", "expected"),
[
# Test case 1: Simple string truncation
({"key": "a" * 100}, 10, {"key": "a" * 10 + "..."}),