diff --git a/src/backend/langflow/api/v1/flows.py b/src/backend/langflow/api/v1/flows.py index d6c63a707..b215b9f95 100644 --- a/src/backend/langflow/api/v1/flows.py +++ b/src/backend/langflow/api/v1/flows.py @@ -4,16 +4,18 @@ from fastapi.encoders import jsonable_encoder from langflow.api.utils import remove_api_keys from langflow.api.v1.schemas import FlowListCreate, FlowListRead +from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models.flow import ( Flow, FlowCreate, FlowRead, FlowUpdate, ) +from langflow.services.database.models.user.user import User from langflow.services.utils import get_session from langflow.services.utils import get_settings_manager import orjson -from sqlmodel import Session, select +from sqlmodel import Session from fastapi import APIRouter, Depends, HTTPException from fastapi import File, UploadFile @@ -23,9 +25,18 @@ router = APIRouter(prefix="/flows", tags=["Flows"]) @router.post("/", response_model=FlowRead, status_code=201) -def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate): +def create_flow( + *, + session: Session = Depends(get_session), + flow: FlowCreate, + current_user: User = Depends(get_current_active_user), +): """Create a new flow.""" + if flow.user_id is None: + flow.user_id = current_user.id + db_flow = Flow.from_orm(flow) + session.add(db_flow) session.commit() session.refresh(db_flow) @@ -33,31 +44,49 @@ def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate): @router.get("/", response_model=list[FlowRead], status_code=200) -def read_flows(*, session: Session = Depends(get_session)): +def read_flows( + *, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), +): """Read all flows.""" try: - flows = session.exec(select(Flow)).all() + flows = current_user.flows except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e return [jsonable_encoder(flow) for flow in flows] @router.get("/{flow_id}", response_model=FlowRead, status_code=200) -def read_flow(*, session: Session = Depends(get_session), flow_id: UUID): +def read_flow( + *, + session: Session = Depends(get_session), + flow_id: UUID, + current_user: User = Depends(get_current_active_user), +): """Read a flow.""" - if flow := session.get(Flow, flow_id): - return flow + if user_flow := ( + session.query(Flow) + .filter(Flow.id == flow_id) + .filter(Flow.user_id == current_user.id) + .first() + ): + return user_flow else: raise HTTPException(status_code=404, detail="Flow not found") @router.patch("/{flow_id}", response_model=FlowRead, status_code=200) def update_flow( - *, session: Session = Depends(get_session), flow_id: UUID, flow: FlowUpdate + *, + session: Session = Depends(get_session), + flow_id: UUID, + flow: FlowUpdate, + current_user: User = Depends(get_current_active_user), ): """Update a flow.""" - db_flow = session.get(Flow, flow_id) + db_flow = read_flow(session=session, flow_id=flow_id, current_user=current_user) if not db_flow: raise HTTPException(status_code=404, detail="Flow not found") flow_data = flow.dict(exclude_unset=True) @@ -65,7 +94,8 @@ def update_flow( if settings_manager.settings.REMOVE_API_KEYS: flow_data = remove_api_keys(flow_data) for key, value in flow_data.items(): - setattr(db_flow, key, value) + if value is not None: + setattr(db_flow, key, value) session.add(db_flow) session.commit() session.refresh(db_flow) @@ -73,9 +103,14 @@ def update_flow( @router.delete("/{flow_id}", status_code=200) -def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID): +def delete_flow( + *, + session: Session = Depends(get_session), + flow_id: UUID, + current_user: User = Depends(get_current_active_user), +): """Delete a flow.""" - flow = session.get(Flow, flow_id) + flow = read_flow(session=session, flow_id=flow_id, current_user=current_user) if not flow: raise HTTPException(status_code=404, detail="Flow not found") session.delete(flow) @@ -87,10 +122,16 @@ def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID): @router.post("/batch/", response_model=List[FlowRead], status_code=201) -def create_flows(*, session: Session = Depends(get_session), flow_list: FlowListCreate): +def create_flows( + *, + session: Session = Depends(get_session), + flow_list: FlowListCreate, + current_user: User = Depends(get_current_active_user), +): """Create multiple new flows.""" db_flows = [] for flow in flow_list.flows: + flow.user_id = current_user.id db_flow = Flow.from_orm(flow) session.add(db_flow) db_flows.append(db_flow) @@ -102,7 +143,10 @@ def create_flows(*, session: Session = Depends(get_session), flow_list: FlowList @router.post("/upload/", response_model=List[FlowRead], status_code=201) async def upload_file( - *, session: Session = Depends(get_session), file: UploadFile = File(...) + *, + session: Session = Depends(get_session), + file: UploadFile = File(...), + current_user: User = Depends(get_current_active_user), ): """Upload flows from a file.""" contents = await file.read() @@ -111,11 +155,19 @@ async def upload_file( flow_list = FlowListCreate(**data) else: flow_list = FlowListCreate(flows=[FlowCreate(**flow) for flow in data]) - return create_flows(session=session, flow_list=flow_list) + # Now we set the user_id for all flows + for flow in flow_list.flows: + flow.user_id = current_user.id + + return create_flows(session=session, flow_list=flow_list, current_user=current_user) @router.get("/download/", response_model=FlowListRead, status_code=200) -async def download_file(*, session: Session = Depends(get_session)): +async def download_file( + *, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_active_user), +): """Download all flows as a file.""" - flows = read_flows(session=session) + flows = read_flows(session=session, current_user=current_user) return FlowListRead(flows=flows) diff --git a/src/backend/langflow/services/database/models/flow/flow.py b/src/backend/langflow/services/database/models/flow/flow.py index a05de5791..e6ad4af4a 100644 --- a/src/backend/langflow/services/database/models/flow/flow.py +++ b/src/backend/langflow/services/database/models/flow/flow.py @@ -2,9 +2,12 @@ from langflow.services.database.models.base import SQLModelSerializable from pydantic import validator -from sqlmodel import Field, JSON, Column +from sqlmodel import Field, JSON, Column, Relationship from uuid import UUID, uuid4 -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.services.database.models.user import User class FlowBase(SQLModelSerializable): @@ -31,14 +34,17 @@ class FlowBase(SQLModelSerializable): class Flow(FlowBase, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) data: Optional[Dict] = Field(default=None, sa_column=Column(JSON)) + user_id: UUID = Field(index=True, foreign_key="user.id") + user: "User" = Relationship(back_populates="flows") class FlowCreate(FlowBase): - pass + user_id: Optional[UUID] = None class FlowRead(FlowBase): id: UUID + user_id: UUID = Field() class FlowUpdate(SQLModelSerializable):