🐛 fix(chat.py): fix stream_build function to return StreamData objects instead of strings

 feat(chat.py): add progress information to the stream_build function
The stream_build function now returns StreamData objects instead of strings, which improves the readability of the code. The function also now includes progress information in the response, which allows the client to track the progress of the build process.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-23 09:54:15 -03:00
commit 8ccdeb9dd5
2 changed files with 51 additions and 22 deletions

View file

@ -1,4 +1,3 @@
import json
from fastapi import (
APIRouter,
HTTPException,
@ -7,7 +6,7 @@ from fastapi import (
status,
)
from fastapi.responses import StreamingResponse
from langflow.api.v1.schemas import BuiltResponse, InitResponse
from langflow.api.v1.schemas import BuiltResponse, InitResponse, StreamData
from langflow.chat.manager import ChatManager
from langflow.graph.graph.base import Graph
@ -39,14 +38,18 @@ async def init_build(graph_data: dict):
try:
flow_id = graph_data.get("id")
if flow_id is None:
raise ValueError("No ID provided")
# Check if already building
if flow_id in flow_data_store and flow_data_store[flow_id].get("building"):
return InitResponse(flowId=flow_id)
# Delete from cache if already exists
if flow_id in chat_manager.in_memory_cache:
with chat_manager.in_memory_cache._lock:
chat_manager.in_memory_cache.delete(flow_id)
logger.debug(f"Deleted flow {flow_id} from cache")
if flow_id is None:
raise ValueError("No ID provided")
flow_data_store[flow_id] = graph_data
flow_data_store[flow_id] = {"graph_data": graph_data, "building": False}
return InitResponse(flowId=flow_id)
except Exception as exc:
@ -76,26 +79,44 @@ async def stream_build(flow_id: str):
"""Stream the build process based on stored flow data."""
async def event_stream(flow_id):
final_response = json.dumps({"end_of_stream": True})
final_response = {"end_of_stream": True}
try:
if flow_id not in flow_data_store:
error_message = "Invalid session ID"
yield f"data: {json.dumps({'error': error_message})}\n\n"
yield str(StreamData(event="error", data={"error": error_message}))
return
if flow_data_store[flow_id].get("building"):
error_message = "Already building"
yield str(StreamData(event="error", data={"error": error_message}))
return
graph_data = flow_data_store[flow_id].get("data")
if not graph_data:
error_message = "No data provided"
yield f"data: {json.dumps({'error': error_message})}\n\n"
yield str(StreamData(event="error", data={"error": error_message}))
return
logger.debug("Building langchain object")
graph = Graph.from_payload(graph_data)
for node in graph.generator_build():
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
except Exception as exc:
logger.exception(exc)
error_message = str(exc)
yield str(StreamData(event="error", data={"error": error_message}))
return
number_of_nodes = len(graph.nodes)
for i, vertex in enumerate(graph.generator_build(), 1):
try:
node.build()
params = node._built_object_repr()
log_dict = {
"log": f"Building node {vertex.vertex_type}",
}
yield str(StreamData(event="log", data=log_dict))
vertex.build()
params = vertex._built_object_repr()
valid = True
logger.debug(
f"Building node {params[:50]}{'...' if len(params) > 50 else ''}"
@ -104,21 +125,21 @@ async def stream_build(flow_id: str):
params = str(exc)
valid = False
response = json.dumps(
{
"valid": valid,
"params": params,
"id": node.id,
}
)
yield f"data: {response}\n\n"
response = {
"valid": valid,
"params": params,
"id": vertex.id,
"progress": round(i / number_of_nodes, 2),
}
yield str(StreamData(event="message", data=response))
chat_manager.set_cache(flow_id, graph.build())
except Exception as exc:
logger.error("Error while building the flow: %s", exc)
yield f"error: {json.dumps({'error': str(exc)})}\n\n"
yield str(StreamData(event="error", data={"error": str(exc)}))
finally:
yield f"data: {final_response}\n\n"
yield str(StreamData(event="message", data=final_response))
try:
return StreamingResponse(event_stream(flow_id), media_type="text/event-stream")

View file

@ -101,3 +101,11 @@ class InitResponse(BaseModel):
class BuiltResponse(BaseModel):
built: bool
class StreamData(BaseModel):
event: str
data: dict
def __str__(self) -> str:
return f"event: {self.event}\ndata: {json.dumps(self.data)}\n\n"