diff --git a/src/backend/langflow/database/models/flow.py b/src/backend/langflow/database/models/flow.py index df042eeed..5ad49c668 100644 --- a/src/backend/langflow/database/models/flow.py +++ b/src/backend/langflow/database/models/flow.py @@ -1,6 +1,7 @@ -from sqlmodel import Field, SQLModel, Relationship +from pydantic import validator +from sqlmodel import Field, SQLModel, Relationship, JSON, Column from uuid import UUID, uuid4 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict if TYPE_CHECKING: from langflow.database.models.flowstyle import FlowStyle @@ -8,9 +9,36 @@ if TYPE_CHECKING: class FlowBase(SQLModel): name: str = Field(index=True) - flow: str = Field(index=False) + flow: Dict = Field(default_factory=dict, sa_column=Column(JSON)) style: "FlowStyle" = Relationship(back_populates="flow") + @validator("flow") + def validate_json(v): + # dict_keys(['description', 'name', 'id', 'data']) + if not isinstance(v, dict): + raise ValueError("Flow must be a valid JSON") + if "description" not in v.keys(): + raise ValueError("Flow must have a description") + if "data" not in v.keys(): + raise ValueError("Flow must have data") + + # data must contain nodes and edges + if "nodes" not in v["data"].keys(): + raise ValueError("Flow must have nodes") + if "edges" not in v["data"].keys(): + raise ValueError("Flow must have edges") + + return v + + # @validator("flow") + # def flow_must_be_json(cls, v): + # try: + # valid_json = json.loads(v) + + # except Exception as e: + # raise ValueError(f"Flow must be a valid JSON: {e}") from e + # return v + class Flow(FlowBase, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) diff --git a/tests/conftest.py b/tests/conftest.py index 9d596cc39..61eab9fbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,5 @@ def client_fixture(session: Session): # app.dependency_overrides[get_session] = get_session_override # - client = TestClient(app) # - yield client # + yield TestClient(app) app.dependency_overrides.clear() # diff --git a/tests/test_database.py b/tests/test_database.py index 0478332a0..6da9083bc 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,25 +1,29 @@ +from uuid import uuid4 from langflow.api.schemas import FlowListCreate from langflow.database.models.flow import FlowCreate import json from sqlalchemy.orm import Session from langflow.database.models.flow import Flow +from fastapi.testclient import TestClient + +import threading -def test_create_flow(client, json_flow): - flow = FlowCreate(name="Test Flow", flow=json_flow) +def test_create_flow(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) assert response.status_code == 200 assert response.json()["name"] == flow.name assert response.json()["flow"] == flow.flow -def test_read_flows(client, json_flow): - flow = FlowCreate(name="Test Flow", flow=json_flow) +def test_read_flows(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) assert response.status_code == 200 assert response.json()["name"] == flow.name assert response.json()["flow"] == flow.flow - flow = FlowCreate(name="Test Flow", flow=json_flow) + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) assert response.status_code == 200 assert response.json()["name"] == flow.name @@ -30,8 +34,8 @@ def test_read_flows(client, json_flow): assert len(response.json()) > 0 -def test_read_flow(client, json_flow): - flow = FlowCreate(name="Test Flow", flow=json_flow) +def test_read_flow(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) flow_id = response.json()["id"] response = client.get(f"/flows/{flow_id}") @@ -40,12 +44,13 @@ def test_read_flow(client, json_flow): assert response.json()["flow"] == flow.flow -def test_update_flow(client, json_flow): - flow = FlowCreate(name="Test Flow", flow=json_flow) +def test_update_flow(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) flow_id = response.json()["id"] updated_flow = FlowCreate( - name="Updated Flow", flow=json_flow.replace("BasicExample", "Updated Flow") + name="Updated Flow", + flow=json.loads(json_flow.replace("BasicExample", "Updated Flow")), ) response = client.put(f"/flows/{flow_id}", json=updated_flow.dict()) assert response.status_code == 200 @@ -53,8 +58,8 @@ def test_update_flow(client, json_flow): assert response.json()["flow"] == updated_flow.flow -def test_delete_flow(client, json_flow): - flow = FlowCreate(name="Test Flow", flow=json_flow) +def test_delete_flow(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) response = client.post("/flows/", json=flow.dict()) flow_id = response.json()["id"] response = client.delete(f"/flows/{flow_id}") @@ -62,12 +67,12 @@ def test_delete_flow(client, json_flow): assert response.json()["message"] == "Flow deleted successfully" -def test_create_flows(client, session: Session): +def test_create_flows(client: TestClient, session: Session, json_flow: str): # Create test data flow_list = FlowListCreate( flows=[ - FlowCreate(name="Flow 1", flow="Test flow 1"), - FlowCreate(name="Flow 2", flow="Test flow 2"), + FlowCreate(name="Flow 1", flow=json.loads(json_flow)), + FlowCreate(name="Flow 2", flow=json.loads(json_flow)), ] ) # Make request to endpoint @@ -78,26 +83,20 @@ def test_create_flows(client, session: Session): response_data = response.json() assert len(response_data) == 2 assert response_data[0]["name"] == "Flow 1" - assert response_data[0]["flow"] == "Test flow 1" + assert response_data[0]["flow"] == json.loads(json_flow) assert response_data[1]["name"] == "Flow 2" - assert response_data[1]["flow"] == "Test flow 2" + assert response_data[1]["flow"] == json.loads(json_flow) -def test_upload_file(client, session: Session): +def test_upload_file(client: TestClient, session: Session, json_flow: str): # Create test data flow_list = FlowListCreate( flows=[ - FlowCreate(name="Flow 1", flow="Test flow 1"), - FlowCreate(name="Flow 2", flow="Test flow 2"), + FlowCreate(name="Flow 1", flow=json.loads(json_flow)), + FlowCreate(name="Flow 2", flow=json.loads(json_flow)), ] ) file_contents = json.dumps(flow_list.dict()) - # Make request to endpoint - # curl -X 'POST' \ - # 'http://127.0.0.1:7860/flows/upload/' \ - # -H 'accept: application/json' \ - # -H 'Content-Type: multipart/form-data' \ - # -F 'file=@examples.json;type=application/json' response = client.post( "/flows/upload/", files={"file": ("examples.json", file_contents, "application/json")}, @@ -108,17 +107,17 @@ def test_upload_file(client, session: Session): response_data = response.json() assert len(response_data) == 2 assert response_data[0]["name"] == "Flow 1" - assert response_data[0]["flow"] == "Test flow 1" + assert response_data[0]["flow"] == json.loads(json_flow) assert response_data[1]["name"] == "Flow 2" - assert response_data[1]["flow"] == "Test flow 2" + assert response_data[1]["flow"] == json.loads(json_flow) -def test_download_file(client, session: Session): +def test_download_file(client: TestClient, session: Session, json_flow): # Create test data flow_list = FlowListCreate( flows=[ - FlowCreate(name="Flow 1", flow="Test flow 1"), - FlowCreate(name="Flow 2", flow="Test flow 2"), + FlowCreate(name="Flow 1", flow=json.loads(json_flow)), + FlowCreate(name="Flow 2", flow=json.loads(json_flow)), ] ) for flow in flow_list.flows: @@ -133,6 +132,68 @@ def test_download_file(client, session: Session): response_data = json.loads(response.json()["file"]) assert len(response_data) == 2 assert response_data[0]["name"] == "Flow 1" - assert response_data[0]["flow"] == "Test flow 1" + assert response_data[0]["flow"] == json.loads(json_flow) assert response_data[1]["name"] == "Flow 2" - assert response_data[1]["flow"] == "Test flow 2" + assert response_data[1]["flow"] == json.loads(json_flow) + + +def test_create_flow_with_invalid_data(client: TestClient): + flow = {"name": "a" * 256, "flow": "Invalid flow data"} + response = client.post("/flows/", json=flow) + assert response.status_code == 422 + + +def test_get_nonexistent_flow(client: TestClient): + # uuid4 generates a random UUID + uuid = uuid4() + response = client.get(f"/flows/{uuid}") + assert response.status_code == 404 + + +def test_update_flow_idempotency(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) + response = client.post("/flows/", json=flow.dict()) + flow_id = response.json()["id"] + updated_flow = FlowCreate(name="Updated Flow", flow=json.loads(json_flow)) + response1 = client.put(f"/flows/{flow_id}", json=updated_flow.dict()) + response2 = client.put(f"/flows/{flow_id}", json=updated_flow.dict()) + assert response1.json() == response2.json() + + +def test_update_nonexistent_flow(client: TestClient, json_flow: str): + uuid = uuid4() + updated_flow = FlowCreate( + name="Updated Flow", + flow=json.loads(json_flow.replace("BasicExample", "Updated Flow")), + ) + response = client.put(f"/flows/{uuid}", json=updated_flow.dict()) + assert response.status_code == 404 + + +def test_delete_nonexistent_flow(client: TestClient): + uuid = uuid4() + response = client.delete(f"/flows/{uuid}") + assert response.status_code == 404 + + +def test_read_empty_flows(client: TestClient): + response = client.get("/flows/") + assert response.status_code == 200 + assert len(response.json()) == 0 + + +def test_stress_create_flow(client: TestClient, json_flow: str): + flow = FlowCreate(name="Test Flow", flow=json.loads(json_flow)) + + def create_flow(): + response = client.post("/flows/", json=flow.dict()) + assert response.status_code == 200 + + threads = [] + for i in range(100): + t = threading.Thread(target=create_flow) + threads.append(t) + t.start() + + for t in threads: + t.join()