diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index f2b494803..459abcf51 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -1,10 +1,18 @@ -from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + WebSocket, + WebSocketException, + status, +) from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData from langflow.services import service_manager, ServiceType from langflow.graph.graph.base import Graph +from langflow.services.auth.utils import get_current_active_user from langflow.utils.logger import logger from cachetools import LRUCache @@ -14,7 +22,9 @@ flow_data_store: LRUCache = LRUCache(maxsize=10) @router.websocket("/chat/{client_id}") -async def chat(client_id: str, websocket: WebSocket): +async def chat( + client_id: str, websocket: WebSocket, current_user=Depends(get_current_active_user) +): """Websocket endpoint for chat.""" try: chat_manager = service_manager.get(ServiceType.CHAT_MANAGER) @@ -32,7 +42,9 @@ async def chat(client_id: str, websocket: WebSocket): @router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201) -async def init_build(graph_data: dict, flow_id: str): +async def init_build( + graph_data: dict, flow_id: str, current_user=Depends(get_current_active_user) +): """Initialize the build by storing graph data and returning a unique session ID.""" try: @@ -54,6 +66,7 @@ async def init_build(graph_data: dict, flow_id: str): flow_data_store[flow_id] = { "graph_data": graph_data, "status": BuildStatus.STARTED, + "user_id": current_user.id, } return InitResponse(flowId=flow_id) @@ -99,6 +112,7 @@ async def stream_build(flow_id: str): return graph_data = flow_data_store[flow_id].get("graph_data") + user_id = flow_data_store[flow_id]["user_id"] if not graph_data: error_message = "No data provided" @@ -119,7 +133,7 @@ async def stream_build(flow_id: str): "log": f"Building node {vertex.vertex_type}", } yield str(StreamData(event="log", data=log_dict)) - vertex.build() + vertex.build(user_id) params = vertex._built_object_repr() valid = True logger.debug(f"Building node {str(vertex.vertex_type)}")