From abe4925cc3c38ea2305b4ecadcbd0cd7387999d4 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 25 Jan 2024 16:17:40 -0300 Subject: [PATCH] Refactor imports and update function names --- tests/test_chains_template.py | 24 +---- tests/test_cli.py | 9 +- tests/test_custom_component.py | 81 ++++++----------- tests/test_database.py | 74 +++++----------- tests/test_endpoints.py | 135 +++++++++------------------- tests/test_frontend_nodes.py | 9 +- tests/test_graph.py | 136 +++++++++++------------------ tests/test_vectorstore_template.py | 2 +- 8 files changed, 156 insertions(+), 314 deletions(-) diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index 4a8038d14..970c0f315 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -1,6 +1,5 @@ from fastapi.testclient import TestClient - # def test_chains_settings(client: TestClient, logged_in_headers): # response = client.get("api/v1/all", headers=logged_in_headers) # assert response.status_code == 200 @@ -18,13 +17,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers): chain = chains["ConversationChain"] # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects - assert set(chain["base_classes"]) == { - "ConversationChain", - "LLMChain", - "Chain", - "function", - "Text", - } + assert set(chain["base_classes"]) == {"Callable", "Chain"} template = chain["template"] assert template["memory"] == { @@ -97,10 +90,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers): assert template["_type"] == "ConversationChain" # Test the description object - assert ( - chain["description"] - == "Chain to have a conversation and load context from memory." - ) + assert chain["description"] == "Chain to have a conversation and load context from memory." def test_llm_chain(client: TestClient, logged_in_headers): @@ -287,10 +277,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): assert template["_type"] == "LLMMathChain" # Test the description object - assert ( - chain["description"] - == "Chain that interprets a prompt and executes python code to do math." - ) + assert chain["description"] == "Chain that interprets a prompt and executes python code to do math." def test_series_character_chain(client: TestClient, logged_in_headers): @@ -396,10 +383,7 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers): "info": "", } # Test the description object - assert ( - chain["description"] - == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." - ) + assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." def test_time_travel_guide_chain(client: TestClient, logged_in_headers): diff --git a/tests/test_cli.py b/tests/test_cli.py index ee938db12..efde059f6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,9 @@ from pathlib import Path from tempfile import tempdir -from langflow.__main__ import app -import pytest -from langflow.services import getters +import pytest +from langflow.__main__ import app +from langflow.services import deps @pytest.fixture(scope="module") @@ -16,6 +16,7 @@ def default_settings(): def test_components_path(runner, client, default_settings): # Create a foldr in the tmp directory + temp_dir = Path(tempdir) # create a "components" folder temp_dir = temp_dir / "components" @@ -26,7 +27,7 @@ def test_components_path(runner, client, default_settings): ["run", "--components-path", str(temp_dir), *default_settings], ) assert result.exit_code == 0, result.stdout - settings_service = getters.get_settings_service() + settings_service = deps.get_settings_service() assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 47c9cbfb2..a03be10bf 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,19 +1,18 @@ import ast -import pytest import types from uuid import uuid4 - +import pytest from fastapi import HTTPException -from langflow.services.database.models.flow import Flow, FlowCreate +from langchain_core.documents import Document from langflow.interface.custom.base import CustomComponent -from langflow.interface.custom.component import ( +from langflow.interface.custom.code_parser.code_parser import CodeParser, CodeSyntaxError +from langflow.interface.custom.custom_component.component import ( Component, ComponentCodeNullError, ComponentFunctionEntrypointNameNullError, ) -from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError - +from langflow.services.database.models.flow import Flow, FlowCreate code_default = """ from langflow import Prompt @@ -53,7 +52,7 @@ def test_code_parser_get_tree(): Test the __get_tree method of the CodeParser class. """ parser = CodeParser(code_default) - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() assert isinstance(tree, ast.AST) @@ -66,7 +65,7 @@ def test_code_parser_syntax_error(): parser = CodeParser(code_syntax_error) with pytest.raises(CodeSyntaxError): - parser._CodeParser__get_tree() + parser.get_tree() def test_component_init(): @@ -113,9 +112,7 @@ def test_custom_component_init(): """ function_entrypoint_name = "build" - custom_component = CustomComponent( - code=code_default, function_entrypoint_name=function_entrypoint_name - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name) assert custom_component.code == code_default assert custom_component.function_entrypoint_name == function_entrypoint_name @@ -124,10 +121,8 @@ def test_custom_component_build_template_config(): """ Test the build_template_config property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) - config = custom_component.build_template_config + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") + config = custom_component.build_template_config() assert isinstance(config, dict) @@ -135,9 +130,7 @@ def test_custom_component_get_function(): """ Test the get_function property of the CustomComponent class. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function assert isinstance(my_function, types.FunctionType) @@ -148,7 +141,7 @@ def test_code_parser_parse_imports_import(): class with an import statement. """ parser = CodeParser(code_default) - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() for node in ast.walk(tree): if isinstance(node, ast.Import): parser.parse_imports(node) @@ -161,7 +154,7 @@ def test_code_parser_parse_imports_importfrom(): class with an import from statement. """ parser = CodeParser("from os import path") - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): parser.parse_imports(node) @@ -173,7 +166,7 @@ def test_code_parser_parse_functions(): Test the parse_functions method of the CodeParser class. """ parser = CodeParser("def test(): pass") - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): parser.parse_functions(node) @@ -186,7 +179,7 @@ def test_code_parser_parse_classes(): Test the parse_classes method of the CodeParser class. """ parser = CodeParser("class Test: pass") - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() for node in ast.walk(tree): if isinstance(node, ast.ClassDef): parser.parse_classes(node) @@ -199,7 +192,7 @@ def test_code_parser_parse_global_vars(): Test the parse_global_vars method of the CodeParser class. """ parser = CodeParser("x = 1") - tree = parser._CodeParser__get_tree() + tree = parser.get_tree() for node in ast.walk(tree): if isinstance(node, ast.Assign): parser.parse_global_vars(node) @@ -222,9 +215,7 @@ def test_custom_component_get_function_entrypoint_args(): Test the get_function_entrypoint_args property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") args = custom_component.get_function_entrypoint_args assert len(args) == 4 assert args[0]["name"] == "self" @@ -237,20 +228,16 @@ def test_custom_component_get_function_entrypoint_return_type(): Test the get_function_entrypoint_return_type property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type - assert return_type == ["Document"] + assert return_type == [Document] def test_custom_component_get_main_class_name(): """ Test the get_main_class_name property of the CustomComponent class. """ - custom_component = CustomComponent( - code=code_default, function_entrypoint_name="build" - ) + custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") class_name = custom_component.get_main_class_name assert class_name == "YourComponent" @@ -260,9 +247,7 @@ def test_custom_component_get_function_valid(): Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") my_function = custom_component.get_function assert callable(my_function) @@ -297,9 +282,7 @@ def test_code_parser_parse_callable_details_no_args(): parser = CodeParser("") node = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -345,9 +328,7 @@ def test_code_parser_parse_function_def_not_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="test", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -365,9 +346,7 @@ def test_code_parser_parse_function_def_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="__init__", - args=ast.arguments( - args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] - ), + args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[], decorator_list=[], returns=None, @@ -402,9 +381,7 @@ def test_custom_component_get_code_tree_syntax_error(): Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. """ - custom_component = CustomComponent( - code="import os as", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="import os as", function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): custom_component.get_code_tree(custom_component.code) @@ -458,9 +435,7 @@ def test_custom_component_build_not_implemented(): Test the build method of the CustomComponent class raises the NotImplementedError. """ - custom_component = CustomComponent( - code="def build(): pass", function_entrypoint_name="build" - ) + custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") with pytest.raises(NotImplementedError): custom_component.build() @@ -494,9 +469,7 @@ def test_flow(db): } # Create flow - flow = FlowCreate( - id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data - ) + flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data) # Add to database db.add(flow) diff --git a/tests/test_database.py b/tests/test_database.py index 21f0cec17..a5ed533d4 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,16 +1,14 @@ -from langflow.services.database.models.base import orjson_dumps -from langflow.services.database.utils import session_getter -from langflow.services.getters import get_db_service +from uuid import UUID, uuid4 + import orjson import pytest - -from uuid import UUID, uuid4 -from sqlmodel import Session - from fastapi.testclient import TestClient - from langflow.api.v1.schemas import FlowListCreate +from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate +from langflow.services.database.utils import session_getter +from langflow.services.deps import get_db_service +from sqlmodel import Session @pytest.fixture(scope="module") @@ -27,9 +25,7 @@ def json_style(): ) -def test_create_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -39,9 +35,7 @@ def test_create_flow( assert response.json()["data"] == flow.data # flow is optional so we can create a flow without a flow flow = FlowCreate(name="Test Flow") - response = client.post( - "api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -82,9 +76,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he assert response.json()["data"] == flow.data -def test_update_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] @@ -97,9 +89,7 @@ def test_update_flow( description="updated description", data=data, ) - response = client.patch( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 200 assert response.json()["name"] == updated_flow.name @@ -107,9 +97,7 @@ def test_update_flow( # assert response.json()["data"] == updated_flow.data -def test_delete_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) @@ -120,9 +108,7 @@ def test_delete_flow( assert response.json()["message"] == "Flow deleted successfully" -def test_create_flows( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -133,9 +119,7 @@ def test_create_flows( ] ) # Make request to endpoint - response = client.post( - "api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers) # Check response status code assert response.status_code == 201 # Check response data @@ -149,9 +133,7 @@ def test_create_flows( assert response_data[1]["data"] == data -def test_upload_file( - client: TestClient, session: Session, json_flow: str, logged_in_headers -): +def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -218,9 +200,7 @@ def test_download_file( assert response_data[1]["data"] == data -def test_create_flow_with_invalid_data( - client: TestClient, active_user, logged_in_headers -): +def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers): flow = {"name": "a" * 256, "data": "Invalid flow data"} response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers) assert response.status_code == 422 @@ -232,29 +212,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers assert response.status_code == 404 -def test_update_flow_idempotency( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] flow_data = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post( - "api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers - ) + response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers) flow_id = response.json()["id"] updated_flow = FlowCreate(name="Updated Flow", description="description", data=data) - response1 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) - response2 = client.put( - f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers - ) + response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) + response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers) assert response1.json() == response2.json() -def test_update_nonexistent_flow( - client: TestClient, json_flow: str, active_user, logged_in_headers -): +def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): flow_data = orjson.loads(json_flow) data = flow_data["data"] uuid = uuid4() @@ -263,9 +233,7 @@ def test_update_nonexistent_flow( description="description", data=data, ) - response = client.patch( - f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers - ) + response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers) assert response.status_code == 404 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index a7c721b09..647a36ece 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,18 +1,17 @@ -from collections import namedtuple +import time import uuid -from langflow.processing.process import Result -from langflow.services.auth.utils import get_password_hash -from langflow.services.database.models.api_key.api_key import ApiKey -from langflow.services.getters import get_settings_service -from langflow.services.database.utils import session_getter -from langflow.services.getters import get_db_service +from collections import namedtuple + import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS +from langflow.processing.process import Result +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.api_key.model import ApiKey +from langflow.services.database.utils import session_getter +from langflow.services.deps import get_db_service, get_settings_service from langflow.template.frontend_node.chains import TimeTravelGuideChainNode -import time - def run_post(client, flow_id, headers, post_data): response = client.post( @@ -31,10 +30,7 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1): href, headers=headers, ) - if ( - task_status_response.status_code == 200 - and task_status_response.json()["status"] == "SUCCESS" - ): + if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": return task_status_response.json() time.sleep(sleep_time) return None # Return None if task did not complete in time @@ -127,11 +123,7 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if ( - existing_api_key := session.query(ApiKey) - .filter(ApiKey.api_key == api_key.api_key) - .first() - ): + if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): return existing_api_key session.add(api_key) session.commit() @@ -190,9 +182,7 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key): } invalid_id = uuid.uuid4() - response = client.post( - f"api/v1/process/{invalid_id}", headers=headers, json=post_data - ) + response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data) assert response.status_code == 404 assert f"Flow {invalid_id} not found" in response.json()["detail"] @@ -233,9 +223,7 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - monkeypatch.setattr( - endpoints, "process_graph_cached_task", mock_process_graph_cached_task - ) + monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task) api_key = created_api_key.api_key headers = {"x-api-key": api_key} @@ -420,110 +408,77 @@ def test_various_prompts(client, prompt, expected_input_variables): def test_get_vertices_flow_not_found(client, logged_in_headers): - response = client.get( - "/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers - ) - assert ( - response.status_code == 500 - ) # Or whatever status code you've set for invalid ID + response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers) + assert response.status_code == 500 # Or whatever status code you've set for invalid ID def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers): flow_id = added_flow_with_prompt_and_history["id"] - response = client.get( - f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers - ) + response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) assert response.status_code == 200 assert "ids" in response.json() # The response should contain the list in this order # ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1'] # The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain) - ids = [id.split("-")[0] for id in response.json()["ids"]] - assert ids == [ - "ConversationBufferMemory", - "PromptTemplate", - "ChatOpenAI", - "LLMChain", - ] + ids = [inner_id.split("-")[0] for _id in response.json()["ids"] for inner_id in _id] + assert ids == ["ChatOpenAI", "PromptTemplate", "ConversationBufferMemory", "LLMChain"] def test_build_vertex_invalid_flow_id(client, logged_in_headers): - response = client.post( - "/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers - ) + response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers) assert response.status_code == 500 -def test_build_vertex_invalid_vertex_id( - client, added_flow_with_prompt_and_history, logged_in_headers -): +def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers): flow_id = added_flow_with_prompt_and_history["id"] - response = client.post( - f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers - ) + response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers) assert response.status_code == 500 -def test_build_all_vertices_in_sequence_with_chat_input( - client, added_flow_chat_input, logged_in_headers -): +def test_build_all_vertices_in_sequence_with_chat_input(client, added_flow_chat_input, logged_in_headers): flow_id = added_flow_chat_input["id"] # First, get all the vertices in the correct sequence - response = client.get( - f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers - ) + response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) assert response.status_code == 200 assert "ids" in response.json() vertex_ids = response.json()["ids"] # Now, iterate through each vertex and build it for vertex_id in vertex_ids: - response = client.post( - f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers - ) + response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers) json_response = response.json() - assert ( - response.status_code == 200 - ), f"Failed at vertex {vertex_id}: {json_response}" + assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}" assert "valid" in json_response assert json_response["valid"], json_response["params"] -def test_build_all_vertices_in_sequence_with_two_outputs( - client, added_flow_two_outputs, logged_in_headers -): +def test_build_all_vertices_in_sequence_with_two_outputs(client, added_flow_two_outputs, logged_in_headers): """This tests the case where a node has two outputs, one of which is Text and the other (in this case) is a LLMChain. We need to make sure the correct output is passed in both cases.""" flow_id = added_flow_two_outputs["id"] # First, get all the vertices in the correct sequence - response = client.get( - f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers - ) + response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers) assert response.status_code == 200 assert "ids" in response.json() vertex_ids = response.json()["ids"] # Now, iterate through each vertex and build it for vertex_id in vertex_ids: - response = client.post( - f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers - ) + response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers) json_response = response.json() - assert ( - response.status_code == 200 - ), f"Failed at vertex {vertex_id}: {json_response}" + assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}" assert "valid" in json_response assert json_response["valid"], json_response["params"] -def test_basic_chat_in_process(client, added_flow, created_api_key): +def test_basic_chat_in_process(client, flow, created_api_key): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"text": "Hi, My name is Gabriel"}} response = client.post( - f"api/v1/process/{added_flow.get('id')}", + f"api/v1/process/{flow.get('id')}", headers=headers, json=post_data, ) @@ -540,7 +495,7 @@ def test_basic_chat_in_process(client, added_flow, created_api_key): "session_id": response.json()["session_id"], } response = client.post( - f"api/v1/process/{added_flow.get('id')}", + f"api/v1/process/{flow.get('id')}", headers=headers, json=post_data, ) @@ -548,12 +503,12 @@ def test_basic_chat_in_process(client, added_flow, created_api_key): assert "Gabriel" in response.json()["result"]["text"] -def test_basic_chat_different_session_ids(client, added_flow, created_api_key): +def test_basic_chat_different_session_ids(client, flow, created_api_key): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"text": "Hi, My name is Gabriel"}} response = client.post( - f"api/v1/process/{added_flow.get('id')}", + f"api/v1/process/{flow.get('id')}", headers=headers, json=post_data, ) @@ -570,7 +525,7 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): "inputs": {"text": "What is my name?"}, } response = client.post( - f"api/v1/process/{added_flow.get('id')}", + f"api/v1/process/{flow.get('id')}", headers=headers, json=post_data, ) @@ -579,9 +534,9 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key): assert session_id1 != response.json()["session_id"] -def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key): +def test_basic_chat_with_two_session_ids_and_names(client, flow, created_api_key): headers = {"x-api-key": created_api_key.api_key} - flow_id = added_flow.get("id") + flow_id = flow.get("id") names = ["Gabriel", "John"] session_ids = [] @@ -606,9 +561,7 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a @pytest.mark.async_test -def test_vector_store_in_process( - distributed_client, added_vector_store, created_api_key -): +def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "What is Langflow?"}} @@ -627,13 +580,13 @@ def test_vector_store_in_process( # Test function without loop @pytest.mark.async_test -def test_async_task_processing(distributed_client, added_flow, created_api_key): +def test_async_task_processing(distributed_client, flow, created_api_key): headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"text": "Hi, My name is Gabriel"}} - + flow = flow.model_dump() # Run the /api/v1/process/{flow_id} endpoint with sync=False response = distributed_client.post( - f"api/v1/process/{added_flow.get('id')}", + f"api/v1/process/{flow.get('id')}", headers=headers, json={**post_data, "sync": False}, ) @@ -659,9 +612,7 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key): # Test function without loop @pytest.mark.async_test -def test_async_task_processing_vector_store( - client, added_vector_store, created_api_key -): +def test_async_task_processing_vector_store(client, added_vector_store, created_api_key): headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "How do I upload examples?"}} @@ -690,6 +641,4 @@ def test_async_task_processing_vector_store( # Validate that the task completed successfully and the result is as expected assert "result" in task_status_json, task_status_json assert "output" in task_status_json["result"], task_status_json["result"] - assert "Langflow" in task_status_json["result"]["output"], task_status_json[ - "result" - ] + assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"] diff --git a/tests/test_frontend_nodes.py b/tests/test_frontend_nodes.py index 00fe9fcb1..e92ad1fe4 100644 --- a/tests/test_frontend_nodes.py +++ b/tests/test_frontend_nodes.py @@ -31,17 +31,14 @@ def test_template_field_defaults(sample_template_field: TemplateField): assert sample_template_field.is_list is False assert sample_template_field.show is True assert sample_template_field.multiline is False - assert sample_template_field.value is None - assert sample_template_field.suffixes == [] + assert sample_template_field.value == "" assert sample_template_field.file_types == [] - assert sample_template_field.file_path is None + assert sample_template_field.file_path == "" assert sample_template_field.password is False assert sample_template_field.name == "test_field" -def test_template_to_dict( - sample_template: Template, sample_template_field: TemplateField -): +def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField): template_dict = sample_template.to_dict() assert template_dict["_type"] == "test_template" assert len(template_dict) == 2 # _type and test_field diff --git a/tests/test_graph.py b/tests/test_graph.py index f32bc21d7..6cc19a101 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,32 +1,29 @@ import copy import json import os -from pathlib import Path import pickle +from pathlib import Path from typing import Type, Union -from langflow.graph.edge.base import Edge -from langflow.graph.vertex.base import Vertex -from langchain.agents import AgentExecutor + import pytest +from langchain.agents import AgentExecutor from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM from langflow.graph import Graph -from langflow.graph.vertex.types import ( - FileToolVertex, - LLMVertex, - ToolkitVertex, -) -from langflow.processing.process import get_result_and_thought -from langflow.utils.payload import get_root_node +from langflow.graph.edge.base import Edge from langflow.graph.graph.utils import ( find_last_node, + process_flow, set_new_target_handle, ungroup_node, - process_flow, update_source_handle, update_target_handle, update_template, ) +from langflow.graph.vertex.base import Vertex +from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex +from langflow.processing.process import get_result_and_thought +from langflow.utils.payload import get_root_vertex # Test cases for the graph module @@ -47,13 +44,7 @@ def sample_nodes(): return [ { "id": "node1", - "data": { - "node": { - "template": { - "some_field": {"show": True, "advanced": False, "name": "Name1"} - } - } - }, + "data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}}, }, { "id": "node2", @@ -71,11 +62,7 @@ def sample_nodes(): }, { "id": "node3", - "data": { - "node": { - "template": {"unrelated_field": {"show": True, "advanced": True}} - } - }, + "data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}}, }, ] @@ -93,8 +80,10 @@ def test_graph_structure(basic_graph): assert isinstance(node, Vertex) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert edge.source in basic_graph.vertices - assert edge.target in basic_graph.vertices + source_vertex = basic_graph.get_vertex(edge.source_id) + target_vertex = basic_graph.get_vertex(edge.target_id) + assert source_vertex in basic_graph.vertices + assert target_vertex in basic_graph.vertices def test_circular_dependencies(basic_graph): @@ -102,7 +91,7 @@ def test_circular_dependencies(basic_graph): def check_circular(node, visited): visited.add(node) - neighbors = basic_graph.get_nodes_with_target(node) + neighbors = basic_graph.get_vertices_with_target(node) for neighbor in neighbors: if neighbor in visited: return True @@ -135,13 +124,13 @@ def test_invalid_node_types(): Graph(graph_data["nodes"], graph_data["edges"]) -def test_get_nodes_with_target(basic_graph): +def test_get_vertices_with_target(basic_graph): """Test getting connected nodes""" assert isinstance(basic_graph, Graph) # Get root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None - connected_nodes = basic_graph.get_nodes_with_target(root) + connected_nodes = basic_graph.get_vertices_with_target(root.id) assert connected_nodes is not None @@ -150,23 +139,17 @@ def test_get_node_neighbors_basic(basic_graph): assert isinstance(basic_graph, Graph) # Get root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None - neighbors = basic_graph.get_node_neighbors(root) + neighbors = basic_graph.get_vertex_neighbors(root) assert neighbors is not None assert isinstance(neighbors, dict) # Root Node is an Agent, it requires an LLMChain and tools # We need to check if there is a Chain in the one of the neighbors' # data attribute in the type key - assert any( - "ConversationBufferMemory" in neighbor.data["type"] - for neighbor, val in neighbors.items() - if val - ) + assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) - assert any( - "OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val - ) + assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) def test_get_node(basic_graph): @@ -180,7 +163,7 @@ def test_get_node(basic_graph): def test_build_nodes(basic_graph): """Test building nodes""" - assert len(basic_graph.vertices) == len(basic_graph._nodes) + assert len(basic_graph.vertices) == len(basic_graph._vertices) for node in basic_graph.vertices: assert isinstance(node, Vertex) @@ -190,20 +173,20 @@ def test_build_edges(basic_graph): assert len(basic_graph.edges) == len(basic_graph._edges) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert isinstance(edge.source, Vertex) - assert isinstance(edge.target, Vertex) + assert isinstance(edge.source_id, str) + assert isinstance(edge.target_id, str) -def test_get_root_node(client, basic_graph, complex_graph): +def test_get_root_vertex(client, basic_graph, complex_graph): """Test getting root node""" assert isinstance(basic_graph, Graph) - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "TimeTravelGuideChain" # For complex example, the root node is a ZeroShotAgent too assert isinstance(complex_graph, Graph) - root = get_root_node(complex_graph) + root = get_root_vertex(complex_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "ZeroShotAgent" @@ -239,7 +222,7 @@ def test_build_params(basic_graph): # The matched_type attribute should be in the source_types attr assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges) # Get the root node - root = get_root_node(basic_graph) + root = get_root_vertex(basic_graph) # Root node is a TimeTravelGuideChain # which requires an llm and memory assert root is not None @@ -248,17 +231,18 @@ def test_build_params(basic_graph): assert "memory" in root.params -def test_build(basic_graph): +@pytest.mark.asyncio +async def test_build(basic_graph): """Test Node's build method""" - assert_agent_was_built(basic_graph) + await assert_agent_was_built(basic_graph) -def assert_agent_was_built(graph): +async def assert_agent_was_built(graph): """Assert that the agent was built""" assert isinstance(graph, Graph) # Now we test the build method # Build the Agent - result = graph.build() + result = await graph.build() # The agent should be a AgentExecutor assert isinstance(result, Chain) @@ -307,7 +291,8 @@ def test_file_tool_node_build(client, openapi_graph): # assert built_object is not None -def test_get_result_and_thought(basic_graph): +@pytest.mark.asyncio +async def test_get_result_and_thought(basic_graph): """Test the get_result_and_thought method""" responses = [ "Final Answer: I am a response", @@ -319,7 +304,7 @@ def test_get_result_and_thought(basic_graph): assert llm_node is not None llm_node._built_object = FakeListLLM(responses=responses) llm_node._built = True - langchain_object = basic_graph.build() + langchain_object = await basic_graph.build() # assert all nodes are built assert all(node._built for node in basic_graph.vertices) # now build again and check if FakeListLLM was used @@ -339,9 +324,7 @@ def test_find_last_node(grouped_chat_json_flow): def test_ungroup_node(grouped_chat_json_flow): grouped_chat_data = json.loads(grouped_chat_json_flow).get("data") - group_node = grouped_chat_data["nodes"][ - 2 - ] # Assuming the first node is a group node + group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node base_flow = copy.deepcopy(grouped_chat_data) ungroup_node(group_node["data"], base_flow) # after ungroup_node is called, the base_flow and grouped_chat_data should be different @@ -393,14 +376,9 @@ def test_process_flow_one_group(one_grouped_chat_json_flow): assert "edges" in processed_flow # Now get the node that has ChatOpenAI in its id - chat_openai_node = next( - (node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None - ) + chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None) assert chat_openai_node is not None - assert ( - chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] - == "test" - ) + assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test" def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow): @@ -449,17 +427,11 @@ def test_update_template(sample_template, sample_nodes): assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False - assert ( - node1_updated["data"]["node"]["template"]["some_field"]["display_name"] - == "Name1" - ) + assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1" assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True - assert ( - node2_updated["data"]["node"]["template"]["other_field"]["display_name"] - == "DisplayName2" - ) + assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2" # Ensure node3 remains unchanged assert node3_updated == sample_nodes[2] @@ -490,9 +462,7 @@ def test_set_new_target_handle(): "data": { "node": { "flow": True, - "template": { - "field_1": {"proxy": {"field": "new_field", "id": "new_id"}} - }, + "template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}}, } } } @@ -512,33 +482,33 @@ def test_update_source_handle(): "nodes": [{"id": "some_node"}, {"id": "last_node"}], "edges": [{"source": "some_node"}], } - updated_edge = update_source_handle( - new_edge, flow_data["nodes"], flow_data["edges"] - ) + updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"]) assert updated_edge["source"] == "last_node" assert updated_edge["data"]["sourceHandle"]["id"] == "last_node" -def test_pickle_graph(json_vector_store): +@pytest.mark.asyncio +async def test_pickle_graph(json_vector_store): loaded_json = json.loads(json_vector_store) graph = Graph.from_payload(loaded_json) assert isinstance(graph, Graph) - first_result = graph.build() + first_result = await graph.build() assert isinstance(first_result, AgentExecutor) pickled = pickle.dumps(graph) assert pickled is not None unpickled = pickle.loads(pickled) assert unpickled is not None - result = unpickled.build() + result = await unpickled.build() assert isinstance(result, AgentExecutor) -def test_pickle_each_vertex(json_vector_store): +@pytest.mark.asyncio +async def test_pickle_each_vertex(json_vector_store): loaded_json = json.loads(json_vector_store) graph = Graph.from_payload(loaded_json) assert isinstance(graph, Graph) - for vertex in graph.nodes: - vertex.build() + for vertex in graph.vertices: + await vertex.build() pickled = pickle.dumps(vertex) assert pickled is not None unpickled = pickle.loads(pickled) diff --git a/tests/test_vectorstore_template.py b/tests/test_vectorstore_template.py index 9dd131dbc..3b5c7ed42 100644 --- a/tests/test_vectorstore_template.py +++ b/tests/test_vectorstore_template.py @@ -1,5 +1,5 @@ from fastapi.testclient import TestClient -from langflow.services.getters import get_settings_service +from langflow.services.deps import get_settings_service # check that all agents are in settings.agents