🐛 fix(chat.py): add missing import statement for get_current_active_user function

 feat(chat.py): add current_user dependency to chat and init_build endpoints to enforce authentication and authorization
🔧 refactor(chat.py): pass user_id to vertex.build() method in stream_build endpoint to associate the build with the current user
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-25 15:24:45 -03:00
commit 853ce351c9

View file

@ -1,10 +1,18 @@
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status
from fastapi import (
APIRouter,
Depends,
HTTPException,
WebSocket,
WebSocketException,
status,
)
from fastapi.responses import StreamingResponse
from langflow.api.utils import build_input_keys_response
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
from langflow.services import service_manager, ServiceType
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user
from langflow.utils.logger import logger
from cachetools import LRUCache
@ -14,7 +22,9 @@ flow_data_store: LRUCache = LRUCache(maxsize=10)
@router.websocket("/chat/{client_id}")
async def chat(client_id: str, websocket: WebSocket):
async def chat(
client_id: str, websocket: WebSocket, current_user=Depends(get_current_active_user)
):
"""Websocket endpoint for chat."""
try:
chat_manager = service_manager.get(ServiceType.CHAT_MANAGER)
@ -32,7 +42,9 @@ async def chat(client_id: str, websocket: WebSocket):
@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201)
async def init_build(graph_data: dict, flow_id: str):
async def init_build(
graph_data: dict, flow_id: str, current_user=Depends(get_current_active_user)
):
"""Initialize the build by storing graph data and returning a unique session ID."""
try:
@ -54,6 +66,7 @@ async def init_build(graph_data: dict, flow_id: str):
flow_data_store[flow_id] = {
"graph_data": graph_data,
"status": BuildStatus.STARTED,
"user_id": current_user.id,
}
return InitResponse(flowId=flow_id)
@ -99,6 +112,7 @@ async def stream_build(flow_id: str):
return
graph_data = flow_data_store[flow_id].get("graph_data")
user_id = flow_data_store[flow_id]["user_id"]
if not graph_data:
error_message = "No data provided"
@ -119,7 +133,7 @@ async def stream_build(flow_id: str):
"log": f"Building node {vertex.vertex_type}",
}
yield str(StreamData(event="log", data=log_dict))
vertex.build()
vertex.build(user_id)
params = vertex._built_object_repr()
valid = True
logger.debug(f"Building node {str(vertex.vertex_type)}")