🐛 fix(chat.py): add check for client_id in in_memory_cache before handling websocket

 feat(chat.py): add support for storing graph data and returning a unique session ID for building langchain object
 feat(chat.py): add support for streaming the build process based on stored flow data
The fix adds a check for the client_id in the in_memory_cache before handling the websocket. This ensures that the flow has been built before sending messages.

The first feature adds support for storing graph data and returning a unique session ID for building the langchain object. This allows the user to build the flow and then send messages.

The second feature adds support for streaming the build process based on stored flow data. This allows the user to see the progress of the build process.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-12 11:49:24 -03:00
commit 76f1aa8adf
2 changed files with 79 additions and 24 deletions

View file

@ -7,7 +7,7 @@ from fastapi import (
WebSocketException,
status,
)
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from langflow.chat.manager import ChatManager
from langflow.graph.graph.base import Graph
@ -15,13 +15,18 @@ from langflow.utils.logger import logger
router = APIRouter()
chat_manager = ChatManager()
flow_data_store = {}
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
try:
await chat_manager.handle_websocket(client_id, websocket)
if client_id in chat_manager.in_memory_cache:
await chat_manager.handle_websocket(client_id, websocket)
else:
message = "Please, build the flow before sending messages"
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
@ -30,16 +35,73 @@ async def websocket_endpoint(client_id: str, websocket: WebSocket):
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))
@router.post("/build/{client_id}", response_class=StreamingResponse)
async def stream_build(client_id: str, graph_data: dict):
"""Build langchain object from data_graph."""
@router.post("/build/init")
async def init_build(graph_data: dict):
"""Initialize the build by storing graph data and returning a unique session ID."""
flow_id = graph_data.get("id")
flow_data_store[flow_id] = graph_data
return JSONResponse(content={"flowId": flow_id})
# @router.post("/build/{client_id}", response_class=StreamingResponse)
# async def stream_build(client_id: str, graph_data: dict):
# """Build langchain object from data_graph."""
# async def event_stream(graph_data):
# node_id = None
# try:
# graph_data = graph_data.get("data")
# if not graph_data:
# raise HTTPException(status_code=400, detail="No data provided")
# logger.debug("Building langchain object")
# graph = Graph.from_payload(graph_data)
# for node_repr, node_id in graph.generator_build():
# logger.debug(
# f"Building node {node_repr[:50]}{'...' if len(node_repr) > 50 else ''}"
# )
# response = json.dumps(
# {
# "valid": True,
# "params": node_repr,
# "id": node_id,
# }
# )
# yield f"data: {response}\n\n" # SSE format
# chat_manager.set_cache(client_id, graph.build())
# except Exception as exc:
# logger.exception(exc)
# error_response = json.dumps(
# {"valid": False, "params": str(exc), "id": node_id}
# )
# yield f"data: {error_response}\n\n" # SSE format
# return StreamingResponse(event_stream(graph_data), media_type="text/event-stream")
@router.get("/build/stream/{flow_id}", response_class=StreamingResponse)
async def stream_build(flow_id: str):
"""Stream the build process based on stored flow data."""
async def event_stream(flow_id):
if flow_id not in flow_data_store:
error_message = "Invalid session ID"
yield f"data: {json.dumps({'error': error_message})}\n\n"
return
graph_data = flow_data_store[flow_id].get("data")
if not graph_data:
error_message = "No data provided"
yield f"data: {json.dumps({'error': error_message})}\n\n"
return
async def event_stream(graph_data):
try:
graph_data = graph_data.get("data")
if not graph_data:
raise HTTPException(status_code=400, detail="No data provided")
logger.debug("Building langchain object")
graph = Graph.from_payload(graph_data)
for node_repr, node_id in graph.generator_build():
@ -55,13 +117,13 @@ async def stream_build(client_id: str, graph_data: dict):
)
yield f"data: {response}\n\n" # SSE format
chat_manager.set_cache(client_id, graph.build())
chat_manager.set_cache(flow_id, graph.build())
final_response = json.dumps({"end_of_stream": True})
yield f"data: {final_response}\n\n" # SSE format
except Exception as exc:
logger.exception(exc)
error_response = json.dumps(
{"valid": False, "params": str(exc), "id": node_id}
)
error_response = json.dumps({"valid": False, "params": str(exc)})
yield f"data: {error_response}\n\n" # SSE format
return StreamingResponse(event_stream(graph_data), media_type="text/event-stream")
return StreamingResponse(event_stream(flow_id), media_type="text/event-stream")

View file

@ -186,15 +186,8 @@ class ChatManager:
continue
with self.cache_manager.set_client_id(client_id):
if client_id not in self.in_memory_cache:
await self.close_connection(
client_id=client_id,
code=status.WS_1011_INTERNAL_ERROR,
reason="Please, build the flow before sending messages",
)
else:
langchain_object = self.in_memory_cache.get(client_id)
await self.process_message(client_id, payload, langchain_object)
langchain_object = self.in_memory_cache.get(client_id)
await self.process_message(client_id, payload, langchain_object)
except Exception as e:
# Handle any exceptions that might occur