Merge remote-tracking branch 'origin/dev' into v2
This commit is contained in:
commit
1111bfa45d
334 changed files with 17982 additions and 6406 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
|
|
@ -1,10 +1,16 @@
|
|||
from contextlib import contextmanager
|
||||
import json
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, TYPE_CHECKING
|
||||
from langflow.api.v1.flows import get_session
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.flow.flow import Flow, FlowCreate
|
||||
from langflow.services.database.models.user.user import User, UserCreate
|
||||
import orjson
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
|
@ -12,8 +18,11 @@ from sqlmodel import SQLModel, Session, create_engine
|
|||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
# we need to import tmpdir
|
||||
import tempfile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.manager import DatabaseManager
|
||||
from langflow.services.database.manager import DatabaseService
|
||||
|
||||
|
||||
def pytest_configure():
|
||||
|
|
@ -26,7 +35,12 @@ def pytest_configure():
|
|||
pytest.OPENAPI_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
)
|
||||
|
||||
pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = (
|
||||
Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json"
|
||||
)
|
||||
pytest.VECTOR_STORE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Vector_store.json"
|
||||
)
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
retun "Hello World"
|
||||
|
|
@ -42,15 +56,62 @@ async def async_client() -> AsyncGenerator:
|
|||
yield client
|
||||
|
||||
|
||||
# Create client fixture for FastAPI
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class Config:
|
||||
broker_url = "redis://localhost:6379/0"
|
||||
result_backend = "redis://localhost:6379/0"
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_env")
|
||||
def setup_env(monkeypatch):
|
||||
monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_DB", "0")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_EXPIRE", "3600")
|
||||
monkeypatch.setenv("LANGFLOW_REDIS_PASSWORD", "")
|
||||
monkeypatch.setenv("FLOWER_UNAUTHENTICATED_API", "True")
|
||||
monkeypatch.setenv("BROKER_URL", "redis://result_backend:6379/0")
|
||||
monkeypatch.setenv("RESULT_BACKEND", "redis://result_backend:6379/0")
|
||||
monkeypatch.setenv("C_FORCE_ROOT", "true")
|
||||
|
||||
|
||||
@pytest.fixture(name="distributed_client")
|
||||
def distributed_client_fixture(session: Session, monkeypatch, distributed_env):
|
||||
# Here we load the .env from ../deploy/.env
|
||||
from langflow.core import celery_app
|
||||
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
# monkeypatch langflow.services.task.manager.USE_CELERY to True
|
||||
# monkeypatch.setattr(manager, "USE_CELERY", True)
|
||||
monkeypatch.setattr(
|
||||
celery_app, "celery_app", celery_app.make_celery("langflow", Config)
|
||||
)
|
||||
|
||||
# def get_session_override():
|
||||
# return session
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
|
|
@ -98,55 +159,53 @@ def json_flow():
|
|||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
@pytest.fixture
|
||||
def json_flow_with_prompt_and_history():
|
||||
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@pytest.fixture(name="client", scope="function", autouse=True)
|
||||
def client_fixture(session: Session):
|
||||
def get_session_override():
|
||||
return session
|
||||
=======
|
||||
@pytest.fixture
|
||||
def json_vector_store():
|
||||
with open(pytest.VECTOR_STORE_PATH, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture(name="client", autouse=True)
|
||||
def client_fixture(session: Session, monkeypatch):
|
||||
# Set the database url to a test database
|
||||
db_dir = tempfile.mkdtemp()
|
||||
db_path = Path(db_dir) / "test.db"
|
||||
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
|
||||
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
|
||||
>>>>>>> origin/dev
|
||||
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
app.dependency_overrides[get_session] = get_session_override
|
||||
# app.dependency_overrides[get_session] = get_session_override
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# @contextmanager
|
||||
# def session_getter():
|
||||
# try:
|
||||
# session = Session(engine)
|
||||
# yield session
|
||||
# except Exception as e:
|
||||
# print("Session rollback because of exception:", e)
|
||||
# session.rollback()
|
||||
# raise
|
||||
# finally:
|
||||
# session.close()
|
||||
# app.dependency_overrides.clear()
|
||||
monkeypatch.undo()
|
||||
# clear the temp db
|
||||
with suppress(FileNotFoundError):
|
||||
db_path.unlink()
|
||||
|
||||
|
||||
# create a fixture for session_getter above
|
||||
@pytest.fixture(name="session_getter")
|
||||
def session_getter_fixture(client):
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
@contextmanager
|
||||
def blank_session_getter(db_manager: "DatabaseManager"):
|
||||
with Session(db_manager.engine) as session:
|
||||
def blank_session_getter(db_service: "DatabaseService"):
|
||||
with Session(db_service.engine) as session:
|
||||
yield session
|
||||
|
||||
yield blank_session_getter
|
||||
|
|
@ -155,3 +214,90 @@ def session_getter_fixture(client):
|
|||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(client):
|
||||
user_data = UserCreate(
|
||||
username="testuser",
|
||||
password="testpassword",
|
||||
)
|
||||
response = client.post("/api/v1/users", json=user_data.dict())
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def active_user(client):
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
user = User(
|
||||
username="activeuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
# check if user exists
|
||||
if (
|
||||
active_user := session.query(User)
|
||||
.filter(User.username == user.username)
|
||||
.first()
|
||||
):
|
||||
return active_user
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_headers(client, active_user):
|
||||
login_data = {"username": active_user.username, "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flow(client, json_flow: str, active_user):
|
||||
from langflow.services.database.models.flow.flow import FlowCreate
|
||||
|
||||
loaded_json = json.loads(json_flow)
|
||||
flow_data = FlowCreate(
|
||||
name="test_flow", data=loaded_json.get("data"), user_id=active_user.id
|
||||
)
|
||||
flow = Flow(**flow_data.dict())
|
||||
with session_getter(get_db_service()) as session:
|
||||
session.add(flow)
|
||||
session.commit()
|
||||
session.refresh(flow)
|
||||
|
||||
return flow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers):
|
||||
flow = orjson.loads(json_flow_with_prompt_and_history)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Basic Chat", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def added_vector_store(client, json_vector_store, logged_in_headers):
|
||||
vector_store = orjson.loads(json_vector_store)
|
||||
data = vector_store["data"]
|
||||
vector_store = FlowCreate(name="Vector Store", description="description", data=data)
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == vector_store.name
|
||||
assert response.json()["data"] == vector_store.data
|
||||
return response.json()
|
||||
|
|
|
|||
1
tests/data/BasicChatwithPromptandHistory.json
Normal file
1
tests/data/BasicChatwithPromptandHistory.json
Normal file
File diff suppressed because one or more lines are too long
1283
tests/data/Vector_store.json
Normal file
1283
tests/data/Vector_store.json
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -236,7 +236,7 @@
|
|||
"placeholder": "",
|
||||
"show": true,
|
||||
"multiline": false,
|
||||
"value": "abc",
|
||||
"value": null,
|
||||
"password": true,
|
||||
"name": "openai_api_key",
|
||||
"display_name": "OpenAI API Key",
|
||||
|
|
|
|||
132
tests/locust/locustfile.py
Normal file
132
tests/locust/locustfile.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
from locust import FastHttpUser, task, between
|
||||
import random
|
||||
import time
|
||||
import orjson
|
||||
from rich import print
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class NameTest(FastHttpUser):
|
||||
wait_time = between(1, 5)
|
||||
|
||||
with open("names.txt", "r") as file:
|
||||
names = [line.strip() for line in file.readlines()]
|
||||
|
||||
headers = {}
|
||||
|
||||
def poll_task(self, task_id, sleep_time=1):
|
||||
while True:
|
||||
with self.rest(
|
||||
"GET",
|
||||
f"/task/{task_id}",
|
||||
name="task_status",
|
||||
headers=self.headers,
|
||||
) as response:
|
||||
status = response.js.get("status")
|
||||
print(f"Poll Response: {response.js}")
|
||||
if status == "SUCCESS":
|
||||
return response.js.get("result")
|
||||
elif status in ["FAILURE", "REVOKED"]:
|
||||
raise ValueError(f"Task failed with status: {status}")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def process(self, name, flow_id, payload):
|
||||
task_id = None
|
||||
print(f"Processing {payload}")
|
||||
with self.rest(
|
||||
"POST",
|
||||
f"/process/{flow_id}",
|
||||
json=payload,
|
||||
name="process",
|
||||
headers=self.headers,
|
||||
) as response:
|
||||
print(response.js)
|
||||
if response.status_code != 200:
|
||||
response.failure("Process call failed")
|
||||
raise ValueError("Process call failed")
|
||||
task_id = response.js.get("id")
|
||||
session_id = response.js.get("session_id")
|
||||
assert task_id, "Inner Task ID not found"
|
||||
|
||||
assert task_id, "Task ID not found"
|
||||
result = self.poll_task(task_id)
|
||||
print(f"Result for {name}: {result}")
|
||||
|
||||
return result, session_id
|
||||
|
||||
@task
|
||||
def send_name_and_check(self):
|
||||
name = random.choice(self.names)
|
||||
|
||||
payload1 = {
|
||||
"inputs": {"text": f"Hello, My name is {name}"},
|
||||
"sync": False,
|
||||
}
|
||||
result1, session_id = self.process(name, self.flow_id, payload1)
|
||||
|
||||
payload2 = {
|
||||
"inputs": {
|
||||
"text": "What is my name? Please, answer like this: Your name is <name>"
|
||||
},
|
||||
"session_id": session_id,
|
||||
"sync": False,
|
||||
}
|
||||
result2, session_id = self.process(name, self.flow_id, payload2)
|
||||
|
||||
assert f"Your name is {name}" in str(result2), "Name not found in response"
|
||||
|
||||
def on_start(self):
|
||||
print("Starting")
|
||||
login_data = {"username": "superuser", "password": "superuser"}
|
||||
response = httpx.post(f"{self.host}/login", data=login_data)
|
||||
print(response.json())
|
||||
|
||||
tokens = response.json()
|
||||
print(tokens)
|
||||
a_token = tokens["access_token"]
|
||||
logged_in_headers = {"Authorization": f"Bearer {a_token}"}
|
||||
print("Logged in")
|
||||
with open(
|
||||
Path(__file__).parent.parent
|
||||
/ "data"
|
||||
/ "BasicChatwithPromptandHistory.json",
|
||||
"r",
|
||||
) as f:
|
||||
json_flow = f.read()
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow = {"name": "Flow 1", "description": "description", "data": data}
|
||||
print("Creating flow")
|
||||
# Make request to endpoint
|
||||
response = httpx.post(
|
||||
f"{self.host}/flows/",
|
||||
json=flow,
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
self.flow_id = response.json()["id"]
|
||||
print(f"Flow ID: {self.flow_id}")
|
||||
|
||||
# read all users
|
||||
response = httpx.get(
|
||||
f"{self.host}/users/",
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
print(response.json())
|
||||
user_id = next(
|
||||
(
|
||||
user["id"]
|
||||
for user in response.json()["users"]
|
||||
if user["username"] == "superuser"
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Create api key
|
||||
response = httpx.post(
|
||||
f"{self.host}/api_key/",
|
||||
json={"user_id": user_id},
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
print(response.json())
|
||||
self.headers["x-api-key"] = response.json()["api_key"]
|
||||
5
tests/locust/names.txt
Normal file
5
tests/locust/names.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
Bob
|
||||
Alice
|
||||
John
|
||||
Gabriel
|
||||
Lily
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_zero_shot_agent(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -113,8 +113,8 @@ def test_zero_shot_agent(client: TestClient):
|
|||
}
|
||||
|
||||
|
||||
def test_json_agent(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_json_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -152,8 +152,8 @@ def test_json_agent(client: TestClient):
|
|||
}
|
||||
|
||||
|
||||
def test_csv_agent(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_csv_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -195,8 +195,8 @@ def test_csv_agent(client: TestClient):
|
|||
}
|
||||
|
||||
|
||||
def test_initialize_agent(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_initialize_agent(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
|
|||
50
tests/test_api_key.py
Normal file
50
tests/test_api_key.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
from langflow.services.database.models.api_key import ApiKeyCreate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(client, logged_in_headers, active_user):
|
||||
api_key = ApiKeyCreate(name="test-api-key")
|
||||
|
||||
response = client.post(
|
||||
"api/v1/api_key", data=api_key.json(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_get_api_keys(client, logged_in_headers, api_key):
|
||||
response = client.get("api/v1/api_key", headers=logged_in_headers)
|
||||
assert response.status_code == 200, response.text
|
||||
data = response.json()
|
||||
assert "total_count" in data
|
||||
assert "user_id" in data
|
||||
assert "api_keys" in data
|
||||
assert any("test-api-key" in api_key["name"] for api_key in data["api_keys"])
|
||||
# assert all api keys in data["api_keys"] are masked
|
||||
assert all("**" in api_key["api_key"] for api_key in data["api_keys"])
|
||||
# Add more assertions as needed based on the expected data structure and content
|
||||
|
||||
|
||||
def test_create_api_key(client, logged_in_headers):
|
||||
api_key_name = "test-api-key"
|
||||
response = client.post(
|
||||
"api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data and data["name"] == api_key_name
|
||||
assert "api_key" in data
|
||||
# When creating the API key is returned which is
|
||||
# the only time the API key is unmasked
|
||||
assert "**" not in data["api_key"]
|
||||
|
||||
|
||||
def test_delete_api_key(client, logged_in_headers, active_user, api_key):
|
||||
# Assuming a function to create a test API key, returning the key ID
|
||||
api_key_id = api_key["id"]
|
||||
response = client.delete(f"api/v1/api_key/{api_key_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["detail"] == "API Key deleted"
|
||||
# Optionally, add a follow-up check to ensure that the key is actually removed from the database
|
||||
|
|
@ -2,9 +2,6 @@ import json
|
|||
from langflow.graph import Graph
|
||||
|
||||
import pytest
|
||||
from langflow.interface.run import (
|
||||
build_langchain_object_with_caching,
|
||||
)
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
|
|
@ -40,32 +37,9 @@ def langchain_objects_are_equal(obj1, obj2):
|
|||
return str(obj1) == str(obj2)
|
||||
|
||||
|
||||
# Test build_langchain_object_with_caching
|
||||
def test_build_langchain_object_with_caching(basic_data_graph):
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
graph = build_langchain_object_with_caching(basic_data_graph)
|
||||
assert graph is not None
|
||||
|
||||
|
||||
# Test build_graph
|
||||
def test_build_graph(basic_data_graph):
|
||||
def test_build_graph(client, basic_data_graph):
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.edges) == len(basic_data_graph["edges"])
|
||||
|
||||
|
||||
# Test cache size limit
|
||||
def test_cache_size_limit(basic_data_graph):
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
for i in range(11):
|
||||
modified_data_graph = basic_data_graph.copy()
|
||||
nodes = modified_data_graph["nodes"]
|
||||
node_id = nodes[0]["id"]
|
||||
# Now we replace all instances ode node_id with a new id in the json
|
||||
json_string = json.dumps(modified_data_graph)
|
||||
modified_json_string = json_string.replace(node_id, f"{node_id}_{i}")
|
||||
modified_data_graph_new_id = json.loads(modified_json_string)
|
||||
build_langchain_object_with_caching(modified_data_graph_new_id)
|
||||
|
||||
assert len(build_langchain_object_with_caching.cache) == 10
|
||||
|
|
|
|||
|
|
@ -2,81 +2,81 @@ from io import StringIO
|
|||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from langflow.services.cache.manager import CacheManager
|
||||
from langflow.services.chat.cache import CacheService
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_manager():
|
||||
return CacheManager()
|
||||
def cache_service():
|
||||
return CacheService()
|
||||
|
||||
|
||||
def test_cache_manager_attach_detach_notify(cache_manager):
|
||||
def test_cache_service_attach_detach_notify(cache_service):
|
||||
observer_called = False
|
||||
|
||||
def observer():
|
||||
nonlocal observer_called
|
||||
observer_called = True
|
||||
|
||||
cache_manager.attach(observer)
|
||||
cache_manager.notify()
|
||||
cache_service.attach(observer)
|
||||
cache_service.notify()
|
||||
|
||||
assert observer_called
|
||||
|
||||
observer_called = False
|
||||
cache_manager.detach(observer)
|
||||
cache_manager.notify()
|
||||
cache_service.detach(observer)
|
||||
cache_service.notify()
|
||||
|
||||
assert not observer_called
|
||||
|
||||
|
||||
def test_cache_manager_client_context(cache_manager):
|
||||
with cache_manager.set_client_id("client1"):
|
||||
cache_manager.add("foo", "bar", "string")
|
||||
assert cache_manager.get("foo") == {
|
||||
def test_cache_service_client_context(cache_service):
|
||||
with cache_service.set_client_id("client1"):
|
||||
cache_service.add("foo", "bar", "string")
|
||||
assert cache_service.get("foo") == {
|
||||
"obj": "bar",
|
||||
"type": "string",
|
||||
"extension": "str",
|
||||
}
|
||||
|
||||
with cache_manager.set_client_id("client2"):
|
||||
cache_manager.add("baz", "qux", "string")
|
||||
assert cache_manager.get("baz") == {
|
||||
with cache_service.set_client_id("client2"):
|
||||
cache_service.add("baz", "qux", "string")
|
||||
assert cache_service.get("baz") == {
|
||||
"obj": "qux",
|
||||
"type": "string",
|
||||
"extension": "str",
|
||||
}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
cache_manager.get("foo")
|
||||
cache_service.get("foo")
|
||||
|
||||
|
||||
def test_cache_manager_add_pandas(cache_manager):
|
||||
def test_cache_service_add_pandas(cache_service):
|
||||
df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
|
||||
|
||||
with cache_manager.set_client_id("client1"):
|
||||
cache_manager.add_pandas("test_df", df)
|
||||
cached_df = cache_manager.get("test_df")
|
||||
with cache_service.set_client_id("client1"):
|
||||
cache_service.add_pandas("test_df", df)
|
||||
cached_df = cache_service.get("test_df")
|
||||
assert cached_df["type"] == "pandas"
|
||||
assert cached_df["extension"] == "csv"
|
||||
read_df = pd.read_csv(StringIO(cached_df["obj"]), index_col=0)
|
||||
pd.testing.assert_frame_equal(df, read_df)
|
||||
|
||||
|
||||
def test_cache_manager_add_image(cache_manager):
|
||||
def test_cache_service_add_image(cache_service):
|
||||
img = Image.new("RGB", (50, 50), color="red")
|
||||
|
||||
with cache_manager.set_client_id("client1"):
|
||||
cache_manager.add_image("test_image", img)
|
||||
cached_img = cache_manager.get("test_image")
|
||||
with cache_service.set_client_id("client1"):
|
||||
cache_service.add_image("test_image", img)
|
||||
cached_img = cache_service.get("test_image")
|
||||
assert cached_img["type"] == "image"
|
||||
assert cached_img["extension"] == "png"
|
||||
assert isinstance(cached_img["obj"], Image.Image)
|
||||
|
||||
|
||||
def test_cache_manager_get_last(cache_manager):
|
||||
with cache_manager.set_client_id("client1"):
|
||||
cache_manager.add("foo", "bar", "string")
|
||||
cache_manager.add("baz", "qux", "string")
|
||||
last_item = cache_manager.get_last()
|
||||
def test_cache_service_get_last(cache_service):
|
||||
with cache_service.set_client_id("client1"):
|
||||
cache_service.add("foo", "bar", "string")
|
||||
cache_service.add("baz", "qux", "string")
|
||||
last_item = cache_service.get_last()
|
||||
assert last_item == {"obj": "qux", "type": "string", "extension": "str"}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# def test_chains_settings(client: TestClient):
|
||||
# response = client.get("api/v1/all")
|
||||
# def test_chains_settings(client: TestClient, logged_in_headers):
|
||||
# response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
# assert response.status_code == 200
|
||||
# json_response = response.json()
|
||||
# chains = json_response["chains"]
|
||||
|
|
@ -10,8 +10,8 @@ from fastapi.testclient import TestClient
|
|||
|
||||
|
||||
# Test the ConversationChain object
|
||||
def test_conversation_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_conversation_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -102,8 +102,8 @@ def test_conversation_chain(client: TestClient):
|
|||
)
|
||||
|
||||
|
||||
def test_llm_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_llm_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -173,8 +173,8 @@ def test_llm_chain(client: TestClient):
|
|||
}
|
||||
|
||||
|
||||
def test_llm_checker_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -207,8 +207,8 @@ def test_llm_checker_chain(client: TestClient):
|
|||
assert chain["description"] == ""
|
||||
|
||||
|
||||
def test_llm_math_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_llm_math_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -299,8 +299,8 @@ def test_llm_math_chain(client: TestClient):
|
|||
)
|
||||
|
||||
|
||||
def test_series_character_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_series_character_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -367,8 +367,8 @@ def test_series_character_chain(client: TestClient):
|
|||
)
|
||||
|
||||
|
||||
def test_mid_journey_prompt_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -408,8 +408,8 @@ def test_mid_journey_prompt_chain(client: TestClient):
|
|||
)
|
||||
|
||||
|
||||
def test_time_travel_guide_chain(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_time_travel_guide_chain(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from tempfile import tempdir
|
|||
from langflow.__main__ import app
|
||||
import pytest
|
||||
|
||||
from langflow.services import utils
|
||||
from langflow.services import getters
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
@ -23,8 +23,19 @@ def test_components_path(runner, client, default_settings):
|
|||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["--components-path", str(temp_dir), *default_settings],
|
||||
["run", "--components-path", str(temp_dir), *default_settings],
|
||||
)
|
||||
assert result.exit_code == 0, result.stdout
|
||||
<<<<<<< HEAD
|
||||
settings_manager = utils.get_settings_manager()
|
||||
assert str(temp_dir) in settings_manager.settings.COMPONENTS_PATH
|
||||
=======
|
||||
settings_service = getters.get_settings_service()
|
||||
assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH
|
||||
|
||||
|
||||
def test_superuser(runner, client, session):
|
||||
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
|
||||
assert result.exit_code == 0, result.stdout
|
||||
assert "Superuser created successfully." in result.stdout
|
||||
>>>>>>> origin/dev
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def sample_agent_creator() -> AgentCreator:
|
|||
|
||||
|
||||
def test_lang_chain_type_creator_to_dict(
|
||||
client,
|
||||
sample_lang_chain_type_creator: LangChainTypeCreator,
|
||||
):
|
||||
type_dict = sample_lang_chain_type_creator.to_dict()
|
||||
|
|
|
|||
|
|
@ -473,15 +473,16 @@ def test_build_config_no_code():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def component():
|
||||
def component(client, active_user):
|
||||
return CustomComponent(
|
||||
user_id=active_user.id,
|
||||
field_config={
|
||||
"fields": {
|
||||
"llm": {"type": "str"},
|
||||
"url": {"type": "str"},
|
||||
"year": {"type": "int"},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -517,13 +518,13 @@ def db(app):
|
|||
app.db.drop_all()
|
||||
|
||||
|
||||
def test_list_flows_return_type(component, session_getter):
|
||||
flows = component.list_flows(get_session=session_getter)
|
||||
def test_list_flows_return_type(component):
|
||||
flows = component.list_flows()
|
||||
assert isinstance(flows, list)
|
||||
|
||||
|
||||
def test_list_flows_flow_objects(component, session_getter):
|
||||
flows = component.list_flows(get_session=session_getter)
|
||||
def test_list_flows_flow_objects(component):
|
||||
flows = component.list_flows()
|
||||
assert all(isinstance(flow, Flow) for flow in flows)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import json
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
|
@ -16,7 +19,7 @@ def json_style():
|
|||
# color: str = Field(index=True)
|
||||
# emoji: str = Field(index=False)
|
||||
# flow_id: UUID = Field(default=None, foreign_key="flow.id")
|
||||
return json.dumps(
|
||||
return orjson_dumps(
|
||||
{
|
||||
"color": "red",
|
||||
"emoji": "👍",
|
||||
|
|
@ -24,63 +27,69 @@ def json_style():
|
|||
)
|
||||
|
||||
|
||||
def test_create_flow(client: TestClient, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_create_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
# flow is optional so we can create a flow without a flow
|
||||
flow = FlowCreate(name="Test Flow")
|
||||
response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True))
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_read_flows(client: TestClient, json_flow: str):
|
||||
flow_data = json.loads(json_flow)
|
||||
def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
response = client.get("api/v1/flows/")
|
||||
response = client.get("api/v1/flows/", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) > 0
|
||||
|
||||
|
||||
def test_read_flow(client: TestClient, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"] # flow_id should be a UUID but is a string
|
||||
# turn it into a UUID
|
||||
flow_id = UUID(flow_id)
|
||||
|
||||
response = client.get(f"api/v1/flows/{flow_id}")
|
||||
response = client.get(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == flow.name
|
||||
assert response.json()["data"] == flow.data
|
||||
|
||||
|
||||
def test_update_flow(client: TestClient, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_update_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowUpdate(
|
||||
|
|
@ -88,7 +97,9 @@ def test_update_flow(client: TestClient, json_flow: str):
|
|||
description="updated description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict())
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == updated_flow.name
|
||||
|
|
@ -96,19 +107,23 @@ def test_update_flow(client: TestClient, json_flow: str):
|
|||
# assert response.json()["data"] == updated_flow.data
|
||||
|
||||
|
||||
def test_delete_flow(client: TestClient, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_delete_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
flow = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow.dict())
|
||||
response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers)
|
||||
flow_id = response.json()["id"]
|
||||
response = client.delete(f"api/v1/flows/{flow_id}")
|
||||
response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Flow deleted successfully"
|
||||
|
||||
|
||||
def test_create_flows(client: TestClient, session: Session, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_create_flows(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
|
|
@ -118,7 +133,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str):
|
|||
]
|
||||
)
|
||||
# Make request to endpoint
|
||||
response = client.post("api/v1/flows/batch/", json=flow_list.dict())
|
||||
response = client.post(
|
||||
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
|
||||
)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
# Check response data
|
||||
|
|
@ -132,8 +149,10 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str):
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_upload_file(client: TestClient, session: Session, json_flow: str):
|
||||
flow = json.loads(json_flow)
|
||||
def test_upload_file(
|
||||
client: TestClient, session: Session, json_flow: str, logged_in_headers
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
|
|
@ -142,10 +161,11 @@ def test_upload_file(client: TestClient, session: Session, json_flow: str):
|
|||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
file_contents = json.dumps(flow_list.dict())
|
||||
file_contents = orjson_dumps(flow_list.dict())
|
||||
response = client.post(
|
||||
"api/v1/flows/upload/",
|
||||
files={"file": ("examples.json", file_contents, "application/json")},
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
# Check response status code
|
||||
assert response.status_code == 201
|
||||
|
|
@ -160,8 +180,14 @@ def test_upload_file(client: TestClient, session: Session, json_flow: str):
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_download_file(client: TestClient, session: Session, json_flow):
|
||||
flow = json.loads(json_flow)
|
||||
def test_download_file(
|
||||
client: TestClient,
|
||||
session: Session,
|
||||
json_flow,
|
||||
active_user,
|
||||
logged_in_headers,
|
||||
):
|
||||
flow = orjson.loads(json_flow)
|
||||
data = flow["data"]
|
||||
# Create test data
|
||||
flow_list = FlowListCreate(
|
||||
|
|
@ -170,17 +196,20 @@ def test_download_file(client: TestClient, session: Session, json_flow):
|
|||
FlowCreate(name="Flow 2", description="description", data=data),
|
||||
]
|
||||
)
|
||||
for flow in flow_list.flows:
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = active_user.id
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
# Make request to endpoint
|
||||
response = client.get("api/v1/flows/download/")
|
||||
response = client.get("api/v1/flows/download/", headers=logged_in_headers)
|
||||
# Check response status code
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check response data
|
||||
response_data = response.json()["flows"]
|
||||
assert len(response_data) == 2
|
||||
assert len(response_data) == 2, response_data
|
||||
assert response_data[0]["name"] == "Flow 1"
|
||||
assert response_data[0]["description"] == "description"
|
||||
assert response_data[0]["data"] == data
|
||||
|
|
@ -189,32 +218,44 @@ def test_download_file(client: TestClient, session: Session, json_flow):
|
|||
assert response_data[1]["data"] == data
|
||||
|
||||
|
||||
def test_create_flow_with_invalid_data(client: TestClient):
|
||||
def test_create_flow_with_invalid_data(
|
||||
client: TestClient, active_user, logged_in_headers
|
||||
):
|
||||
flow = {"name": "a" * 256, "data": "Invalid flow data"}
|
||||
response = client.post("api/v1/flows/", json=flow)
|
||||
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_nonexistent_flow(client: TestClient):
|
||||
def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
|
||||
uuid = uuid4()
|
||||
response = client.get(f"api/v1/flows/{uuid}")
|
||||
response = client.get(f"api/v1/flows/{uuid}", headers=logged_in_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_flow_idempotency(client: TestClient, json_flow: str):
|
||||
flow_data = json.loads(json_flow)
|
||||
def test_update_flow_idempotency(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
|
||||
response = client.post("api/v1/flows/", json=flow_data.dict())
|
||||
response = client.post(
|
||||
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
flow_id = response.json()["id"]
|
||||
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
|
||||
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict())
|
||||
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict())
|
||||
response1 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
response2 = client.put(
|
||||
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response1.json() == response2.json()
|
||||
|
||||
|
||||
def test_update_nonexistent_flow(client: TestClient, json_flow: str):
|
||||
flow_data = json.loads(json_flow)
|
||||
def test_update_nonexistent_flow(
|
||||
client: TestClient, json_flow: str, active_user, logged_in_headers
|
||||
):
|
||||
flow_data = orjson.loads(json_flow)
|
||||
data = flow_data["data"]
|
||||
uuid = uuid4()
|
||||
updated_flow = FlowCreate(
|
||||
|
|
@ -222,17 +263,19 @@ def test_update_nonexistent_flow(client: TestClient, json_flow: str):
|
|||
description="description",
|
||||
data=data,
|
||||
)
|
||||
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict())
|
||||
response = client.patch(
|
||||
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_nonexistent_flow(client: TestClient):
|
||||
def test_delete_nonexistent_flow(client: TestClient, active_user, logged_in_headers):
|
||||
uuid = uuid4()
|
||||
response = client.delete(f"api/v1/flows/{uuid}")
|
||||
response = client.delete(f"api/v1/flows/{uuid}", headers=logged_in_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_read_empty_flows(client: TestClient):
|
||||
response = client.get("api/v1/flows/")
|
||||
def test_read_empty_flows(client: TestClient, active_user, logged_in_headers):
|
||||
response = client.get("api/v1/flows/", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
|
|
|||
|
|
@ -1,8 +1,44 @@
|
|||
from collections import namedtuple
|
||||
import uuid
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def run_post(client, flow_id, headers, post_data):
|
||||
response = client.post(
|
||||
f"api/v1/process/{flow_id}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
return response.json()
|
||||
|
||||
|
||||
# Helper function to poll task status
|
||||
def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
|
||||
for _ in range(max_attempts):
|
||||
task_status_response = client.get(
|
||||
href,
|
||||
headers=headers,
|
||||
)
|
||||
if (
|
||||
task_status_response.status_code == 200
|
||||
and task_status_response.json()["status"] == "SUCCESS"
|
||||
):
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
||||
|
||||
PROMPT_REQUEST = {
|
||||
"name": "string",
|
||||
|
|
@ -83,8 +119,186 @@ PROMPT_REQUEST = {
|
|||
}
|
||||
|
||||
|
||||
def test_get_all(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
@pytest.fixture
|
||||
def created_api_key(active_user):
|
||||
hashed = get_password_hash("random_key")
|
||||
api_key = ApiKey(
|
||||
name="test_api_key",
|
||||
user_id=active_user.id,
|
||||
api_key="random_key",
|
||||
hashed_api_key=hashed,
|
||||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if (
|
||||
existing_api_key := session.query(ApiKey)
|
||||
.filter(ApiKey.api_key == api_key.api_key)
|
||||
.first()
|
||||
):
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
def test_process_flow_invalid_api_key(client, flow, monkeypatch):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
|
||||
headers = {"x-api-key": "invalid_api_key"}
|
||||
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
|
||||
|
||||
def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
from langflow.api.v1 import endpoints
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
invalid_id = uuid.uuid4()
|
||||
response = client.post(
|
||||
f"api/v1/process/{invalid_id}", headers=headers, json=post_data
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert f"Flow {invalid_id} not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
def mock_process_graph_cached_task(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
# The task function is ran like this:
|
||||
# if not self.use_celery:
|
||||
# return None, await task_func(*args, **kwargs)
|
||||
# if not hasattr(task_func, "apply"):
|
||||
# raise ValueError(f"Task function {task_func} does not have an apply method")
|
||||
# task = task_func.apply(args=args, kwargs=kwargs)
|
||||
# result = task.get()
|
||||
# return task.id, result
|
||||
# So we need to mock the task function to return a task object
|
||||
# and then mock the task object to return a result
|
||||
# maybe a named tuple would be better here
|
||||
task = namedtuple("task", ["id", "get"])
|
||||
mock_process_graph_cached_task.apply = lambda *args, **kwargs: task(
|
||||
id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock")
|
||||
)
|
||||
|
||||
def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
monkeypatch.setattr(
|
||||
endpoints, "process_graph_cached_task", mock_process_graph_cached_task
|
||||
)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
|
||||
# Dummy POST data
|
||||
post_data = {
|
||||
"inputs": {"input": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
# Make the request to the FastAPI TestClient
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
# Check the response
|
||||
assert response.status_code == 200, response.json()
|
||||
assert response.json()["result"] == {}, response.json()
|
||||
assert response.json()["session_id"] == "session_id_mock", response.json()
|
||||
|
||||
|
||||
def test_process_flow_fails_autologin_off(client, flow, monkeypatch):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
async def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
|
||||
headers = {"x-api-key": "api_key"}
|
||||
|
||||
# Dummy POST data
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
# Make the request to the FastAPI TestClient
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
# Check the response
|
||||
assert response.status_code == 403, response.json()
|
||||
assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
|
||||
|
||||
def test_get_all(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
# We need to test the custom nodes
|
||||
|
|
@ -206,3 +420,180 @@ def test_various_prompts(client, prompt, expected_input_variables):
|
|||
response = client.post("api/v1/validate/prompt", json=PROMPT_REQUEST)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["input_variables"] == expected_input_variables
|
||||
|
||||
|
||||
def test_basic_chat_in_process(client, added_flow, created_api_key):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check the response
|
||||
assert "Gabriel" in response.json()["result"]["text"]
|
||||
# session_id should be returned
|
||||
assert "session_id" in response.json()
|
||||
assert response.json()["session_id"] is not None
|
||||
# New request with the same session_id
|
||||
# asking "What is my name?" should return "Gabriel"
|
||||
post_data = {
|
||||
"inputs": {"text": "What is my name?"},
|
||||
"session_id": response.json()["session_id"],
|
||||
}
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert "Gabriel" in response.json()["result"]["text"]
|
||||
|
||||
|
||||
def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check the response
|
||||
assert "Gabriel" in response.json()["result"]["text"]
|
||||
# session_id should be returned
|
||||
assert "session_id" in response.json()
|
||||
assert response.json()["session_id"] is not None
|
||||
session_id1 = response.json()["session_id"]
|
||||
# New request with a different session_id
|
||||
# asking "What is my name?" should return "Gabriel"
|
||||
post_data = {
|
||||
"inputs": {"text": "What is my name?"},
|
||||
}
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert "Gabriel" not in response.json()["result"]["text"]
|
||||
assert session_id1 != response.json()["session_id"]
|
||||
|
||||
|
||||
def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = added_flow.get("id")
|
||||
names = ["Gabriel", "John"]
|
||||
session_ids = []
|
||||
|
||||
for name in names:
|
||||
post_data = {"inputs": {"text": f"Hi, My name is {name}"}}
|
||||
response_json = run_post(client, flow_id, headers, post_data)
|
||||
|
||||
assert name in response_json["result"]["text"]
|
||||
assert "session_id" in response_json
|
||||
assert response_json["session_id"] is not None
|
||||
|
||||
session_ids.append(response_json["session_id"])
|
||||
|
||||
for i, name in enumerate(names):
|
||||
post_data = {
|
||||
"inputs": {"text": "What is my name?"},
|
||||
"session_id": session_ids[i],
|
||||
}
|
||||
response_json = run_post(client, flow_id, headers, post_data)
|
||||
|
||||
assert name in response_json["result"]["text"]
|
||||
|
||||
|
||||
@pytest.mark.async_test
|
||||
def test_vector_store_in_process(
|
||||
distributed_client, added_vector_store, created_api_key
|
||||
):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "What is Langflow?"}}
|
||||
response = distributed_client.post(
|
||||
f"api/v1/process/{added_vector_store.get('id')}",
|
||||
headers=headers,
|
||||
json=post_data,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
# Check the response
|
||||
assert "Langflow" in response.json()["result"]["output"]
|
||||
# session_id should be returned
|
||||
assert "session_id" in response.json()
|
||||
assert response.json()["session_id"] is not None
|
||||
|
||||
|
||||
# Test function without loop
|
||||
@pytest.mark.async_test
|
||||
def test_async_task_processing(distributed_client, added_flow, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
|
||||
|
||||
# Run the /api/v1/process/{flow_id} endpoint with sync=False
|
||||
response = distributed_client.post(
|
||||
f"api/v1/process/{added_flow.get('id')}",
|
||||
headers=headers,
|
||||
json={**post_data, "sync": False},
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Extract the task ID from the response
|
||||
task = response.json().get("task")
|
||||
task_id = task.get("id")
|
||||
task_href = task.get("href")
|
||||
assert task_id is not None
|
||||
assert task_href is not None
|
||||
assert task_href == f"api/v1/task/{task_id}"
|
||||
|
||||
# Polling the task status using the helper function
|
||||
task_status_json = poll_task_status(distributed_client, headers, task_href)
|
||||
assert task_status_json is not None, "Task did not complete in time"
|
||||
|
||||
# Validate that the task completed successfully and the result is as expected
|
||||
assert "result" in task_status_json, task_status_json
|
||||
assert "text" in task_status_json["result"], task_status_json["result"]
|
||||
assert "Gabriel" in task_status_json["result"]["text"], task_status_json["result"]
|
||||
|
||||
|
||||
# Test function without loop
|
||||
@pytest.mark.async_test
|
||||
def test_async_task_processing_vector_store(
|
||||
client, added_vector_store, created_api_key
|
||||
):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "How do I upload examples?"}}
|
||||
|
||||
# Run the /api/v1/process/{flow_id} endpoint with sync=False
|
||||
response = client.post(
|
||||
f"api/v1/process/{added_vector_store.get('id')}",
|
||||
headers=headers,
|
||||
json={**post_data, "sync": False},
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert "result" in response.json()
|
||||
assert "FAILURE" not in response.json()["result"]
|
||||
|
||||
# Extract the task ID from the response
|
||||
task = response.json().get("task")
|
||||
task_id = task.get("id")
|
||||
task_href = task.get("href")
|
||||
assert task_id is not None
|
||||
assert task_href is not None
|
||||
assert task_href == f"api/v1/task/{task_id}"
|
||||
|
||||
# Polling the task status using the helper function
|
||||
task_status_json = poll_task_status(client, headers, task_href)
|
||||
assert task_status_json is not None, "Task did not complete in time"
|
||||
|
||||
# Validate that the task completed successfully and the result is as expected
|
||||
assert "result" in task_status_json, task_status_json
|
||||
assert "output" in task_status_json["result"], task_status_json["result"]
|
||||
assert "Langflow" in task_status_json["result"]["output"], task_status_json[
|
||||
"result"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from typing import Type, Union
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
|
@ -185,7 +187,7 @@ def test_build_edges(basic_graph):
|
|||
assert isinstance(edge.target, Vertex)
|
||||
|
||||
|
||||
def test_get_root_node(basic_graph, complex_graph):
|
||||
def test_get_root_node(client, basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
|
|
@ -261,7 +263,7 @@ def test_llm_node_build(basic_graph):
|
|||
assert built_object is not None
|
||||
|
||||
|
||||
def test_toolkit_node_build(openapi_graph):
|
||||
def test_toolkit_node_build(client, openapi_graph):
|
||||
# Write a file to the disk
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
|
|
@ -276,7 +278,7 @@ def test_toolkit_node_build(openapi_graph):
|
|||
assert not Path(file_path).exists()
|
||||
|
||||
|
||||
def test_file_tool_node_build(openapi_graph):
|
||||
def test_file_tool_node_build(client, openapi_graph):
|
||||
file_path = "api-with-examples.yaml"
|
||||
with open(file_path, "w") as f:
|
||||
f.write("openapi: 3.0.0")
|
||||
|
|
@ -318,3 +320,29 @@ def test_get_result_and_thought(basic_graph):
|
|||
# Get the result and thought
|
||||
result = get_result_and_thought(langchain_object, message)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_pickle_graph(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
first_result = graph.build()
|
||||
assert isinstance(first_result, AgentExecutor)
|
||||
pickled = pickle.dumps(graph)
|
||||
assert pickled is not None
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not None
|
||||
result = unpickled.build()
|
||||
assert isinstance(result, AgentExecutor)
|
||||
|
||||
|
||||
def test_pickle_each_vertex(json_vector_store):
|
||||
loaded_json = json.loads(json_vector_store)
|
||||
graph = Graph.from_payload(loaded_json)
|
||||
assert isinstance(graph, Graph)
|
||||
for vertex in graph.nodes:
|
||||
vertex.build()
|
||||
pickled = pickle.dumps(vertex)
|
||||
assert pickled is not None
|
||||
unpickled = pickle.loads(pickled)
|
||||
assert unpickled is not None
|
||||
|
|
|
|||
|
|
@ -1,110 +1,8 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.utils import get_settings_manager
|
||||
|
||||
|
||||
def test_llms_settings(client: TestClient):
|
||||
settings_manager = get_settings_manager()
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
llms = json_response["llms"]
|
||||
assert set(llms.keys()) == set(settings_manager.settings.LLMS)
|
||||
|
||||
|
||||
# def test_hugging_face_hub(client: TestClient):
|
||||
# response = client.get("api/v1/all")
|
||||
# assert response.status_code == 200
|
||||
# json_response = response.json()
|
||||
# language_models = json_response["llms"]
|
||||
|
||||
# model = language_models["HuggingFaceHub"]
|
||||
# template = model["template"]
|
||||
|
||||
# assert template["cache"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": False,
|
||||
# "multiline": False,
|
||||
# "password": False,
|
||||
# "name": "cache",
|
||||
# "type": "bool",
|
||||
# "list": False,
|
||||
# "advanced": False,
|
||||
# }
|
||||
# assert template["verbose"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": False,
|
||||
# "multiline": False,
|
||||
# "value": False,
|
||||
# "password": False,
|
||||
# "name": "verbose",
|
||||
# "type": "bool",
|
||||
# "list": False,
|
||||
# "advanced": False,
|
||||
# }
|
||||
# assert template["client"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": False,
|
||||
# "multiline": False,
|
||||
# "password": False,
|
||||
# "name": "client",
|
||||
# "type": "Any",
|
||||
# "list": False,
|
||||
# "advanced": False,
|
||||
# }
|
||||
# assert template["repo_id"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": True,
|
||||
# "multiline": False,
|
||||
# "value": "gpt2",
|
||||
# "password": False,
|
||||
# "name": "repo_id",
|
||||
# "type": "str",
|
||||
# "list": False,
|
||||
# "advanced": False,
|
||||
# }
|
||||
# assert template["task"] == {
|
||||
# "required": True,
|
||||
# "placeholder": "",
|
||||
# "show": True,
|
||||
# "multiline": False,
|
||||
# "password": False,
|
||||
# "options": ["text-generation", "text2text-generation"],
|
||||
# "name": "task",
|
||||
# "type": "str",
|
||||
# "list": True,
|
||||
# "advanced": True,
|
||||
# }
|
||||
# assert template["model_kwargs"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": True,
|
||||
# "multiline": False,
|
||||
# "password": False,
|
||||
# "name": "model_kwargs",
|
||||
# "type": "code",
|
||||
# "list": False,
|
||||
# "advanced": True,
|
||||
# }
|
||||
# assert template["huggingfacehub_api_token"] == {
|
||||
# "required": False,
|
||||
# "placeholder": "",
|
||||
# "show": True,
|
||||
# "multiline": False,
|
||||
# "password": True,
|
||||
# "name": "huggingfacehub_api_token",
|
||||
# "display_name": "HuggingFace Hub API Token",
|
||||
# "type": "str",
|
||||
# "list": False,
|
||||
# "advanced": False,
|
||||
# }
|
||||
|
||||
|
||||
def test_openai(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_openai(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
|
@ -279,7 +177,7 @@ def test_openai(client: TestClient):
|
|||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "model_kwargs",
|
||||
"type": "code",
|
||||
"type": "dict",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
|
|
@ -334,7 +232,7 @@ def test_openai(client: TestClient):
|
|||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "logit_bias",
|
||||
"type": "code",
|
||||
"type": "dict",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
|
|
@ -369,8 +267,8 @@ def test_openai(client: TestClient):
|
|||
}
|
||||
|
||||
|
||||
def test_chat_open_ai(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_chat_open_ai(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
|
@ -451,7 +349,7 @@ def test_chat_open_ai(client: TestClient):
|
|||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "model_kwargs",
|
||||
"type": "code",
|
||||
"type": "dict",
|
||||
"list": False,
|
||||
"advanced": True,
|
||||
"info": "",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
|
|
|
|||
50
tests/test_login.py
Normal file
50
tests/test_login.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service
|
||||
import pytest
|
||||
from langflow.services.database.models.user import User
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return User(
|
||||
username="testuser",
|
||||
password=get_password_hash(
|
||||
"testpassword"
|
||||
), # Assuming password needs to be hashed
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
|
||||
def test_login_successful(client, test_user):
|
||||
# Adding the test user to the database
|
||||
with session_getter(get_db_service()) as session:
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_username(client):
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "wrongusername", "password": "testpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
||||
|
||||
def test_login_unsuccessful_wrong_password(client, test_user, session):
|
||||
# Adding the test user to the database
|
||||
session.add(test_user)
|
||||
session.commit()
|
||||
|
||||
response = client.post(
|
||||
"api/v1/login", data={"username": "testuser", "password": "wrongpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Incorrect username or password"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from langflow.interface.run import build_sorted_vertices_with_caching
|
||||
from langflow.processing.process import load_langchain_object, process_tweaks
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.getters import get_session_service
|
||||
|
||||
|
||||
def test_no_tweaks():
|
||||
|
|
@ -198,51 +198,38 @@ def test_tweak_not_in_template():
|
|||
|
||||
|
||||
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
# Build the langchain_object once and get the session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
# Use the same session_id to get the langchain_object again
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, session_id1
|
||||
)
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
assert id(langchain_object1) == id(langchain_object2)
|
||||
assert graph1 == graph2
|
||||
assert artifacts1 == artifacts2
|
||||
|
||||
|
||||
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, "non_existent_session"
|
||||
)
|
||||
session_service = get_session_service()
|
||||
session_id1 = "non-existent-session-id"
|
||||
session_id = session_service.build_key(session_id1, basic_graph_data)
|
||||
graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data)
|
||||
# Clear the cache
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
session_service.clear_session(session_id)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, session_id1
|
||||
)
|
||||
graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
assert id(langchain_object1) != id(
|
||||
langchain_object2
|
||||
) # Since the cache was cleared, objects should be different
|
||||
assert id(graph1) != id(graph2)
|
||||
# Since the cache was cleared, objects should be different
|
||||
|
||||
|
||||
def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
# Build the langchain_object without providing a session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
# Build the langchain_object again without providing a session_id
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
# Provide a non-existent session_id
|
||||
session_service = get_session_service()
|
||||
session_id1 = None
|
||||
graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data)
|
||||
# Use the new session_id to get the langchain_object again
|
||||
graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
|
||||
assert id(langchain_object1) == id(
|
||||
langchain_object2
|
||||
) # Since no session_id was provided, the hash will be based on the graph_data
|
||||
assert artifacts1 == artifacts2
|
||||
assert graph1 == graph2
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.utils import get_settings_manager
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
|
||||
def test_prompts_settings(client: TestClient):
|
||||
settings_manager = get_settings_manager()
|
||||
response = client.get("api/v1/all")
|
||||
def test_prompts_settings(client: TestClient, logged_in_headers):
|
||||
settings_service = get_settings_service()
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
assert set(prompts.keys()) == set(settings_manager.settings.PROMPTS)
|
||||
assert set(prompts.keys()) == set(settings_service.settings.PROMPTS)
|
||||
|
||||
|
||||
def test_prompt_template(client: TestClient):
|
||||
response = client.get("api/v1/all")
|
||||
def test_prompt_template(client: TestClient, logged_in_headers):
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
|
|
@ -55,7 +55,7 @@ def test_prompt_template(client: TestClient):
|
|||
"multiline": False,
|
||||
"password": False,
|
||||
"name": "partial_variables",
|
||||
"type": "code",
|
||||
"type": "dict",
|
||||
"list": False,
|
||||
"advanced": False,
|
||||
"info": "",
|
||||
|
|
|
|||
143
tests/test_setup_superuser.py
Normal file
143
tests/test_setup_superuser.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
from unittest.mock import patch, Mock, MagicMock, call
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.settings.constants import (
|
||||
DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD,
|
||||
)
|
||||
from langflow.services.utils import (
|
||||
setup_superuser,
|
||||
teardown_superuser,
|
||||
)
|
||||
|
||||
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.utils.create_super_user")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_setup_superuser(
|
||||
mock_get_session, mock_create_super_user, mock_get_settings_service
|
||||
):
|
||||
# Test when AUTO_LOGIN is True
|
||||
calls = []
|
||||
mock_settings_service = Mock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
mock_session = Mock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_session
|
||||
)
|
||||
# return value of get_session is a generator
|
||||
mock_get_session.return_value = iter([mock_session, mock_session, mock_session])
|
||||
setup_superuser(mock_settings_service, mock_session)
|
||||
mock_session.query.assert_called_once_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == DEFAULT_SUPERUSER
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
create_call = call(
|
||||
db=mock_session, username=DEFAULT_SUPERUSER, password=DEFAULT_SUPERUSER_PASSWORD
|
||||
)
|
||||
calls.append(create_call)
|
||||
# mock_create_super_user.assert_has_calls(calls)
|
||||
assert 1 == mock_create_super_user.call_count
|
||||
|
||||
def reset_mock_credentials():
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = (
|
||||
DEFAULT_SUPERUSER_PASSWORD
|
||||
)
|
||||
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
# Test when username and password are default
|
||||
mock_settings_service.auth_settings = Mock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
mock_settings_service.auth_settings.reset_credentials = Mock(
|
||||
side_effect=reset_mock_credentials
|
||||
)
|
||||
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
setup_superuser(mock_settings_service, mock_session)
|
||||
mock_session.query.assert_called_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == ADMIN_USER_NAME
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
create_call = call(db=mock_session, username=ADMIN_USER_NAME, password="password")
|
||||
calls.append(create_call)
|
||||
# mock_create_super_user.assert_has_calls(calls)
|
||||
assert 2 == mock_create_super_user.call_count
|
||||
# Test that superuser credentials are reset
|
||||
mock_settings_service.auth_settings.reset_credentials.assert_called_once()
|
||||
assert mock_settings_service.auth_settings.SUPERUSER != ADMIN_USER_NAME
|
||||
assert mock_settings_service.auth_settings.SUPERUSER_PASSWORD != "password"
|
||||
|
||||
# Test when superuser already exists
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
mock_user = Mock()
|
||||
mock_user.is_superuser = True
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
setup_superuser(mock_settings_service, mock_session)
|
||||
mock_session.query.assert_called_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == ADMIN_USER_NAME
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
|
||||
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = True
|
||||
mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = True
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = iter([mock_session])
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_called_once_with(User)
|
||||
actual_expr = mock_session.query.return_value.filter.call_args[0][0]
|
||||
expected_expr = User.username == DEFAULT_SUPERUSER
|
||||
|
||||
assert str(actual_expr) == str(expected_expr)
|
||||
mock_session.delete.assert_called_once_with(mock_user)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch("langflow.services.getters.get_settings_service")
|
||||
@patch("langflow.services.getters.get_session")
|
||||
def test_teardown_superuser_no_default_superuser(
|
||||
mock_get_session, mock_get_settings_service
|
||||
):
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
mock_settings_service = MagicMock()
|
||||
mock_settings_service.auth_settings.AUTO_LOGIN = False
|
||||
mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME
|
||||
mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password"
|
||||
mock_get_settings_service.return_value = mock_settings_service
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_superuser = False
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_get_session.return_value = [mock_session]
|
||||
|
||||
teardown_superuser(mock_settings_service, mock_session)
|
||||
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
|
@ -135,7 +135,7 @@ def test_format_dict():
|
|||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "code", # Mapping type is replaced with dict which is replaced with code
|
||||
"type": "dict[str, int]", # Mapping type is replaced with dict which is replaced with code
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
|
|
@ -249,7 +249,7 @@ def test_format_dict():
|
|||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "code",
|
||||
"type": "Dict[str, int]",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
|
|
|
|||
249
tests/test_user.py
Normal file
249
tests/test_user.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
from datetime import datetime
|
||||
from langflow.services.auth.utils import create_super_user, get_password_hash
|
||||
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service, get_settings_service
|
||||
import pytest
|
||||
from langflow.services.database.models.user import UserUpdate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user(client):
|
||||
settings_manager = get_settings_service()
|
||||
auth_settings = settings_manager.auth_settings
|
||||
with session_getter(get_db_service()) as session:
|
||||
return create_super_user(
|
||||
db=session,
|
||||
username=auth_settings.SUPERUSER,
|
||||
password=auth_settings.SUPERUSER_PASSWORD,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def super_user_headers(client, super_user):
|
||||
settings_service = get_settings_service()
|
||||
auth_settings = settings_service.auth_settings
|
||||
login_data = {
|
||||
"username": auth_settings.SUPERUSER,
|
||||
"password": auth_settings.SUPERUSER_PASSWORD,
|
||||
}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
tokens = response.json()
|
||||
a_token = tokens["access_token"]
|
||||
return {"Authorization": f"Bearer {a_token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deactivated_user():
|
||||
with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="deactivateduser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
is_superuser=False,
|
||||
last_login_at=datetime.now(),
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def test_user_waiting_for_approval(
|
||||
client,
|
||||
):
|
||||
# Create a user that is not active and has never logged in
|
||||
with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="waitingforapproval",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at=None,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
login_data = {"username": "waitingforapproval", "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Waiting for approval"
|
||||
|
||||
|
||||
def test_deactivated_user_cannot_login(client, deactivated_user):
|
||||
login_data = {"username": deactivated_user.username, "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json()["detail"] == "Inactive user"
|
||||
|
||||
|
||||
def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_headers):
|
||||
# Assuming the headers for deactivated_user
|
||||
response = client.get("/api/v1/users", headers=logged_in_headers)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json()["detail"] == "The user doesn't have enough privileges"
|
||||
|
||||
|
||||
def test_data_consistency_after_update(
|
||||
client, active_user, logged_in_headers, super_user_headers
|
||||
):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(is_active=False)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Fetch the updated user from the database
|
||||
response = client.get("/api/v1/users/whoami", headers=logged_in_headers)
|
||||
assert response.status_code == 401, response.json()
|
||||
assert response.json()["detail"] == "Could not validate credentials"
|
||||
|
||||
|
||||
def test_data_consistency_after_delete(client, test_user, super_user_headers):
|
||||
user_id = test_user.get("id")
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
# Attempt to fetch the deleted user from the database
|
||||
response = client.get("/api/v1/users", headers=super_user_headers)
|
||||
assert response.status_code == 200
|
||||
assert all(user["id"] != user_id for user in response.json()["users"])
|
||||
|
||||
|
||||
def test_inactive_user(client):
|
||||
# Create a user that is not active and has a last_login_at value
|
||||
with session_getter(get_db_service()) as session:
|
||||
user = User(
|
||||
username="inactiveuser",
|
||||
password=get_password_hash("testpassword"),
|
||||
is_active=False,
|
||||
last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
login_data = {"username": "inactiveuser", "password": "testpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Inactive user"
|
||||
|
||||
|
||||
def test_add_user(client, test_user):
|
||||
assert test_user["username"] == "testuser"
|
||||
|
||||
|
||||
# This is not used in the Frontend at the moment
|
||||
# def test_read_current_user(client: TestClient, active_user):
|
||||
# # First we need to login to get the access token
|
||||
# login_data = {"username": "testuser", "password": "testpassword"}
|
||||
# response = client.post("/api/v1/login", data=login_data)
|
||||
# assert response.status_code == 200
|
||||
|
||||
# headers = {"Authorization": f"Bearer {response.json()['access_token']}"}
|
||||
|
||||
# response = client.get("/api/v1/user", headers=headers)
|
||||
# assert response.status_code == 200, response.json()
|
||||
# assert response.json()["username"] == "testuser"
|
||||
|
||||
|
||||
def test_read_all_users(client, super_user_headers):
|
||||
response = client.get("/api/v1/users", headers=super_user_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
assert isinstance(response.json()["users"], list)
|
||||
|
||||
|
||||
def test_normal_user_cant_read_all_users(client, logged_in_headers):
|
||||
response = client.get("/api/v1/users", headers=logged_in_headers)
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json() == {"detail": "The user doesn't have enough privileges"}
|
||||
|
||||
|
||||
def test_patch_user(client, active_user, logged_in_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(
|
||||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
update_data = UserUpdate(
|
||||
profile_image="new_image",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
|
||||
|
||||
def test_patch_reset_password(client, active_user, logged_in_headers):
|
||||
user_id = active_user.id
|
||||
update_data = UserUpdate(
|
||||
password="newpassword",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}/reset-password",
|
||||
json=update_data.dict(),
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
# Now we need to test if the new password works
|
||||
login_data = {"username": active_user.username, "password": "newpassword"}
|
||||
response = client.post("/api/v1/login", data=login_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_patch_user_wrong_id(client, active_user, logged_in_headers):
|
||||
user_id = "wrong_id"
|
||||
update_data = UserUpdate(
|
||||
username="newname",
|
||||
)
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 422, response.json()
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_delete_user(client, test_user, super_user_headers):
|
||||
user_id = test_user["id"]
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"detail": "User deleted"}
|
||||
|
||||
|
||||
def test_delete_user_wrong_id(client, test_user, super_user_headers):
|
||||
user_id = "wrong_id"
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "user_id"],
|
||||
"msg": "value is not a valid uuid",
|
||||
"type": "type_error.uuid",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_normal_user_cant_delete_user(client, test_user, logged_in_headers):
|
||||
user_id = test_user["id"]
|
||||
response = client.delete(f"/api/v1/users/{user_id}", headers=logged_in_headers)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "The user doesn't have enough privileges"}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from langflow.services.utils import get_settings_manager
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
|
||||
# check that all agents are in settings.agents
|
||||
# are in json_response["agents"]
|
||||
def test_vectorstores_settings(client: TestClient):
|
||||
settings_manager = get_settings_manager()
|
||||
response = client.get("api/v1/all")
|
||||
def test_vectorstores_settings(client: TestClient, logged_in_headers):
|
||||
settings_service = get_settings_service()
|
||||
response = client.get("api/v1/all", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
vectorstores = json_response["vectorstores"]
|
||||
settings_vecs = set(settings_manager.settings.VECTORSTORES)
|
||||
settings_vecs = set(settings_service.settings.VECTORSTORES)
|
||||
assert all(vs in vectorstores for vs in settings_vecs)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from fastapi import WebSocketDisconnect
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# from langflow.services.chat.manager import ChatManager
|
||||
# from langflow.services.chat.manager import ChatService
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_init_build(client):
|
||||
def test_init_build(client, active_user, logged_in_headers):
|
||||
response = client.post(
|
||||
"api/v1/build/init/test", json={"id": "test", "data": {"key": "value"}}
|
||||
"api/v1/build/init/test",
|
||||
json={"id": "test", "data": {"key": "value"}},
|
||||
headers=logged_in_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"flowId": "test"}
|
||||
|
|
@ -24,10 +27,12 @@ def test_init_build(client):
|
|||
# assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
|
||||
def test_websocket_endpoint(client):
|
||||
def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers):
|
||||
# Assuming your websocket_endpoint uses chat_service which caches data from stream_build
|
||||
access_token = logged_in_headers["Authorization"].split(" ")[1]
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
"api/v1/chat/non_existing_client_id"
|
||||
f"api/v1/chat/non_existing_client_id?token={access_token}"
|
||||
) as websocket:
|
||||
websocket.send_json({"type": "test"})
|
||||
data = websocket.receive_json()
|
||||
|
|
@ -35,12 +40,12 @@ def test_websocket_endpoint(client):
|
|||
|
||||
|
||||
def test_websocket_endpoint_after_build(client, basic_graph_data):
|
||||
# Assuming your websocket_endpoint uses chat_manager which caches data from stream_build
|
||||
# Assuming your websocket_endpoint uses chat_service which caches data from stream_build
|
||||
client.post("api/v1/build/init", json=basic_graph_data)
|
||||
client.get("api/v1/build/stream/websocket_test")
|
||||
|
||||
# There should be more to test here, but it depends on the inner workings of your websocket handler
|
||||
# and how your chat_manager and other classes behave. The following is just an example structure.
|
||||
# and how your chat_service and other classes behave. The following is just an example structure.
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect("api/v1/chat/websocket_test") as websocket:
|
||||
websocket.send_json({"input": "test"})
|
||||
|
|
|
|||
0
tests/utils.py
Normal file
0
tests/utils.py
Normal file
Loading…
Add table
Add a link
Reference in a new issue