From 76f1aa8adff7f3b1fd67de6bbb531f54951312f5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 12 Jun 2023 11:49:24 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(chat.py):=20add=20check=20fo?= =?UTF-8?q?r=20client=5Fid=20in=20in=5Fmemory=5Fcache=20before=20handling?= =?UTF-8?q?=20websocket=20=E2=9C=A8=20feat(chat.py):=20add=20support=20for?= =?UTF-8?q?=20storing=20graph=20data=20and=20returning=20a=20unique=20sess?= =?UTF-8?q?ion=20ID=20for=20building=20langchain=20object=20=E2=9C=A8=20fe?= =?UTF-8?q?at(chat.py):=20add=20support=20for=20streaming=20the=20build=20?= =?UTF-8?q?process=20based=20on=20stored=20flow=20data=20The=20fix=20adds?= =?UTF-8?q?=20a=20check=20for=20the=20client=5Fid=20in=20the=20in=5Fmemory?= =?UTF-8?q?=5Fcache=20before=20handling=20the=20websocket.=20This=20ensure?= =?UTF-8?q?s=20that=20the=20flow=20has=20been=20built=20before=20sending?= =?UTF-8?q?=20messages.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The first feature adds support for storing graph data and returning a unique session ID for building the langchain object. This allows the user to build the flow and then send messages. The second feature adds support for streaming the build process based on stored flow data. This allows the user to see the progress of the build process. --- src/backend/langflow/api/v1/chat.py | 92 +++++++++++++++++++++++----- src/backend/langflow/chat/manager.py | 11 +--- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 1dc75b7bc..0e0f09501 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -7,7 +7,7 @@ from fastapi import ( WebSocketException, status, ) -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from langflow.chat.manager import ChatManager from langflow.graph.graph.base import Graph @@ -15,13 +15,18 @@ from langflow.utils.logger import logger router = APIRouter() chat_manager = ChatManager() +flow_data_store = {} @router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): """Websocket endpoint for chat.""" try: - await chat_manager.handle_websocket(client_id, websocket) + if client_id in chat_manager.in_memory_cache: + await chat_manager.handle_websocket(client_id, websocket) + else: + message = "Please, build the flow before sending messages" + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message) except WebSocketException as exc: logger.error(exc) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) @@ -30,16 +35,73 @@ async def websocket_endpoint(client_id: str, websocket: WebSocket): await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) -@router.post("/build/{client_id}", response_class=StreamingResponse) -async def stream_build(client_id: str, graph_data: dict): - """Build langchain object from data_graph.""" +@router.post("/build/init") +async def init_build(graph_data: dict): + """Initialize the build by storing graph data and returning a unique session ID.""" + + flow_id = graph_data.get("id") + + flow_data_store[flow_id] = graph_data + + return JSONResponse(content={"flowId": flow_id}) + + +# @router.post("/build/{client_id}", response_class=StreamingResponse) +# async def stream_build(client_id: str, graph_data: dict): +# """Build langchain object from data_graph.""" + +# async def event_stream(graph_data): +# node_id = None +# try: +# graph_data = graph_data.get("data") +# if not graph_data: +# raise HTTPException(status_code=400, detail="No data provided") + +# logger.debug("Building langchain object") +# graph = Graph.from_payload(graph_data) +# for node_repr, node_id in graph.generator_build(): +# logger.debug( +# f"Building node {node_repr[:50]}{'...' if len(node_repr) > 50 else ''}" +# ) +# response = json.dumps( +# { +# "valid": True, +# "params": node_repr, +# "id": node_id, +# } +# ) +# yield f"data: {response}\n\n" # SSE format + +# chat_manager.set_cache(client_id, graph.build()) + +# except Exception as exc: +# logger.exception(exc) +# error_response = json.dumps( +# {"valid": False, "params": str(exc), "id": node_id} +# ) +# yield f"data: {error_response}\n\n" # SSE format + +# return StreamingResponse(event_stream(graph_data), media_type="text/event-stream") + + +@router.get("/build/stream/{flow_id}", response_class=StreamingResponse) +async def stream_build(flow_id: str): + """Stream the build process based on stored flow data.""" + + async def event_stream(flow_id): + if flow_id not in flow_data_store: + error_message = "Invalid session ID" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + graph_data = flow_data_store[flow_id].get("data") + + if not graph_data: + error_message = "No data provided" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return - async def event_stream(graph_data): try: - graph_data = graph_data.get("data") - if not graph_data: - raise HTTPException(status_code=400, detail="No data provided") - logger.debug("Building langchain object") graph = Graph.from_payload(graph_data) for node_repr, node_id in graph.generator_build(): @@ -55,13 +117,13 @@ async def stream_build(client_id: str, graph_data: dict): ) yield f"data: {response}\n\n" # SSE format - chat_manager.set_cache(client_id, graph.build()) + chat_manager.set_cache(flow_id, graph.build()) + final_response = json.dumps({"end_of_stream": True}) + yield f"data: {final_response}\n\n" # SSE format except Exception as exc: logger.exception(exc) - error_response = json.dumps( - {"valid": False, "params": str(exc), "id": node_id} - ) + error_response = json.dumps({"valid": False, "params": str(exc)}) yield f"data: {error_response}\n\n" # SSE format - return StreamingResponse(event_stream(graph_data), media_type="text/event-stream") + return StreamingResponse(event_stream(flow_id), media_type="text/event-stream") diff --git a/src/backend/langflow/chat/manager.py b/src/backend/langflow/chat/manager.py index 7352fafd4..7c3a08240 100644 --- a/src/backend/langflow/chat/manager.py +++ b/src/backend/langflow/chat/manager.py @@ -186,15 +186,8 @@ class ChatManager: continue with self.cache_manager.set_client_id(client_id): - if client_id not in self.in_memory_cache: - await self.close_connection( - client_id=client_id, - code=status.WS_1011_INTERNAL_ERROR, - reason="Please, build the flow before sending messages", - ) - else: - langchain_object = self.in_memory_cache.get(client_id) - await self.process_message(client_id, payload, langchain_object) + langchain_object = self.in_memory_cache.get(client_id) + await self.process_message(client_id, payload, langchain_object) except Exception as e: # Handle any exceptions that might occur