diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 1dc75b7bc..0e0f09501 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -7,7 +7,7 @@ from fastapi import ( WebSocketException, status, ) -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from langflow.chat.manager import ChatManager from langflow.graph.graph.base import Graph @@ -15,13 +15,18 @@ from langflow.utils.logger import logger router = APIRouter() chat_manager = ChatManager() +flow_data_store = {} @router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): """Websocket endpoint for chat.""" try: - await chat_manager.handle_websocket(client_id, websocket) + if client_id in chat_manager.in_memory_cache: + await chat_manager.handle_websocket(client_id, websocket) + else: + message = "Please, build the flow before sending messages" + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message) except WebSocketException as exc: logger.error(exc) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) @@ -30,16 +35,73 @@ async def websocket_endpoint(client_id: str, websocket: WebSocket): await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) -@router.post("/build/{client_id}", response_class=StreamingResponse) -async def stream_build(client_id: str, graph_data: dict): - """Build langchain object from data_graph.""" +@router.post("/build/init") +async def init_build(graph_data: dict): + """Initialize the build by storing graph data and returning a unique session ID.""" + + flow_id = graph_data.get("id") + + flow_data_store[flow_id] = graph_data + + return JSONResponse(content={"flowId": flow_id}) + + +# @router.post("/build/{client_id}", response_class=StreamingResponse) +# async def stream_build(client_id: str, graph_data: dict): +# """Build langchain object from data_graph.""" + +# async def event_stream(graph_data): +# node_id = None +# try: +# graph_data = graph_data.get("data") +# if not graph_data: +# raise HTTPException(status_code=400, detail="No data provided") + +# logger.debug("Building langchain object") +# graph = Graph.from_payload(graph_data) +# for node_repr, node_id in graph.generator_build(): +# logger.debug( +# f"Building node {node_repr[:50]}{'...' if len(node_repr) > 50 else ''}" +# ) +# response = json.dumps( +# { +# "valid": True, +# "params": node_repr, +# "id": node_id, +# } +# ) +# yield f"data: {response}\n\n" # SSE format + +# chat_manager.set_cache(client_id, graph.build()) + +# except Exception as exc: +# logger.exception(exc) +# error_response = json.dumps( +# {"valid": False, "params": str(exc), "id": node_id} +# ) +# yield f"data: {error_response}\n\n" # SSE format + +# return StreamingResponse(event_stream(graph_data), media_type="text/event-stream") + + +@router.get("/build/stream/{flow_id}", response_class=StreamingResponse) +async def stream_build(flow_id: str): + """Stream the build process based on stored flow data.""" + + async def event_stream(flow_id): + if flow_id not in flow_data_store: + error_message = "Invalid session ID" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + graph_data = flow_data_store[flow_id].get("data") + + if not graph_data: + error_message = "No data provided" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return - async def event_stream(graph_data): try: - graph_data = graph_data.get("data") - if not graph_data: - raise HTTPException(status_code=400, detail="No data provided") - logger.debug("Building langchain object") graph = Graph.from_payload(graph_data) for node_repr, node_id in graph.generator_build(): @@ -55,13 +117,13 @@ async def stream_build(client_id: str, graph_data: dict): ) yield f"data: {response}\n\n" # SSE format - chat_manager.set_cache(client_id, graph.build()) + chat_manager.set_cache(flow_id, graph.build()) + final_response = json.dumps({"end_of_stream": True}) + yield f"data: {final_response}\n\n" # SSE format except Exception as exc: logger.exception(exc) - error_response = json.dumps( - {"valid": False, "params": str(exc), "id": node_id} - ) + error_response = json.dumps({"valid": False, "params": str(exc)}) yield f"data: {error_response}\n\n" # SSE format - return StreamingResponse(event_stream(graph_data), media_type="text/event-stream") + return StreamingResponse(event_stream(flow_id), media_type="text/event-stream") diff --git a/src/backend/langflow/chat/manager.py b/src/backend/langflow/chat/manager.py index 7352fafd4..7c3a08240 100644 --- a/src/backend/langflow/chat/manager.py +++ b/src/backend/langflow/chat/manager.py @@ -186,15 +186,8 @@ class ChatManager: continue with self.cache_manager.set_client_id(client_id): - if client_id not in self.in_memory_cache: - await self.close_connection( - client_id=client_id, - code=status.WS_1011_INTERNAL_ERROR, - reason="Please, build the flow before sending messages", - ) - else: - langchain_object = self.in_memory_cache.get(client_id) - await self.process_message(client_id, payload, langchain_object) + langchain_object = self.in_memory_cache.get(client_id) + await self.process_message(client_id, payload, langchain_object) except Exception as e: # Handle any exceptions that might occur