From 35d81b7e344af34df0db76de21dfc2ffec71cb52 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 27 Sep 2024 15:18:24 -0300 Subject: [PATCH] feat: Add `components_only` parameter to filter flows by components in API endpoint (#3932) * Add fixture for creating flow component in tests * Add test for reading flows with components only in test_database.py * Add `components_only` parameter to filter flows by components in API endpoint --- src/backend/base/langflow/api/v1/flows.py | 19 ++++++++++++------- src/backend/tests/conftest.py | 14 ++++++++++++-- src/backend/tests/unit/test_database.py | 8 ++++++++ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/api/v1/flows.py b/src/backend/base/langflow/api/v1/flows.py index 90bb82d4e..ceffae481 100644 --- a/src/backend/base/langflow/api/v1/flows.py +++ b/src/backend/base/langflow/api/v1/flows.py @@ -126,6 +126,7 @@ def read_flows( session: Session = Depends(get_session), settings_service: "SettingsService" = Depends(get_settings_service), remove_example_flows: bool = False, + components_only: bool = False, ): """ Retrieve a list of flows. @@ -135,7 +136,7 @@ def read_flows( session (Session): The database session. settings_service (SettingsService): The settings service. remove_example_flows (bool, optional): Whether to remove example flows. Defaults to False. - + components_only (bool, optional): Whether to return only components. Defaults to False. Returns: List[Dict]: A list of flows in JSON format. @@ -144,18 +145,22 @@ def read_flows( try: auth_settings = settings_service.auth_settings if auth_settings.AUTO_LOGIN: - flows = session.exec( - select(Flow).where( - (Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa - ) - ).all() + stmt = select(Flow).where( + (Flow.user_id == None) | (Flow.user_id == current_user.id) # noqa + ) + if components_only: + stmt = stmt.where(Flow.is_component == True) # noqa + flows = session.exec(stmt).all() + else: flows = current_user.flows flows = validate_is_component(flows) # type: ignore + if components_only: + flows = [flow for flow in flows if flow.is_component] flow_ids = [flow.id for flow in flows] # with the session get the flows that DO NOT have a user_id - if not remove_example_flows: + if not remove_example_flows and not components_only: try: folder = session.exec(select(Folder).where(Folder.name == STARTER_FOLDER_NAME)).first() diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index e2a020a92..b632d4935 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING import orjson import pytest +from base.langflow.components.inputs.ChatInput import ChatInput from dotenv import load_dotenv from fastapi.testclient import TestClient from httpx import AsyncClient @@ -420,6 +421,17 @@ def added_webhook_test(client, json_webhook_test, logged_in_headers): client.delete(f"api/v1/flows/{response.json()['id']}", headers=logged_in_headers) +@pytest.fixture +def flow_component(client: TestClient, logged_in_headers): + chat_input = ChatInput() + graph = Graph(start=chat_input, end=chat_input) + graph_dict = graph.dump(name="Chat Input Component") + flow = FlowCreate(**graph_dict) + response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + assert response.status_code == 201 + return response.json() + + @pytest.fixture def created_api_key(active_user): hashed = get_password_hash("random_key") @@ -448,8 +460,6 @@ def get_simple_api_test(client, logged_in_headers, json_simple_api_test): flow = FlowCreate(name="Simple API Test", data=data, description="Simple API Test") response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) assert response.status_code == 201 - assert response.json()["name"] == flow.name - assert response.json()["data"] == flow.data return response.json() diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 60b617545..c59762417 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -67,6 +67,14 @@ def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_h assert len(response.json()) > 0 +def test_read_flows_components_only(client: TestClient, flow_component: dict, logged_in_headers): + response = client.get("api/v1/flows/", headers=logged_in_headers, params={"components_only": True}) + assert response.status_code == 200 + names = [flow["name"] for flow in response.json()] + assert any("Chat Input Component" in name for name in names) + assert all(flow["is_component"] is True for flow in response.json()), [flow["name"] for flow in response.json()] + + def test_read_flow(client: TestClient, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"]