feat: Add possibility to save flow to FS (#6841)
Add possibility to save flow to FS
This commit is contained in:
parent
c720933001
commit
581e8cb643
5 changed files with 153 additions and 57 deletions
|
|
@ -0,0 +1,43 @@
|
|||
"""Add column fs_path to Flow
|
||||
|
||||
Revision ID: 93e2705fa8d6
|
||||
Revises: dd9e0804ebd1
|
||||
Create Date: 2025-02-25 13:08:11.263504
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from langflow.utils import migration
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '93e2705fa8d6'
|
||||
down_revision: Union[str, None] = 'dd9e0804ebd1'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn) # type: ignore
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
column_names = [column["name"] for column in inspector.get_columns("flow")]
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
if "fs_path" not in column_names:
|
||||
batch_op.add_column(sa.Column("fs_path", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn) # type: ignore
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
column_names = [column["name"] for column in inspector.get_columns("flow")]
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
if "fs_path" in column_names:
|
||||
batch_op.drop_column("fs_path")
|
||||
|
|
@ -40,13 +40,12 @@ def has_api_terms(word: str):
|
|||
|
||||
def remove_api_keys(flow: dict):
|
||||
"""Remove api keys from flow data."""
|
||||
if flow.get("data") and flow["data"].get("nodes"):
|
||||
for node in flow["data"]["nodes"]:
|
||||
node_data = node.get("data").get("node")
|
||||
template = node_data.get("template")
|
||||
for value in template.values():
|
||||
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
|
||||
value["value"] = None
|
||||
for node in flow.get("data", {}).get("nodes", []):
|
||||
node_data = node.get("data").get("node")
|
||||
template = node_data.get("template")
|
||||
for value in template.values():
|
||||
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
|
||||
value["value"] = None
|
||||
|
||||
return flow
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import Annotated
|
|||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
from aiofile import async_open
|
||||
from anyio import Path
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -20,6 +22,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||
from langflow.api.utils import CurrentActiveUser, DbSession, cascade_delete_flow, remove_api_keys, validate_is_component
|
||||
from langflow.api.v1.schemas import FlowListCreate
|
||||
from langflow.initial_setup.constants import STARTER_FOLDER_NAME
|
||||
from langflow.logging import logger
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowRead, FlowUpdate
|
||||
from langflow.services.database.models.flow.model import FlowHeader
|
||||
from langflow.services.database.models.flow.utils import get_webhook_component_in_flow
|
||||
|
|
@ -32,6 +35,22 @@ from langflow.services.settings.service import SettingsService
|
|||
router = APIRouter(prefix="/flows", tags=["Flows"])
|
||||
|
||||
|
||||
async def _verify_fs_path(path: str | None) -> None:
|
||||
if path:
|
||||
path_ = Path(path)
|
||||
if not await path_.exists():
|
||||
await path_.touch()
|
||||
|
||||
|
||||
async def _save_flow_to_fs(flow: Flow) -> None:
|
||||
if flow.fs_path:
|
||||
async with async_open(flow.fs_path, "w") as f:
|
||||
try:
|
||||
await f.write(flow.model_dump_json())
|
||||
except OSError:
|
||||
logger.exception("Failed to write flow %s to path %s", flow.name, flow.fs_path)
|
||||
|
||||
|
||||
async def _new_flow(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
|
|
@ -39,6 +58,8 @@ async def _new_flow(
|
|||
user_id: UUID,
|
||||
):
|
||||
try:
|
||||
await _verify_fs_path(flow.fs_path)
|
||||
|
||||
"""Create a new flow."""
|
||||
if flow.user_id is None:
|
||||
flow.user_id = user_id
|
||||
|
|
@ -124,6 +145,9 @@ async def create_flow(
|
|||
db_flow = await _new_flow(session=session, flow=flow, user_id=current_user.id)
|
||||
await session.commit()
|
||||
await session.refresh(db_flow)
|
||||
|
||||
await _save_flow_to_fs(db_flow)
|
||||
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
# Get the name of the column that failed
|
||||
|
|
@ -283,6 +307,8 @@ async def update_flow(
|
|||
for key, value in update_data.items():
|
||||
setattr(db_flow, key, value)
|
||||
|
||||
await _verify_fs_path(db_flow.fs_path)
|
||||
|
||||
webhook_component = get_webhook_component_in_flow(db_flow.data)
|
||||
db_flow.webhook = webhook_component is not None
|
||||
db_flow.updated_at = datetime.now(timezone.utc)
|
||||
|
|
@ -296,6 +322,8 @@ async def update_flow(
|
|||
await session.commit()
|
||||
await session.refresh(db_flow)
|
||||
|
||||
await _save_flow_to_fs(db_flow)
|
||||
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
# Get the name of the column that failed
|
||||
|
|
@ -381,6 +409,7 @@ async def upload_file(
|
|||
await session.commit()
|
||||
for db_flow in response_list:
|
||||
await session.refresh(db_flow)
|
||||
await _save_flow_to_fs(db_flow)
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
# Get the name of the column that failed
|
||||
|
|
|
|||
|
|
@ -169,6 +169,7 @@ class Flow(FlowBase, table=True): # type: ignore[call-arg]
|
|||
tags: list[str] | None = Field(sa_column=Column(JSON), default=[])
|
||||
locked: bool | None = Field(default=False, nullable=True)
|
||||
folder_id: UUID | None = Field(default=None, foreign_key="folder.id", nullable=True, index=True)
|
||||
fs_path: str | None = Field(default=None, nullable=True)
|
||||
folder: Optional["Folder"] = Relationship(back_populates="flows")
|
||||
messages: list["MessageTable"] = Relationship(back_populates="flow")
|
||||
transactions: list["TransactionTable"] = Relationship(back_populates="flow")
|
||||
|
|
@ -194,6 +195,7 @@ class Flow(FlowBase, table=True): # type: ignore[call-arg]
|
|||
class FlowCreate(FlowBase):
|
||||
user_id: UUID | None = None
|
||||
folder_id: UUID | None = None
|
||||
fs_path: str | None = None
|
||||
|
||||
|
||||
class FlowRead(FlowBase):
|
||||
|
|
@ -233,6 +235,7 @@ class FlowUpdate(SQLModel):
|
|||
folder_id: UUID | None = None
|
||||
endpoint_name: str | None = None
|
||||
locked: bool | None = None
|
||||
fs_path: str | None = None
|
||||
|
||||
@field_validator("endpoint_name")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,41 +1,54 @@
|
|||
import tempfile
|
||||
import uuid
|
||||
|
||||
from anyio import Path
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
from langflow.services.database.models import Flow
|
||||
|
||||
|
||||
async def test_create_flow(client: AsyncClient, logged_in_headers):
|
||||
basic_case = {
|
||||
"name": "string",
|
||||
"description": "string",
|
||||
"icon": "string",
|
||||
"icon_bg_color": "#ff00ff",
|
||||
"gradient": "string",
|
||||
"data": {},
|
||||
"is_component": False,
|
||||
"webhook": False,
|
||||
"endpoint_name": "string",
|
||||
"tags": ["string"],
|
||||
"user_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"folder_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
}
|
||||
response = await client.post("api/v1/flows/", json=basic_case, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
flow_file = Path(tempfile.tempdir) / f"{uuid.uuid4()!s}.json"
|
||||
try:
|
||||
basic_case = {
|
||||
"name": "string",
|
||||
"description": "string",
|
||||
"icon": "string",
|
||||
"icon_bg_color": "#ff00ff",
|
||||
"gradient": "string",
|
||||
"data": {},
|
||||
"is_component": False,
|
||||
"webhook": False,
|
||||
"endpoint_name": "string",
|
||||
"tags": ["string"],
|
||||
"user_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"folder_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"fs_path": str(flow_file),
|
||||
}
|
||||
response = await client.post("api/v1/flows/", json=basic_case, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert isinstance(result, dict), "The result must be a dictionary"
|
||||
assert "data" in result, "The result must have a 'data' key"
|
||||
assert "description" in result, "The result must have a 'description' key"
|
||||
assert "endpoint_name" in result, "The result must have a 'endpoint_name' key"
|
||||
assert "folder_id" in result, "The result must have a 'folder_id' key"
|
||||
assert "gradient" in result, "The result must have a 'gradient' key"
|
||||
assert "icon" in result, "The result must have a 'icon' key"
|
||||
assert "icon_bg_color" in result, "The result must have a 'icon_bg_color' key"
|
||||
assert "id" in result, "The result must have a 'id' key"
|
||||
assert "is_component" in result, "The result must have a 'is_component' key"
|
||||
assert "name" in result, "The result must have a 'name' key"
|
||||
assert "tags" in result, "The result must have a 'tags' key"
|
||||
assert "updated_at" in result, "The result must have a 'updated_at' key"
|
||||
assert "user_id" in result, "The result must have a 'user_id' key"
|
||||
assert "webhook" in result, "The result must have a 'webhook' key"
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert isinstance(result, dict), "The result must be a dictionary"
|
||||
assert "data" in result, "The result must have a 'data' key"
|
||||
assert "description" in result, "The result must have a 'description' key"
|
||||
assert "endpoint_name" in result, "The result must have a 'endpoint_name' key"
|
||||
assert "folder_id" in result, "The result must have a 'folder_id' key"
|
||||
assert "gradient" in result, "The result must have a 'gradient' key"
|
||||
assert "icon" in result, "The result must have a 'icon' key"
|
||||
assert "icon_bg_color" in result, "The result must have a 'icon_bg_color' key"
|
||||
assert "id" in result, "The result must have a 'id' key"
|
||||
assert "is_component" in result, "The result must have a 'is_component' key"
|
||||
assert "name" in result, "The result must have a 'name' key"
|
||||
assert "tags" in result, "The result must have a 'tags' key"
|
||||
assert "updated_at" in result, "The result must have a 'updated_at' key"
|
||||
assert "user_id" in result, "The result must have a 'user_id' key"
|
||||
assert "webhook" in result, "The result must have a 'webhook' key"
|
||||
|
||||
content = await flow_file.read_text()
|
||||
Flow.model_validate_json(content)
|
||||
finally:
|
||||
await flow_file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
async def test_read_flows(client: AsyncClient, logged_in_headers):
|
||||
|
|
@ -112,26 +125,35 @@ async def test_update_flow(client: AsyncClient, logged_in_headers):
|
|||
response_ = await client.post("api/v1/flows/", json=basic_case, headers=logged_in_headers)
|
||||
id_ = response_.json()["id"]
|
||||
|
||||
flow_file = Path(tempfile.tempdir) / f"{uuid.uuid4()!s}.json"
|
||||
basic_case["name"] = updated_name
|
||||
response = await client.patch(f"api/v1/flows/{id_}", json=basic_case, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
basic_case["fs_path"] = str(flow_file)
|
||||
|
||||
assert isinstance(result, dict), "The result must be a dictionary"
|
||||
assert "data" in result, "The result must have a 'data' key"
|
||||
assert "description" in result, "The result must have a 'description' key"
|
||||
assert "endpoint_name" in result, "The result must have a 'endpoint_name' key"
|
||||
assert "folder_id" in result, "The result must have a 'folder_id' key"
|
||||
assert "gradient" in result, "The result must have a 'gradient' key"
|
||||
assert "icon" in result, "The result must have a 'icon' key"
|
||||
assert "icon_bg_color" in result, "The result must have a 'icon_bg_color' key"
|
||||
assert "id" in result, "The result must have a 'id' key"
|
||||
assert "is_component" in result, "The result must have a 'is_component' key"
|
||||
assert "name" in result, "The result must have a 'name' key"
|
||||
assert "tags" in result, "The result must have a 'tags' key"
|
||||
assert "updated_at" in result, "The result must have a 'updated_at' key"
|
||||
assert "user_id" in result, "The result must have a 'user_id' key"
|
||||
assert "webhook" in result, "The result must have a 'webhook' key"
|
||||
assert result["name"] == updated_name, "The name must be updated"
|
||||
try:
|
||||
response = await client.patch(f"api/v1/flows/{id_}", json=basic_case, headers=logged_in_headers)
|
||||
result = response.json()
|
||||
|
||||
assert isinstance(result, dict), "The result must be a dictionary"
|
||||
assert "data" in result, "The result must have a 'data' key"
|
||||
assert "description" in result, "The result must have a 'description' key"
|
||||
assert "endpoint_name" in result, "The result must have a 'endpoint_name' key"
|
||||
assert "folder_id" in result, "The result must have a 'folder_id' key"
|
||||
assert "gradient" in result, "The result must have a 'gradient' key"
|
||||
assert "icon" in result, "The result must have a 'icon' key"
|
||||
assert "icon_bg_color" in result, "The result must have a 'icon_bg_color' key"
|
||||
assert "id" in result, "The result must have a 'id' key"
|
||||
assert "is_component" in result, "The result must have a 'is_component' key"
|
||||
assert "name" in result, "The result must have a 'name' key"
|
||||
assert "tags" in result, "The result must have a 'tags' key"
|
||||
assert "updated_at" in result, "The result must have a 'updated_at' key"
|
||||
assert "user_id" in result, "The result must have a 'user_id' key"
|
||||
assert "webhook" in result, "The result must have a 'webhook' key"
|
||||
assert result["name"] == updated_name, "The name must be updated"
|
||||
|
||||
content = await flow_file.read_text()
|
||||
Flow.model_validate_json(content)
|
||||
finally:
|
||||
await flow_file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
async def test_create_flows(client: AsyncClient, logged_in_headers):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue