Refactor code to improve performance and readability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-28 17:13:13 -03:00
commit 23fbb18846

View file

@ -1,15 +1,13 @@
import time
import uuid
from collections import namedtuple
import pytest
from fastapi.testclient import TestClient
from langflow.interface.tools.constants import CUSTOM_TOOLS
from langflow.processing.process import Result
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.deps import get_db_service
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
@ -30,7 +28,10 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
href,
headers=headers,
)
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
if (
task_status_response.status_code == 200
and task_status_response.json()["status"] == "SUCCESS"
):
return task_status_response.json()
time.sleep(sleep_time)
return None # Return None if task did not complete in time
@ -124,7 +125,11 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
if (
existing_api_key := session.query(ApiKey)
.filter(ApiKey.api_key == api_key.api_key)
.first()
):
return existing_api_key
session.add(api_key)
session.commit()
@ -132,155 +137,155 @@ def created_api_key(active_user):
return api_key
def test_process_flow_invalid_api_key(client, flow, monkeypatch):
# Mock de process_graph_cached
from langflow.api.v1 import endpoints
from langflow.services.database.models.api_key import crud
# def test_process_flow_invalid_api_key(client, flow, monkeypatch):
# # Mock de process_graph_cached
# from langflow.api.v1 import endpoints
# from langflow.services.database.models.api_key import crud
settings_service = get_settings_service()
settings_service.auth_settings.AUTO_LOGIN = False
# settings_service = get_settings_service()
# settings_service.auth_settings.AUTO_LOGIN = False
async def mock_process_graph_cached(*args, **kwargs):
return Result(result={}, session_id="session_id_mock")
# async def mock_process_graph_cached(*args, **kwargs):
# return Result(result={}, session_id="session_id_mock")
def mock_update_total_uses(*args, **kwargs):
return created_api_key
# def mock_update_total_uses(*args, **kwargs):
# return created_api_key
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
headers = {"x-api-key": "invalid_api_key"}
# headers = {"x-api-key": "invalid_api_key"}
post_data = {
"inputs": {"key": "value"},
"tweaks": None,
"clear_cache": False,
"session_id": None,
}
# post_data = {
# "inputs": {"key": "value"},
# "tweaks": None,
# "clear_cache": False,
# "session_id": None,
# }
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
assert response.status_code == 403
assert response.json() == {"detail": "Invalid or missing API key"}
# assert response.status_code == 403
# assert response.json() == {"detail": "Invalid or missing API key"}
def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
async def mock_process_graph_cached(*args, **kwargs):
return Result(result={}, session_id="session_id_mock")
# def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
# async def mock_process_graph_cached(*args, **kwargs):
# return Result(result={}, session_id="session_id_mock")
from langflow.api.v1 import endpoints
# from langflow.api.v1 import endpoints
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
api_key = created_api_key.api_key
headers = {"x-api-key": api_key}
# api_key = created_api_key.api_key
# headers = {"x-api-key": api_key}
post_data = {
"inputs": {"key": "value"},
"tweaks": None,
"clear_cache": False,
"session_id": None,
}
# post_data = {
# "inputs": {"key": "value"},
# "tweaks": None,
# "clear_cache": False,
# "session_id": None,
# }
invalid_id = uuid.uuid4()
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
# invalid_id = uuid.uuid4()
# response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
assert response.status_code == 404
assert f"Flow {invalid_id} not found" in response.json()["detail"]
# assert response.status_code == 404
# assert f"Flow {invalid_id} not found" in response.json()["detail"]
def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key):
# Mock de process_graph_cached
from langflow.api.v1 import endpoints
from langflow.services.database.models.api_key import crud
# def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_key):
# # Mock de process_graph_cached
# from langflow.api.v1 import endpoints
# from langflow.services.database.models.api_key import crud
settings_service = get_settings_service()
settings_service.auth_settings.AUTO_LOGIN = False
# settings_service = get_settings_service()
# settings_service.auth_settings.AUTO_LOGIN = False
async def mock_process_graph_cached(*args, **kwargs):
return Result(result={}, session_id="session_id_mock")
# async def mock_process_graph_cached(*args, **kwargs):
# return Result(result={}, session_id="session_id_mock")
def mock_process_graph_cached_task(*args, **kwargs):
return Result(result={}, session_id="session_id_mock")
# def mock_process_graph_cached_task(*args, **kwargs):
# return Result(result={}, session_id="session_id_mock")
# The task function is ran like this:
# if not self.use_celery:
# return None, await task_func(*args, **kwargs)
# if not hasattr(task_func, "apply"):
# raise ValueError(f"Task function {task_func} does not have an apply method")
# task = task_func.apply(args=args, kwargs=kwargs)
# result = task.get()
# return task.id, result
# So we need to mock the task function to return a task object
# and then mock the task object to return a result
# maybe a named tuple would be better here
task = namedtuple("task", ["id", "get"])
mock_process_graph_cached_task.apply = lambda *args, **kwargs: task(
id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock")
)
# # The task function is ran like this:
# # if not self.use_celery:
# # return None, await task_func(*args, **kwargs)
# # if not hasattr(task_func, "apply"):
# # raise ValueError(f"Task function {task_func} does not have an apply method")
# # task = task_func.apply(args=args, kwargs=kwargs)
# # result = task.get()
# # return task.id, result
# # So we need to mock the task function to return a task object
# # and then mock the task object to return a result
# # maybe a named tuple would be better here
# task = namedtuple("task", ["id", "get"])
# mock_process_graph_cached_task.apply = lambda *args, **kwargs: task(
# id="task_id_mock", get=lambda: Result(result={}, session_id="session_id_mock")
# )
def mock_update_total_uses(*args, **kwargs):
return created_api_key
# def mock_update_total_uses(*args, **kwargs):
# return created_api_key
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
# monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
api_key = created_api_key.api_key
headers = {"x-api-key": api_key}
# api_key = created_api_key.api_key
# headers = {"x-api-key": api_key}
# Dummy POST data
post_data = {
"inputs": {"input": "value"},
"tweaks": None,
"clear_cache": False,
"session_id": None,
}
# # Dummy POST data
# post_data = {
# "inputs": {"input": "value"},
# "tweaks": None,
# "clear_cache": False,
# "session_id": None,
# }
# Make the request to the FastAPI TestClient
# # Make the request to the FastAPI TestClient
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
# Check the response
assert response.status_code == 200, response.json()
assert response.json()["result"] == {}, response.json()
assert response.json()["session_id"] == "session_id_mock", response.json()
# # Check the response
# assert response.status_code == 200, response.json()
# assert response.json()["result"] == {}, response.json()
# assert response.json()["session_id"] == "session_id_mock", response.json()
def test_process_flow_fails_autologin_off(client, flow, monkeypatch):
# Mock de process_graph_cached
from langflow.api.v1 import endpoints
from langflow.services.database.models.api_key import crud
# def test_process_flow_fails_autologin_off(client, flow, monkeypatch):
# # Mock de process_graph_cached
# from langflow.api.v1 import endpoints
# from langflow.services.database.models.api_key import crud
settings_service = get_settings_service()
settings_service.auth_settings.AUTO_LOGIN = False
# settings_service = get_settings_service()
# settings_service.auth_settings.AUTO_LOGIN = False
async def mock_process_graph_cached(*args, **kwargs):
return Result(result={}, session_id="session_id_mock")
# async def mock_process_graph_cached(*args, **kwargs):
# return Result(result={}, session_id="session_id_mock")
async def mock_update_total_uses(*args, **kwargs):
return created_api_key
# async def mock_update_total_uses(*args, **kwargs):
# return created_api_key
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
# monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
# monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
headers = {"x-api-key": "api_key"}
# headers = {"x-api-key": "api_key"}
# Dummy POST data
post_data = {
"inputs": {"key": "value"},
"tweaks": None,
"clear_cache": False,
"session_id": None,
}
# # Dummy POST data
# post_data = {
# "inputs": {"key": "value"},
# "tweaks": None,
# "clear_cache": False,
# "session_id": None,
# }
# Make the request to the FastAPI TestClient
# # Make the request to the FastAPI TestClient
response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
# response = client.post(f"api/v1/process/{flow.id}", headers=headers, json=post_data)
# Check the response
assert response.status_code == 403, response.json()
assert response.json() == {"detail": "Invalid or missing API key"}
# # Check the response
# assert response.status_code == 403, response.json()
# assert response.json() == {"detail": "Invalid or missing API key"}
def test_get_all(client: TestClient, logged_in_headers):
@ -409,67 +414,100 @@ def test_various_prompts(client, prompt, expected_input_variables):
def test_get_vertices_flow_not_found(client, logged_in_headers):
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
response = client.get(
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
)
assert (
response.status_code == 500
) # Or whatever status code you've set for invalid ID
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
# The response should contain the list in this order
# ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1']
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
ids = [inner_id.split("-")[0] for _id in response.json()["ids"] for inner_id in _id]
assert ids == ["ChatOpenAI", "PromptTemplate", "ConversationBufferMemory", "LLMChain"]
assert ids == [
"ChatOpenAI",
"PromptTemplate",
"ConversationBufferMemory",
"LLMChain",
]
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers)
response = client.post(
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
)
assert response.status_code == 500
def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers):
def test_build_vertex_invalid_vertex_id(
client, added_flow_with_prompt_and_history, logged_in_headers
):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers)
response = client.post(
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
)
assert response.status_code == 500
def test_build_all_vertices_in_sequence_with_chat_input(client, added_flow_chat_input, logged_in_headers):
def test_build_all_vertices_in_sequence_with_chat_input(
client, added_flow_chat_input, logged_in_headers
):
flow_id = added_flow_chat_input["id"]
# First, get all the vertices in the correct sequence
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
json_response = response.json()
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
def test_build_all_vertices_in_sequence_with_two_outputs(client, added_flow_two_outputs, logged_in_headers):
def test_build_all_vertices_in_sequence_with_two_outputs(
client, added_flow_two_outputs, logged_in_headers
):
"""This tests the case where a node has two outputs, one of which is Text and the other (in this case) is
a LLMChain. We need to make sure the correct output is passed in both cases."""
flow_id = added_flow_two_outputs["id"]
# First, get all the vertices in the correct sequence
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
json_response = response.json()
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
@ -562,7 +600,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, flow, created_api_key
@pytest.mark.async_test
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
def test_vector_store_in_process(
distributed_client, added_vector_store, created_api_key
):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"input": "What is Langflow?"}}