From 5a46c1e1a0eda13e52134657614a95085ead849e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 4 Jun 2023 22:58:43 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(database.py):=20handle=20cas?= =?UTF-8?q?e=20where=20data=20does=20not=20contain=20"flows"=20key=20?= =?UTF-8?q?=E2=9C=A8=20feat(database.py):=20add=20default=20argument=20to?= =?UTF-8?q?=20json.dumps=20to=20handle=20datetime=20objects=20=F0=9F=9A=A8?= =?UTF-8?q?=20test(database.py):=20add=20tests=20for=20batch=20flow=20crea?= =?UTF-8?q?tion,=20file=20upload,=20and=20file=20download=20The=20fix=20in?= =?UTF-8?q?=20database.py=20handles=20the=20case=20where=20the=20data=20di?= =?UTF-8?q?ctionary=20does=20not=20contain=20the=20"flows"=20key.=20This?= =?UTF-8?q?=20is=20important=20because=20the=20code=20assumes=20that=20the?= =?UTF-8?q?=20"flows"=20key=20is=20present=20and=20will=20raise=20an=20exc?= =?UTF-8?q?eption=20if=20it=20is=20not.=20The=20fix=20adds=20a=20check=20t?= =?UTF-8?q?o=20see=20if=20the=20"flows"=20key=20is=20present=20and=20if=20?= =?UTF-8?q?not,=20it=20creates=20a=20new=20FlowListCreate=20object=20with?= =?UTF-8?q?=20the=20data=20as=20a=20list=20of=20FlowCreate=20objects.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The feature in database.py adds a default argument to the json.dumps function to handle datetime objects. This is important because the default json encoder does not handle datetime objects and will raise an exception if it encounters one. The tests in test_database.py cover the batch creation of flows, uploading a file containing flows, and downloading a file containing flows. These tests ensure that the endpoints are working as expected and that the data is being handled correctly. --- src/backend/langflow/api/database.py | 7 ++- src/backend/langflow/api/endpoints.py | 2 +- tests/test_database.py | 80 +++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/backend/langflow/api/database.py b/src/backend/langflow/api/database.py index f6e5cae47..34e3a2bc8 100644 --- a/src/backend/langflow/api/database.py +++ b/src/backend/langflow/api/database.py @@ -92,7 +92,10 @@ async def upload_file( """Upload flows from a file.""" contents = await file.read() data = json.loads(contents) - flow_list = FlowListCreate(**data) + if "flows" in data: + flow_list = FlowListCreate(**data) + else: + flow_list = FlowListCreate(flows=[FlowCreate(**flow) for flow in data]) return create_flows(session=session, flow_list=flow_list) @@ -100,4 +103,4 @@ async def upload_file( async def download_file(*, session: Session = Depends(get_session)): """Download all flows as a file.""" flows = read_flows(session=session) - return {"file": json.dumps([flow.dict() for flow in flows])} + return {"file": json.dumps([flow.dict() for flow in flows], default=str)} diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index 610d86f18..a9a5aaf66 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -35,7 +35,7 @@ async def get_load( if flow_obj is None: raise ValueError(f"Flow {flow_id} not found") graph_data: GraphData = json.loads(flow_obj.flow) - data = graph_data.dict() + data = graph_data.get("data") response = process_graph_cached(data, predict_request.message) return PredictResponse( result=response.get("result", ""), diff --git a/tests/test_database.py b/tests/test_database.py index 247d55e12..0478332a0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,4 +1,8 @@ +from langflow.api.schemas import FlowListCreate from langflow.database.models.flow import FlowCreate +import json +from sqlalchemy.orm import Session +from langflow.database.models.flow import Flow def test_create_flow(client, json_flow): @@ -56,3 +60,79 @@ def test_delete_flow(client, json_flow): response = client.delete(f"/flows/{flow_id}") assert response.status_code == 200 assert response.json()["message"] == "Flow deleted successfully" + + +def test_create_flows(client, session: Session): + # Create test data + flow_list = FlowListCreate( + flows=[ + FlowCreate(name="Flow 1", flow="Test flow 1"), + FlowCreate(name="Flow 2", flow="Test flow 2"), + ] + ) + # Make request to endpoint + response = client.post("/flows/batch/", json=flow_list.dict()) + # Check response status code + assert response.status_code == 200 + # Check response data + response_data = response.json() + assert len(response_data) == 2 + assert response_data[0]["name"] == "Flow 1" + assert response_data[0]["flow"] == "Test flow 1" + assert response_data[1]["name"] == "Flow 2" + assert response_data[1]["flow"] == "Test flow 2" + + +def test_upload_file(client, session: Session): + # Create test data + flow_list = FlowListCreate( + flows=[ + FlowCreate(name="Flow 1", flow="Test flow 1"), + FlowCreate(name="Flow 2", flow="Test flow 2"), + ] + ) + file_contents = json.dumps(flow_list.dict()) + # Make request to endpoint + # curl -X 'POST' \ + # 'http://127.0.0.1:7860/flows/upload/' \ + # -H 'accept: application/json' \ + # -H 'Content-Type: multipart/form-data' \ + # -F 'file=@examples.json;type=application/json' + response = client.post( + "/flows/upload/", + files={"file": ("examples.json", file_contents, "application/json")}, + ) + # Check response status code + assert response.status_code == 200 + # Check response data + response_data = response.json() + assert len(response_data) == 2 + assert response_data[0]["name"] == "Flow 1" + assert response_data[0]["flow"] == "Test flow 1" + assert response_data[1]["name"] == "Flow 2" + assert response_data[1]["flow"] == "Test flow 2" + + +def test_download_file(client, session: Session): + # Create test data + flow_list = FlowListCreate( + flows=[ + FlowCreate(name="Flow 1", flow="Test flow 1"), + FlowCreate(name="Flow 2", flow="Test flow 2"), + ] + ) + for flow in flow_list.flows: + db_flow = Flow.from_orm(flow) + session.add(db_flow) + session.commit() + # Make request to endpoint + response = client.get("/flows/download/") + # Check response status code + assert response.status_code == 200 + # Check response data + response_data = json.loads(response.json()["file"]) + assert len(response_data) == 2 + assert response_data[0]["name"] == "Flow 1" + assert response_data[0]["flow"] == "Test flow 1" + assert response_data[1]["name"] == "Flow 2" + assert response_data[1]["flow"] == "Test flow 2"