Remove all websocket logic
This commit is contained in:
parent
5cef5b868a
commit
b90ff00cf4
1 changed files with 2 additions and 235 deletions
|
|
@ -1,180 +1,22 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import orjson
|
||||
from fastapi import WebSocket, status
|
||||
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.interface.utils import pil_to_base64
|
||||
from fastapi import WebSocket
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.chat.cache import Subject
|
||||
from langflow.services.chat.utils import process_graph
|
||||
from langflow.services.deps import get_cache_service
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from .cache import cache_service
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
def add_message(self, client_id: str, message: ChatMessage):
|
||||
"""Add a message to the chat history."""
|
||||
|
||||
self.history[client_id].append(message)
|
||||
|
||||
if not isinstance(message, FileResponse):
|
||||
self.notify()
|
||||
|
||||
def get_history(self, client_id: str, filter_messages=True) -> List[ChatMessage]:
|
||||
"""Get the chat history for a client."""
|
||||
if history := self.history.get(client_id, []):
|
||||
if filter_messages:
|
||||
return [msg for msg in history if msg.type not in ["start", "stream"]]
|
||||
return history
|
||||
else:
|
||||
return []
|
||||
|
||||
def empty_history(self, client_id: str):
|
||||
"""Empty the chat history for a client."""
|
||||
self.history[client_id] = []
|
||||
|
||||
|
||||
class ChatService(Service):
|
||||
name = "chat_service"
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.connection_ids: Dict[str, str] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
self.chat_cache = cache_service
|
||||
self.chat_cache.attach(self.update)
|
||||
self.cache_service = get_cache_service()
|
||||
|
||||
def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
client_id = self.chat_cache.current_client_id
|
||||
if client_id in self.active_connections:
|
||||
chat_response = self.chat_history.get_history(
|
||||
client_id, filter_messages=False
|
||||
)[-1]
|
||||
if chat_response.is_bot:
|
||||
# Process FileResponse
|
||||
if isinstance(chat_response, FileResponse):
|
||||
# If data_type is pandas, convert to csv
|
||||
if chat_response.data_type == "pandas":
|
||||
chat_response.data = chat_response.data.to_csv()
|
||||
elif chat_response.data_type == "image":
|
||||
# Base64 encode the image
|
||||
chat_response.data = pil_to_base64(chat_response.data)
|
||||
# get event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
coroutine = self.send_json(client_id, chat_response)
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
||||
def update(self):
|
||||
if self.chat_cache.current_client_id in self.active_connections:
|
||||
self.last_cached_object_dict = self.chat_cache.get_last()
|
||||
# Add a new ChatResponse with the data
|
||||
chat_response = FileResponse(
|
||||
message=None,
|
||||
type="file",
|
||||
data=self.last_cached_object_dict["obj"],
|
||||
data_type=self.last_cached_object_dict["type"],
|
||||
)
|
||||
|
||||
self.chat_history.add_message(
|
||||
self.chat_cache.current_client_id, chat_response
|
||||
)
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
self.active_connections[client_id] = websocket
|
||||
# This is to avoid having multiple clients with the same id
|
||||
#! Temporary solution
|
||||
self.connection_ids[client_id] = f"{client_id}-{uuid.uuid4()}"
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
self.active_connections.pop(client_id, None)
|
||||
self.connection_ids.pop(client_id, None)
|
||||
|
||||
async def send_message(self, client_id: str, message: str):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(message.model_dump())
|
||||
|
||||
async def close_connection(self, client_id: str, code: int, reason: str):
|
||||
if websocket := self.active_connections[client_id]:
|
||||
try:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
self.disconnect(client_id)
|
||||
except RuntimeError as exc:
|
||||
# This is to catch the following error:
|
||||
# Unexpected ASGI message 'websocket.close', after sending 'websocket.close'
|
||||
if "after sending" in str(exc):
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict, build_result: Any):
|
||||
# Process the graph data and chat message
|
||||
chat_inputs = payload.pop("inputs", {})
|
||||
chatkey = payload.pop("chatKey", None)
|
||||
chat_inputs = ChatMessage(message=chat_inputs, chatKey=chatkey)
|
||||
self.chat_history.add_message(client_id, chat_inputs)
|
||||
|
||||
# graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
await self.send_json(client_id, start_resp)
|
||||
|
||||
# is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps, raw_output = await process_graph(
|
||||
build_result=build_result,
|
||||
chat_inputs=chat_inputs,
|
||||
client_id=client_id,
|
||||
session_id=self.connection_ids[client_id],
|
||||
)
|
||||
self.set_cache(client_id, build_result)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
self.chat_history.empty_history(client_id)
|
||||
raise e
|
||||
# Send a response back to the frontend, if needed
|
||||
intermediate_steps = intermediate_steps or ""
|
||||
history = self.chat_history.get_history(client_id, filter_messages=False)
|
||||
file_responses = []
|
||||
if history:
|
||||
# Iterate backwards through the history
|
||||
for msg in reversed(history):
|
||||
if isinstance(msg, FileResponse):
|
||||
if msg.data_type == "image":
|
||||
# Base64 encode the image
|
||||
if isinstance(msg.data, str):
|
||||
continue
|
||||
msg.data = pil_to_base64(msg.data)
|
||||
file_responses.append(msg)
|
||||
if msg.type == "start":
|
||||
break
|
||||
|
||||
response = ChatResponse(
|
||||
message=result,
|
||||
intermediate_steps=intermediate_steps.strip(),
|
||||
type="end",
|
||||
files=file_responses,
|
||||
)
|
||||
await self.send_json(client_id, response)
|
||||
self.chat_history.add_message(client_id, response)
|
||||
|
||||
def set_cache(self, client_id: str, data: Any) -> bool:
|
||||
"""
|
||||
Set the cache for a client.
|
||||
|
|
@ -189,58 +31,6 @@ class ChatService(Service):
|
|||
self.cache_service.upsert(client_id, result_dict)
|
||||
return client_id in self.cache_service
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
||||
try:
|
||||
chat_history = self.chat_history.get_history(client_id)
|
||||
# iterate and make BaseModel into dict
|
||||
chat_history = [chat.model_dump() for chat in chat_history]
|
||||
await websocket.send_json(chat_history)
|
||||
|
||||
while True:
|
||||
json_payload = await websocket.receive_json()
|
||||
if isinstance(json_payload, str):
|
||||
payload = orjson.loads(json_payload)
|
||||
elif isinstance(json_payload, dict):
|
||||
payload = json_payload
|
||||
if "clear_history" in payload and payload["clear_history"]:
|
||||
self.chat_history.history[client_id] = []
|
||||
continue
|
||||
|
||||
with self.chat_cache.set_client_id(client_id):
|
||||
if build_result := self.cache_service.get(client_id).get("result"):
|
||||
await self.process_message(client_id, payload, build_result)
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Could not find a build result for client_id {client_id}"
|
||||
)
|
||||
except Exception as exc:
|
||||
# Handle any exceptions that might occur
|
||||
logger.exception(f"Error handling websocket: {exc}")
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason=str(exc),
|
||||
)
|
||||
elif websocket.client_state == WebSocketState.DISCONNECTED:
|
||||
self.disconnect(client_id)
|
||||
|
||||
finally:
|
||||
try:
|
||||
# first check if the connection is still open
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1000_NORMAL_CLOSURE,
|
||||
reason="Client disconnected",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
self.disconnect(client_id)
|
||||
|
||||
def get_cache(self, client_id: str) -> Any:
|
||||
"""
|
||||
Get the cache for a client.
|
||||
|
|
@ -252,26 +42,3 @@ class ChatService(Service):
|
|||
Clear the cache for a client.
|
||||
"""
|
||||
self.cache_service.delete(client_id)
|
||||
|
||||
|
||||
def dict_to_markdown_table(my_dict):
|
||||
markdown_table = "| Key | Value |\n|---|---|\n"
|
||||
for key, value in my_dict.items():
|
||||
markdown_table += f"| {key} | {value} |\n"
|
||||
return markdown_table
|
||||
|
||||
|
||||
def list_of_dicts_to_markdown_table(dict_list):
|
||||
if not dict_list:
|
||||
return "No data provided."
|
||||
|
||||
# Extract headers from the keys of the first dictionary
|
||||
headers = dict_list[0].keys()
|
||||
markdown_table = "| " + " | ".join(headers) + " |\n"
|
||||
markdown_table += "| " + " | ".join("---" for _ in headers) + " |\n"
|
||||
|
||||
for row_dict in dict_list:
|
||||
row = [str(row_dict.get(header, "")) for header in headers]
|
||||
markdown_table += "| " + " | ".join(row) + " |\n"
|
||||
|
||||
return markdown_table
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue