Refactor imports and update function names

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-25 16:17:40 -03:00
commit abe4925cc3
8 changed files with 156 additions and 314 deletions

View file

@ -1,6 +1,5 @@
from fastapi.testclient import TestClient
# def test_chains_settings(client: TestClient, logged_in_headers):
# response = client.get("api/v1/all", headers=logged_in_headers)
# assert response.status_code == 200
@ -18,13 +17,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers):
chain = chains["ConversationChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"ConversationChain",
"LLMChain",
"Chain",
"function",
"Text",
}
assert set(chain["base_classes"]) == {"Callable", "Chain"}
template = chain["template"]
assert template["memory"] == {
@ -97,10 +90,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers):
assert template["_type"] == "ConversationChain"
# Test the description object
assert (
chain["description"]
== "Chain to have a conversation and load context from memory."
)
assert chain["description"] == "Chain to have a conversation and load context from memory."
def test_llm_chain(client: TestClient, logged_in_headers):
@ -287,10 +277,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
assert template["_type"] == "LLMMathChain"
# Test the description object
assert (
chain["description"]
== "Chain that interprets a prompt and executes python code to do math."
)
assert chain["description"] == "Chain that interprets a prompt and executes python code to do math."
def test_series_character_chain(client: TestClient, logged_in_headers):
@ -396,10 +383,7 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers):
"info": "",
}
# Test the description object
assert (
chain["description"]
== "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
)
assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
def test_time_travel_guide_chain(client: TestClient, logged_in_headers):

View file

@ -1,9 +1,9 @@
from pathlib import Path
from tempfile import tempdir
from langflow.__main__ import app
import pytest
from langflow.services import getters
import pytest
from langflow.__main__ import app
from langflow.services import deps
@pytest.fixture(scope="module")
@ -16,6 +16,7 @@ def default_settings():
def test_components_path(runner, client, default_settings):
# Create a foldr in the tmp directory
temp_dir = Path(tempdir)
# create a "components" folder
temp_dir = temp_dir / "components"
@ -26,7 +27,7 @@ def test_components_path(runner, client, default_settings):
["run", "--components-path", str(temp_dir), *default_settings],
)
assert result.exit_code == 0, result.stdout
settings_service = getters.get_settings_service()
settings_service = deps.get_settings_service()
assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH

View file

@ -1,19 +1,18 @@
import ast
import pytest
import types
from uuid import uuid4
import pytest
from fastapi import HTTPException
from langflow.services.database.models.flow import Flow, FlowCreate
from langchain_core.documents import Document
from langflow.interface.custom.base import CustomComponent
from langflow.interface.custom.component import (
from langflow.interface.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.interface.custom.custom_component.component import (
Component,
ComponentCodeNullError,
ComponentFunctionEntrypointNameNullError,
)
from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError
from langflow.services.database.models.flow import Flow, FlowCreate
code_default = """
from langflow import Prompt
@ -53,7 +52,7 @@ def test_code_parser_get_tree():
Test the __get_tree method of the CodeParser class.
"""
parser = CodeParser(code_default)
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
assert isinstance(tree, ast.AST)
@ -66,7 +65,7 @@ def test_code_parser_syntax_error():
parser = CodeParser(code_syntax_error)
with pytest.raises(CodeSyntaxError):
parser._CodeParser__get_tree()
parser.get_tree()
def test_component_init():
@ -113,9 +112,7 @@ def test_custom_component_init():
"""
function_entrypoint_name = "build"
custom_component = CustomComponent(
code=code_default, function_entrypoint_name=function_entrypoint_name
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name)
assert custom_component.code == code_default
assert custom_component.function_entrypoint_name == function_entrypoint_name
@ -124,10 +121,8 @@ def test_custom_component_build_template_config():
"""
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
config = custom_component.build_template_config
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.build_template_config()
assert isinstance(config, dict)
@ -135,9 +130,7 @@ def test_custom_component_get_function():
"""
Test the get_function property of the CustomComponent class.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function
assert isinstance(my_function, types.FunctionType)
@ -148,7 +141,7 @@ def test_code_parser_parse_imports_import():
class with an import statement.
"""
parser = CodeParser(code_default)
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
parser.parse_imports(node)
@ -161,7 +154,7 @@ def test_code_parser_parse_imports_importfrom():
class with an import from statement.
"""
parser = CodeParser("from os import path")
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
parser.parse_imports(node)
@ -173,7 +166,7 @@ def test_code_parser_parse_functions():
Test the parse_functions method of the CodeParser class.
"""
parser = CodeParser("def test(): pass")
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
parser.parse_functions(node)
@ -186,7 +179,7 @@ def test_code_parser_parse_classes():
Test the parse_classes method of the CodeParser class.
"""
parser = CodeParser("class Test: pass")
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
parser.parse_classes(node)
@ -199,7 +192,7 @@ def test_code_parser_parse_global_vars():
Test the parse_global_vars method of the CodeParser class.
"""
parser = CodeParser("x = 1")
tree = parser._CodeParser__get_tree()
tree = parser.get_tree()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
parser.parse_global_vars(node)
@ -222,9 +215,7 @@ def test_custom_component_get_function_entrypoint_args():
Test the get_function_entrypoint_args
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
args = custom_component.get_function_entrypoint_args
assert len(args) == 4
assert args[0]["name"] == "self"
@ -237,20 +228,16 @@ def test_custom_component_get_function_entrypoint_return_type():
Test the get_function_entrypoint_return_type
property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
return_type = custom_component.get_function_entrypoint_return_type
assert return_type == ["Document"]
assert return_type == [Document]
def test_custom_component_get_main_class_name():
"""
Test the get_main_class_name property of the CustomComponent class.
"""
custom_component = CustomComponent(
code=code_default, function_entrypoint_name="build"
)
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
class_name = custom_component.get_main_class_name
assert class_name == "YourComponent"
@ -260,9 +247,7 @@ def test_custom_component_get_function_valid():
Test the get_function property of the CustomComponent
class with valid code and function_entrypoint_name.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = custom_component.get_function
assert callable(my_function)
@ -297,9 +282,7 @@ def test_code_parser_parse_callable_details_no_args():
parser = CodeParser("")
node = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -345,9 +328,7 @@ def test_code_parser_parse_function_def_not_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="test",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -365,9 +346,7 @@ def test_code_parser_parse_function_def_init():
parser = CodeParser("")
stmt = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]
),
args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]),
body=[],
decorator_list=[],
returns=None,
@ -402,9 +381,7 @@ def test_custom_component_get_code_tree_syntax_error():
Test the get_code_tree method of the CustomComponent class
raises the CodeSyntaxError when given incorrect syntax.
"""
custom_component = CustomComponent(
code="import os as", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
custom_component.get_code_tree(custom_component.code)
@ -458,9 +435,7 @@ def test_custom_component_build_not_implemented():
Test the build method of the CustomComponent
class raises the NotImplementedError.
"""
custom_component = CustomComponent(
code="def build(): pass", function_entrypoint_name="build"
)
custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build")
with pytest.raises(NotImplementedError):
custom_component.build()
@ -494,9 +469,7 @@ def test_flow(db):
}
# Create flow
flow = FlowCreate(
id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data
)
flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data)
# Add to database
db.add(flow)

View file

@ -1,16 +1,14 @@
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.utils import session_getter
from langflow.services.getters import get_db_service
from uuid import UUID, uuid4
import orjson
import pytest
from uuid import UUID, uuid4
from sqlmodel import Session
from fastapi.testclient import TestClient
from langflow.api.v1.schemas import FlowListCreate
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from sqlmodel import Session
@pytest.fixture(scope="module")
@ -27,9 +25,7 @@ def json_style():
)
def test_create_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
@ -39,9 +35,7 @@ def test_create_flow(
assert response.json()["data"] == flow.data
# flow is optional so we can create a flow without a flow
flow = FlowCreate(name="Test Flow")
response = client.post(
"api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers
)
response = client.post("api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers)
assert response.status_code == 201
assert response.json()["name"] == flow.name
assert response.json()["data"] == flow.data
@ -82,9 +76,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he
assert response.json()["data"] == flow.data
def test_update_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
@ -97,9 +89,7 @@ def test_update_flow(
description="updated description",
data=data,
)
response = client.patch(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
assert response.status_code == 200
assert response.json()["name"] == updated_flow.name
@ -107,9 +97,7 @@ def test_update_flow(
# assert response.json()["data"] == updated_flow.data
def test_delete_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
flow = FlowCreate(name="Test Flow", description="description", data=data)
@ -120,9 +108,7 @@ def test_delete_flow(
assert response.json()["message"] == "Flow deleted successfully"
def test_create_flows(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -133,9 +119,7 @@ def test_create_flows(
]
)
# Make request to endpoint
response = client.post(
"api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers
)
response = client.post("api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers)
# Check response status code
assert response.status_code == 201
# Check response data
@ -149,9 +133,7 @@ def test_create_flows(
assert response_data[1]["data"] == data
def test_upload_file(
client: TestClient, session: Session, json_flow: str, logged_in_headers
):
def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers):
flow = orjson.loads(json_flow)
data = flow["data"]
# Create test data
@ -218,9 +200,7 @@ def test_download_file(
assert response_data[1]["data"] == data
def test_create_flow_with_invalid_data(
client: TestClient, active_user, logged_in_headers
):
def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers):
flow = {"name": "a" * 256, "data": "Invalid flow data"}
response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers)
assert response.status_code == 422
@ -232,29 +212,19 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers
assert response.status_code == 404
def test_update_flow_idempotency(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
flow_data = FlowCreate(name="Test Flow", description="description", data=data)
response = client.post(
"api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers
)
response = client.post("api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers)
flow_id = response.json()["id"]
updated_flow = FlowCreate(name="Updated Flow", description="description", data=data)
response1 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response2 = client.put(
f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers
)
response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers)
assert response1.json() == response2.json()
def test_update_nonexistent_flow(
client: TestClient, json_flow: str, active_user, logged_in_headers
):
def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers):
flow_data = orjson.loads(json_flow)
data = flow_data["data"]
uuid = uuid4()
@ -263,9 +233,7 @@ def test_update_nonexistent_flow(
description="description",
data=data,
)
response = client.patch(
f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers
)
response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers)
assert response.status_code == 404

View file

@ -1,18 +1,17 @@
from collections import namedtuple
import time
import uuid
from langflow.processing.process import Result
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.api_key import ApiKey
from langflow.services.getters import get_settings_service
from langflow.services.database.utils import session_getter
from langflow.services.getters import get_db_service
from collections import namedtuple
import pytest
from fastapi.testclient import TestClient
from langflow.interface.tools.constants import CUSTOM_TOOLS
from langflow.processing.process import Result
from langflow.services.auth.utils import get_password_hash
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.template.frontend_node.chains import TimeTravelGuideChainNode
import time
def run_post(client, flow_id, headers, post_data):
response = client.post(
@ -31,10 +30,7 @@ def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1):
href,
headers=headers,
)
if (
task_status_response.status_code == 200
and task_status_response.json()["status"] == "SUCCESS"
):
if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS":
return task_status_response.json()
time.sleep(sleep_time)
return None # Return None if task did not complete in time
@ -127,11 +123,7 @@ def created_api_key(active_user):
)
db_manager = get_db_service()
with session_getter(db_manager) as session:
if (
existing_api_key := session.query(ApiKey)
.filter(ApiKey.api_key == api_key.api_key)
.first()
):
if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first():
return existing_api_key
session.add(api_key)
session.commit()
@ -190,9 +182,7 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key):
}
invalid_id = uuid.uuid4()
response = client.post(
f"api/v1/process/{invalid_id}", headers=headers, json=post_data
)
response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data)
assert response.status_code == 404
assert f"Flow {invalid_id} not found" in response.json()["detail"]
@ -233,9 +223,7 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k
monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached)
monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses)
monkeypatch.setattr(
endpoints, "process_graph_cached_task", mock_process_graph_cached_task
)
monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task)
api_key = created_api_key.api_key
headers = {"x-api-key": api_key}
@ -420,110 +408,77 @@ def test_various_prompts(client, prompt, expected_input_variables):
def test_get_vertices_flow_not_found(client, logged_in_headers):
response = client.get(
"/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers
)
assert (
response.status_code == 500
) # Or whatever status code you've set for invalid ID
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
assert response.status_code == 200
assert "ids" in response.json()
# The response should contain the list in this order
# ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1']
# The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain)
ids = [id.split("-")[0] for id in response.json()["ids"]]
assert ids == [
"ConversationBufferMemory",
"PromptTemplate",
"ChatOpenAI",
"LLMChain",
]
ids = [inner_id.split("-")[0] for _id in response.json()["ids"] for inner_id in _id]
assert ids == ["ChatOpenAI", "PromptTemplate", "ConversationBufferMemory", "LLMChain"]
def test_build_vertex_invalid_flow_id(client, logged_in_headers):
response = client.post(
"/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers
)
response = client.post("/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers)
assert response.status_code == 500
def test_build_vertex_invalid_vertex_id(
client, added_flow_with_prompt_and_history, logged_in_headers
):
def test_build_vertex_invalid_vertex_id(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.post(
f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers
)
response = client.post(f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers)
assert response.status_code == 500
def test_build_all_vertices_in_sequence_with_chat_input(
client, added_flow_chat_input, logged_in_headers
):
def test_build_all_vertices_in_sequence_with_chat_input(client, added_flow_chat_input, logged_in_headers):
flow_id = added_flow_chat_input["id"]
# First, get all the vertices in the correct sequence
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
json_response = response.json()
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
def test_build_all_vertices_in_sequence_with_two_outputs(
client, added_flow_two_outputs, logged_in_headers
):
def test_build_all_vertices_in_sequence_with_two_outputs(client, added_flow_two_outputs, logged_in_headers):
"""This tests the case where a node has two outputs, one of which is Text and the other (in this case) is
a LLMChain. We need to make sure the correct output is passed in both cases."""
flow_id = added_flow_two_outputs["id"]
# First, get all the vertices in the correct sequence
response = client.get(
f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers
)
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
assert response.status_code == 200
assert "ids" in response.json()
vertex_ids = response.json()["ids"]
# Now, iterate through each vertex and build it
for vertex_id in vertex_ids:
response = client.post(
f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers
)
response = client.post(f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers)
json_response = response.json()
assert (
response.status_code == 200
), f"Failed at vertex {vertex_id}: {json_response}"
assert response.status_code == 200, f"Failed at vertex {vertex_id}: {json_response}"
assert "valid" in json_response
assert json_response["valid"], json_response["params"]
def test_basic_chat_in_process(client, added_flow, created_api_key):
def test_basic_chat_in_process(client, flow, created_api_key):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
response = client.post(
f"api/v1/process/{added_flow.get('id')}",
f"api/v1/process/{flow.get('id')}",
headers=headers,
json=post_data,
)
@ -540,7 +495,7 @@ def test_basic_chat_in_process(client, added_flow, created_api_key):
"session_id": response.json()["session_id"],
}
response = client.post(
f"api/v1/process/{added_flow.get('id')}",
f"api/v1/process/{flow.get('id')}",
headers=headers,
json=post_data,
)
@ -548,12 +503,12 @@ def test_basic_chat_in_process(client, added_flow, created_api_key):
assert "Gabriel" in response.json()["result"]["text"]
def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
def test_basic_chat_different_session_ids(client, flow, created_api_key):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
response = client.post(
f"api/v1/process/{added_flow.get('id')}",
f"api/v1/process/{flow.get('id')}",
headers=headers,
json=post_data,
)
@ -570,7 +525,7 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
"inputs": {"text": "What is my name?"},
}
response = client.post(
f"api/v1/process/{added_flow.get('id')}",
f"api/v1/process/{flow.get('id')}",
headers=headers,
json=post_data,
)
@ -579,9 +534,9 @@ def test_basic_chat_different_session_ids(client, added_flow, created_api_key):
assert session_id1 != response.json()["session_id"]
def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_api_key):
def test_basic_chat_with_two_session_ids_and_names(client, flow, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = added_flow.get("id")
flow_id = flow.get("id")
names = ["Gabriel", "John"]
session_ids = []
@ -606,9 +561,7 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a
@pytest.mark.async_test
def test_vector_store_in_process(
distributed_client, added_vector_store, created_api_key
):
def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key):
# Run the /api/v1/process/{flow_id} endpoint
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"input": "What is Langflow?"}}
@ -627,13 +580,13 @@ def test_vector_store_in_process(
# Test function without loop
@pytest.mark.async_test
def test_async_task_processing(distributed_client, added_flow, created_api_key):
def test_async_task_processing(distributed_client, flow, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"text": "Hi, My name is Gabriel"}}
flow = flow.model_dump()
# Run the /api/v1/process/{flow_id} endpoint with sync=False
response = distributed_client.post(
f"api/v1/process/{added_flow.get('id')}",
f"api/v1/process/{flow.get('id')}",
headers=headers,
json={**post_data, "sync": False},
)
@ -659,9 +612,7 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key):
# Test function without loop
@pytest.mark.async_test
def test_async_task_processing_vector_store(
client, added_vector_store, created_api_key
):
def test_async_task_processing_vector_store(client, added_vector_store, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
post_data = {"inputs": {"input": "How do I upload examples?"}}
@ -690,6 +641,4 @@ def test_async_task_processing_vector_store(
# Validate that the task completed successfully and the result is as expected
assert "result" in task_status_json, task_status_json
assert "output" in task_status_json["result"], task_status_json["result"]
assert "Langflow" in task_status_json["result"]["output"], task_status_json[
"result"
]
assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"]

View file

@ -31,17 +31,14 @@ def test_template_field_defaults(sample_template_field: TemplateField):
assert sample_template_field.is_list is False
assert sample_template_field.show is True
assert sample_template_field.multiline is False
assert sample_template_field.value is None
assert sample_template_field.suffixes == []
assert sample_template_field.value == ""
assert sample_template_field.file_types == []
assert sample_template_field.file_path is None
assert sample_template_field.file_path == ""
assert sample_template_field.password is False
assert sample_template_field.name == "test_field"
def test_template_to_dict(
sample_template: Template, sample_template_field: TemplateField
):
def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField):
template_dict = sample_template.to_dict()
assert template_dict["_type"] == "test_template"
assert len(template_dict) == 2 # _type and test_field

View file

@ -1,32 +1,29 @@
import copy
import json
import os
from pathlib import Path
import pickle
from pathlib import Path
from typing import Type, Union
from langflow.graph.edge.base import Edge
from langflow.graph.vertex.base import Vertex
from langchain.agents import AgentExecutor
import pytest
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain.llms.fake import FakeListLLM
from langflow.graph import Graph
from langflow.graph.vertex.types import (
FileToolVertex,
LLMVertex,
ToolkitVertex,
)
from langflow.processing.process import get_result_and_thought
from langflow.utils.payload import get_root_node
from langflow.graph.edge.base import Edge
from langflow.graph.graph.utils import (
find_last_node,
process_flow,
set_new_target_handle,
ungroup_node,
process_flow,
update_source_handle,
update_target_handle,
update_template,
)
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
from langflow.processing.process import get_result_and_thought
from langflow.utils.payload import get_root_vertex
# Test cases for the graph module
@ -47,13 +44,7 @@ def sample_nodes():
return [
{
"id": "node1",
"data": {
"node": {
"template": {
"some_field": {"show": True, "advanced": False, "name": "Name1"}
}
}
},
"data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}},
},
{
"id": "node2",
@ -71,11 +62,7 @@ def sample_nodes():
},
{
"id": "node3",
"data": {
"node": {
"template": {"unrelated_field": {"show": True, "advanced": True}}
}
},
"data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}},
},
]
@ -93,8 +80,10 @@ def test_graph_structure(basic_graph):
assert isinstance(node, Vertex)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert edge.source in basic_graph.vertices
assert edge.target in basic_graph.vertices
source_vertex = basic_graph.get_vertex(edge.source_id)
target_vertex = basic_graph.get_vertex(edge.target_id)
assert source_vertex in basic_graph.vertices
assert target_vertex in basic_graph.vertices
def test_circular_dependencies(basic_graph):
@ -102,7 +91,7 @@ def test_circular_dependencies(basic_graph):
def check_circular(node, visited):
visited.add(node)
neighbors = basic_graph.get_nodes_with_target(node)
neighbors = basic_graph.get_vertices_with_target(node)
for neighbor in neighbors:
if neighbor in visited:
return True
@ -135,13 +124,13 @@ def test_invalid_node_types():
Graph(graph_data["nodes"], graph_data["edges"])
def test_get_nodes_with_target(basic_graph):
def test_get_vertices_with_target(basic_graph):
"""Test getting connected nodes"""
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_node(basic_graph)
root = get_root_vertex(basic_graph)
assert root is not None
connected_nodes = basic_graph.get_nodes_with_target(root)
connected_nodes = basic_graph.get_vertices_with_target(root.id)
assert connected_nodes is not None
@ -150,23 +139,17 @@ def test_get_node_neighbors_basic(basic_graph):
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_node(basic_graph)
root = get_root_vertex(basic_graph)
assert root is not None
neighbors = basic_graph.get_node_neighbors(root)
neighbors = basic_graph.get_vertex_neighbors(root)
assert neighbors is not None
assert isinstance(neighbors, dict)
# Root Node is an Agent, it requires an LLMChain and tools
# We need to check if there is a Chain in the one of the neighbors'
# data attribute in the type key
assert any(
"ConversationBufferMemory" in neighbor.data["type"]
for neighbor, val in neighbors.items()
if val
)
assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
assert any(
"OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val
)
assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val)
def test_get_node(basic_graph):
@ -180,7 +163,7 @@ def test_get_node(basic_graph):
def test_build_nodes(basic_graph):
"""Test building nodes"""
assert len(basic_graph.vertices) == len(basic_graph._nodes)
assert len(basic_graph.vertices) == len(basic_graph._vertices)
for node in basic_graph.vertices:
assert isinstance(node, Vertex)
@ -190,20 +173,20 @@ def test_build_edges(basic_graph):
assert len(basic_graph.edges) == len(basic_graph._edges)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert isinstance(edge.source, Vertex)
assert isinstance(edge.target, Vertex)
assert isinstance(edge.source_id, str)
assert isinstance(edge.target_id, str)
def test_get_root_node(client, basic_graph, complex_graph):
def test_get_root_vertex(client, basic_graph, complex_graph):
"""Test getting root node"""
assert isinstance(basic_graph, Graph)
root = get_root_node(basic_graph)
root = get_root_vertex(basic_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "TimeTravelGuideChain"
# For complex example, the root node is a ZeroShotAgent too
assert isinstance(complex_graph, Graph)
root = get_root_node(complex_graph)
root = get_root_vertex(complex_graph)
assert root is not None
assert isinstance(root, Vertex)
assert root.data["type"] == "ZeroShotAgent"
@ -239,7 +222,7 @@ def test_build_params(basic_graph):
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
# Get the root node
root = get_root_node(basic_graph)
root = get_root_vertex(basic_graph)
# Root node is a TimeTravelGuideChain
# which requires an llm and memory
assert root is not None
@ -248,17 +231,18 @@ def test_build_params(basic_graph):
assert "memory" in root.params
def test_build(basic_graph):
@pytest.mark.asyncio
async def test_build(basic_graph):
"""Test Node's build method"""
assert_agent_was_built(basic_graph)
await assert_agent_was_built(basic_graph)
def assert_agent_was_built(graph):
async def assert_agent_was_built(graph):
"""Assert that the agent was built"""
assert isinstance(graph, Graph)
# Now we test the build method
# Build the Agent
result = graph.build()
result = await graph.build()
# The agent should be a AgentExecutor
assert isinstance(result, Chain)
@ -307,7 +291,8 @@ def test_file_tool_node_build(client, openapi_graph):
# assert built_object is not None
def test_get_result_and_thought(basic_graph):
@pytest.mark.asyncio
async def test_get_result_and_thought(basic_graph):
"""Test the get_result_and_thought method"""
responses = [
"Final Answer: I am a response",
@ -319,7 +304,7 @@ def test_get_result_and_thought(basic_graph):
assert llm_node is not None
llm_node._built_object = FakeListLLM(responses=responses)
llm_node._built = True
langchain_object = basic_graph.build()
langchain_object = await basic_graph.build()
# assert all nodes are built
assert all(node._built for node in basic_graph.vertices)
# now build again and check if FakeListLLM was used
@ -339,9 +324,7 @@ def test_find_last_node(grouped_chat_json_flow):
def test_ungroup_node(grouped_chat_json_flow):
grouped_chat_data = json.loads(grouped_chat_json_flow).get("data")
group_node = grouped_chat_data["nodes"][
2
] # Assuming the first node is a group node
group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node
base_flow = copy.deepcopy(grouped_chat_data)
ungroup_node(group_node["data"], base_flow)
# after ungroup_node is called, the base_flow and grouped_chat_data should be different
@ -393,14 +376,9 @@ def test_process_flow_one_group(one_grouped_chat_json_flow):
assert "edges" in processed_flow
# Now get the node that has ChatOpenAI in its id
chat_openai_node = next(
(node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None
)
chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None)
assert chat_openai_node is not None
assert (
chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"]
== "test"
)
assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test"
def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow):
@ -449,17 +427,11 @@ def test_update_template(sample_template, sample_nodes):
assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True
assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False
assert (
node1_updated["data"]["node"]["template"]["some_field"]["display_name"]
== "Name1"
)
assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1"
assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False
assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True
assert (
node2_updated["data"]["node"]["template"]["other_field"]["display_name"]
== "DisplayName2"
)
assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2"
# Ensure node3 remains unchanged
assert node3_updated == sample_nodes[2]
@ -490,9 +462,7 @@ def test_set_new_target_handle():
"data": {
"node": {
"flow": True,
"template": {
"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}
},
"template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}},
}
}
}
@ -512,33 +482,33 @@ def test_update_source_handle():
"nodes": [{"id": "some_node"}, {"id": "last_node"}],
"edges": [{"source": "some_node"}],
}
updated_edge = update_source_handle(
new_edge, flow_data["nodes"], flow_data["edges"]
)
updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"])
assert updated_edge["source"] == "last_node"
assert updated_edge["data"]["sourceHandle"]["id"] == "last_node"
def test_pickle_graph(json_vector_store):
@pytest.mark.asyncio
async def test_pickle_graph(json_vector_store):
loaded_json = json.loads(json_vector_store)
graph = Graph.from_payload(loaded_json)
assert isinstance(graph, Graph)
first_result = graph.build()
first_result = await graph.build()
assert isinstance(first_result, AgentExecutor)
pickled = pickle.dumps(graph)
assert pickled is not None
unpickled = pickle.loads(pickled)
assert unpickled is not None
result = unpickled.build()
result = await unpickled.build()
assert isinstance(result, AgentExecutor)
def test_pickle_each_vertex(json_vector_store):
@pytest.mark.asyncio
async def test_pickle_each_vertex(json_vector_store):
loaded_json = json.loads(json_vector_store)
graph = Graph.from_payload(loaded_json)
assert isinstance(graph, Graph)
for vertex in graph.nodes:
vertex.build()
for vertex in graph.vertices:
await vertex.build()
pickled = pickle.dumps(vertex)
assert pickled is not None
unpickled = pickle.loads(pickled)

View file

@ -1,5 +1,5 @@
from fastapi.testclient import TestClient
from langflow.services.getters import get_settings_service
from langflow.services.deps import get_settings_service
# check that all agents are in settings.agents