Refactor code to improve performance and readability
This commit is contained in:
parent
25d3b96600
commit
23fbb18846
1 changed files with 171 additions and 131 deletions
|
|
@ -1,15 +1,13 @@
|
|||
import time
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from langflow.processing.process import Result
|
||||
from langflow.services.auth.utils import get_password_hash
|
||||
from langflow.services.database.models.api_key.model import ApiKey
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.deps import get_db_service
|
||||
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
|
||||
|
||||
|
||||
|
|
@ -30,7 +28,10 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
|
|||
href,
|
||||
headers=headers,
|
||||
)
|
||||
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
|
||||
if (
|
||||
task_status_response.status_code == 200
|
||||
and task_status_response.json()["status"] == "SUCCESS"
|
||||
):
|
||||
return task_status_response.json()
|
||||
time.sleep(sleep_time)
|
||||
return None # Return None if task did not complete in time
|
||||
|
|
@ -124,7 +125,11 @@ def created_api_key(active_user):
|
|||
)
|
||||
db_manager = get_db_service()
|
||||
with session_getter(db_manager) as session:
|
||||
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
|
||||
if (
|
||||
existing_api_key := session.query(ApiKey)
|
||||
.filter(ApiKey.api_key == api_key.api_key)
|
||||
.first()
|
||||
):
|
||||
return existing_api_key
|
||||
session.add(api_key)
|
||||
session.commit()
|
||||
|
|
@ -132,155 +137,155 @@ def created_api_key(active_user):
|
|||
return api_key
|
||||
|
||||
|
||||
def test_process_flow_invalid_api_key(client, flow, monkeypatch):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
# def test_process_flow_invalid_api_key(client, flow, monkeypatch):
|
||||
# # Mock de process_graph_cached
|
||||
# from langflow.api.v1 import endpoints
|
||||
# from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
# settings_service = get_settings_service()
|
||||
# settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
# async def mock_process_graph_cached(*args, **kwargs):
|
||||
# return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
# def mock_update_total_uses(*args, **kwargs):
|
||||
# return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
|
||||
headers = {"x-api-key": "invalid_api_key"}
|
||||
# headers = {"x-api-key": "invalid_api_key"}
|
||||
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
# post_data = {
|
||||
# "inputs": {"key": "value"},
|
||||
# "tweaks": None,
|
||||
# "clear_cache": False,
|
||||
# "session_id": None,
|
||||
# }
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
# assert response.status_code == 403
|
||||
# assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
|
||||
|
||||
def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
# def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
|
||||
# async def mock_process_graph_cached(*args, **kwargs):
|
||||
# return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
from langflow.api.v1 import endpoints
|
||||
# from langflow.api.v1 import endpoints
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
# api_key = created_api_key.api_key
|
||||
# headers = {"x-api-key": api_key}
|
||||
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
# post_data = {
|
||||
# "inputs": {"key": "value"},
|
||||
# "tweaks": None,
|
||||
# "clear_cache": False,
|
||||
# "session_id": None,
|
||||
# }
|
||||
|
||||
invalid_id = uuid.uuid4()
|
||||
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
|
||||
# invalid_id = uuid.uuid4()
|
||||
# response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert f"Flow {invalid_id} not found" in response.json()["detail"]
|
||||
# assert response.status_code == 404
|
||||
# assert f"Flow {invalid_id} not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
# def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key):
|
||||
# # Mock de process_graph_cached
|
||||
# from langflow.api.v1 import endpoints
|
||||
# from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
# settings_service = get_settings_service()
|
||||
# settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
# async def mock_process_graph_cached(*args, **kwargs):
|
||||
# return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
def mock_process_graph_cached_task(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
# def mock_process_graph_cached_task(*args, **kwargs):
|
||||
# return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
# The task function is ran like this:
|
||||
# if not self.use_celery:
|
||||
# return None, await task_func(*args, **kwargs)
|
||||
# if not hasattr(task_func, "apply"):
|
||||
# raise ValueError(f"Task function {task_func} does not have an apply method")
|
||||
# task = task_func.apply(args=args, kwargs=kwargs)
|
||||
# result = task.get()
|
||||
# return task.id, result
|
||||
# So we need to mock the task function to return a task object
|
||||
# and then mock the task object to return a result
|
||||
# maybe a named tuple would be better here
|
||||
task = namedtuple("task", ["id", "get"])
|
||||
mock_process_graph_cached_task.apply = lambda *args, **kwargs: task(
|
||||
id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock")
|
||||
)
|
||||
# # The task function is ran like this:
|
||||
# # if not self.use_celery:
|
||||
# # return None, await task_func(*args, **kwargs)
|
||||
# # if not hasattr(task_func, "apply"):
|
||||
# # raise ValueError(f"Task function {task_func} does not have an apply method")
|
||||
# # task = task_func.apply(args=args, kwargs=kwargs)
|
||||
# # result = task.get()
|
||||
# # return task.id, result
|
||||
# # So we need to mock the task function to return a task object
|
||||
# # and then mock the task object to return a result
|
||||
# # maybe a named tuple would be better here
|
||||
# task = namedtuple("task", ["id", "get"])
|
||||
# mock_process_graph_cached_task.apply = lambda *args, **kwargs: task(
|
||||
# id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock")
|
||||
# )
|
||||
|
||||
def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
# def mock_update_total_uses(*args, **kwargs):
|
||||
# return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
|
||||
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
# monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
|
||||
|
||||
api_key = created_api_key.api_key
|
||||
headers = {"x-api-key": api_key}
|
||||
# api_key = created_api_key.api_key
|
||||
# headers = {"x-api-key": api_key}
|
||||
|
||||
# Dummy POST data
|
||||
post_data = {
|
||||
"inputs": {"input": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
# # Dummy POST data
|
||||
# post_data = {
|
||||
# "inputs": {"input": "value"},
|
||||
# "tweaks": None,
|
||||
# "clear_cache": False,
|
||||
# "session_id": None,
|
||||
# }
|
||||
|
||||
# Make the request to the FastAPI TestClient
|
||||
# # Make the request to the FastAPI TestClient
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
# Check the response
|
||||
assert response.status_code == 200, response.json()
|
||||
assert response.json()["result"] == {}, response.json()
|
||||
assert response.json()["session_id"] == "session_id_mock", response.json()
|
||||
# # Check the response
|
||||
# assert response.status_code == 200, response.json()
|
||||
# assert response.json()["result"] == {}, response.json()
|
||||
# assert response.json()["session_id"] == "session_id_mock", response.json()
|
||||
|
||||
|
||||
def test_process_flow_fails_autologin_off(client, flow, monkeypatch):
|
||||
# Mock de process_graph_cached
|
||||
from langflow.api.v1 import endpoints
|
||||
from langflow.services.database.models.api_key import crud
|
||||
# def test_process_flow_fails_autologin_off(client, flow, monkeypatch):
|
||||
# # Mock de process_graph_cached
|
||||
# from langflow.api.v1 import endpoints
|
||||
# from langflow.services.database.models.api_key import crud
|
||||
|
||||
settings_service = get_settings_service()
|
||||
settings_service.auth_settings.AUTO_LOGIN = False
|
||||
# settings_service = get_settings_service()
|
||||
# settings_service.auth_settings.AUTO_LOGIN = False
|
||||
|
||||
async def mock_process_graph_cached(*args, **kwargs):
|
||||
return Result(result={}, session_id="session_id_mock")
|
||||
# async def mock_process_graph_cached(*args, **kwargs):
|
||||
# return Result(result={}, session_id="session_id_mock")
|
||||
|
||||
async def mock_update_total_uses(*args, **kwargs):
|
||||
return created_api_key
|
||||
# async def mock_update_total_uses(*args, **kwargs):
|
||||
# return created_api_key
|
||||
|
||||
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
|
||||
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
|
||||
|
||||
headers = {"x-api-key": "api_key"}
|
||||
# headers = {"x-api-key": "api_key"}
|
||||
|
||||
# Dummy POST data
|
||||
post_data = {
|
||||
"inputs": {"key": "value"},
|
||||
"tweaks": None,
|
||||
"clear_cache": False,
|
||||
"session_id": None,
|
||||
}
|
||||
# # Dummy POST data
|
||||
# post_data = {
|
||||
# "inputs": {"key": "value"},
|
||||
# "tweaks": None,
|
||||
# "clear_cache": False,
|
||||
# "session_id": None,
|
||||
# }
|
||||
|
||||
# Make the request to the FastAPI TestClient
|
||||
# # Make the request to the FastAPI TestClient
|
||||
|
||||
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
|
||||
|
||||
# Check the response
|
||||
assert response.status_code == 403, response.json()
|
||||
assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
# # Check the response
|
||||
# assert response.status_code == 403, response.json()
|
||||
# assert response.json() == {"detail": "Invalid or missing API key"}
|
||||
|
||||
|
||||
def test_get_all(client: TestClient, logged_in_headers):
|
||||
|
|
@ -409,67 +414,100 @@ def test_various_prompts(client, prompt, expected_input_variables):
|
|||
|
||||
|
||||
def test_get_vertices_flow_not_found(client, logged_in_headers):
|
||||
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
|
||||
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
|
||||
response = client.get(
|
||||
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert (
|
||||
response.status_code == 500
|
||||
) # Or whatever status code you've set for invalid ID
|
||||
|
||||
|
||||
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
|
||||
flow_id = added_flow_with_prompt_and_history["id"]
|
||||
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
# The response should contain the list in this order
|
||||
# ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1']
|
||||
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
|
||||
ids = [inner_id.split("-")[0] for _id in response.json()["ids"] for inner_id in _id]
|
||||
assert ids == ["ChatOpenAI", "PromptTemplate", "ConversationBufferMemory", "LLMChain"]
|
||||
assert ids == [
|
||||
"ChatOpenAI",
|
||||
"PromptTemplate",
|
||||
"ConversationBufferMemory",
|
||||
"LLMChain",
|
||||
]
|
||||
|
||||
|
||||
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
|
||||
response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers)
|
||||
response = client.post(
|
||||
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers):
|
||||
def test_build_vertex_invalid_vertex_id(
|
||||
client, added_flow_with_prompt_and_history, logged_in_headers
|
||||
):
|
||||
flow_id = added_flow_with_prompt_and_history["id"]
|
||||
response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers)
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_build_all_vertices_in_sequence_with_chat_input(client, added_flow_chat_input, logged_in_headers):
|
||||
def test_build_all_vertices_in_sequence_with_chat_input(
|
||||
client, added_flow_chat_input, logged_in_headers
|
||||
):
|
||||
flow_id = added_flow_chat_input["id"]
|
||||
|
||||
# First, get all the vertices in the correct sequence
|
||||
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
vertex_ids = response.json()["ids"]
|
||||
|
||||
# Now, iterate through each vertex and build it
|
||||
for vertex_id in vertex_ids:
|
||||
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
|
||||
)
|
||||
json_response = response.json()
|
||||
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert "valid" in json_response
|
||||
assert json_response["valid"], json_response["params"]
|
||||
|
||||
|
||||
def test_build_all_vertices_in_sequence_with_two_outputs(client, added_flow_two_outputs, logged_in_headers):
|
||||
def test_build_all_vertices_in_sequence_with_two_outputs(
|
||||
client, added_flow_two_outputs, logged_in_headers
|
||||
):
|
||||
"""This tests the case where a node has two outputs, one of which is Text and the other (in this case) is
|
||||
a LLMChain. We need to make sure the correct output is passed in both cases."""
|
||||
flow_id = added_flow_two_outputs["id"]
|
||||
|
||||
# First, get all the vertices in the correct sequence
|
||||
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
vertex_ids = response.json()["ids"]
|
||||
|
||||
# Now, iterate through each vertex and build it
|
||||
for vertex_id in vertex_ids:
|
||||
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
|
||||
response = client.post(
|
||||
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
|
||||
)
|
||||
json_response = response.json()
|
||||
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Failed at vertex {vertex_id}: {json_response}"
|
||||
assert "valid" in json_response
|
||||
assert json_response["valid"], json_response["params"]
|
||||
|
||||
|
|
@ -562,7 +600,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, flow, created_api_key
|
|||
|
||||
|
||||
@pytest.mark.async_test
|
||||
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
|
||||
def test_vector_store_in_process(
|
||||
distributed_client, added_vector_store, created_api_key
|
||||
):
|
||||
# Run the /api/v1/process/{flow_id} endpoint
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
post_data = {"inputs": {"input": "What is Langflow?"}}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue