From b90ff00cf4a94e0dc7cfbd3b43075f1748b13c64 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 20 Feb 2024 13:16:43 -0300 Subject: [PATCH] Remove all websocket logic --- src/backend/langflow/services/chat/service.py | 237 +----------------- 1 file changed, 2 insertions(+), 235 deletions(-) diff --git a/src/backend/langflow/services/chat/service.py b/src/backend/langflow/services/chat/service.py index 48e67da0c..054609f97 100644 --- a/src/backend/langflow/services/chat/service.py +++ b/src/backend/langflow/services/chat/service.py @@ -1,180 +1,22 @@ -import asyncio -import uuid -from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Dict -import orjson -from fastapi import WebSocket, status -from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse -from langflow.interface.utils import pil_to_base64 +from fastapi import WebSocket from langflow.services.base import Service -from langflow.services.chat.cache import Subject -from langflow.services.chat.utils import process_graph from langflow.services.deps import get_cache_service -from loguru import logger -from starlette.websockets import WebSocketState from .cache import cache_service -class ChatHistory(Subject): - def __init__(self): - super().__init__() - self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - - def add_message(self, client_id: str, message: ChatMessage): - """Add a message to the chat history.""" - - self.history[client_id].append(message) - - if not isinstance(message, FileResponse): - self.notify() - - def get_history(self, client_id: str, filter_messages=True) -> List[ChatMessage]: - """Get the chat history for a client.""" - if history := self.history.get(client_id, []): - if filter_messages: - return [msg for msg in history if msg.type not in ["start", "stream"]] - return history - else: - return [] - - def empty_history(self, client_id: str): - """Empty the chat history for a client.""" - self.history[client_id] = [] - - class ChatService(Service): name = "chat_service" def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.connection_ids: Dict[str, str] = {} - self.chat_history = ChatHistory() self.chat_cache = cache_service self.chat_cache.attach(self.update) self.cache_service = get_cache_service() - def on_chat_history_update(self): - """Send the last chat message to the client.""" - client_id = self.chat_cache.current_client_id - if client_id in self.active_connections: - chat_response = self.chat_history.get_history( - client_id, filter_messages=False - )[-1] - if chat_response.is_bot: - # Process FileResponse - if isinstance(chat_response, FileResponse): - # If data_type is pandas, convert to csv - if chat_response.data_type == "pandas": - chat_response.data = chat_response.data.to_csv() - elif chat_response.data_type == "image": - # Base64 encode the image - chat_response.data = pil_to_base64(chat_response.data) - # get event loop - loop = asyncio.get_event_loop() - - coroutine = self.send_json(client_id, chat_response) - asyncio.run_coroutine_threadsafe(coroutine, loop) - - def update(self): - if self.chat_cache.current_client_id in self.active_connections: - self.last_cached_object_dict = self.chat_cache.get_last() - # Add a new ChatResponse with the data - chat_response = FileResponse( - message=None, - type="file", - data=self.last_cached_object_dict["obj"], - data_type=self.last_cached_object_dict["type"], - ) - - self.chat_history.add_message( - self.chat_cache.current_client_id, chat_response - ) - - async def connect(self, client_id: str, websocket: WebSocket): - self.active_connections[client_id] = websocket - # This is to avoid having multiple clients with the same id - #! Temporary solution - self.connection_ids[client_id] = f"{client_id}-{uuid.uuid4()}" - - def disconnect(self, client_id: str): - self.active_connections.pop(client_id, None) - self.connection_ids.pop(client_id, None) - - async def send_message(self, client_id: str, message: str): - websocket = self.active_connections[client_id] - await websocket.send_text(message) - - async def send_json(self, client_id: str, message: ChatMessage): - websocket = self.active_connections[client_id] - await websocket.send_json(message.model_dump()) - - async def close_connection(self, client_id: str, code: int, reason: str): - if websocket := self.active_connections[client_id]: - try: - await websocket.close(code=code, reason=reason) - self.disconnect(client_id) - except RuntimeError as exc: - # This is to catch the following error: - # Unexpected ASGI message 'websocket.close', after sending 'websocket.close' - if "after sending" in str(exc): - logger.error(f"Error closing connection: {exc}") - - async def process_message(self, client_id: str, payload: Dict, build_result: Any): - # Process the graph data and chat message - chat_inputs = payload.pop("inputs", {}) - chatkey = payload.pop("chatKey", None) - chat_inputs = ChatMessage(message=chat_inputs, chatKey=chatkey) - self.chat_history.add_message(client_id, chat_inputs) - - # graph_data = payload - start_resp = ChatResponse(message=None, type="start", intermediate_steps="") - await self.send_json(client_id, start_resp) - - # is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1 - # Generate result and thought - try: - logger.debug("Generating result and thought") - - result, intermediate_steps, raw_output = await process_graph( - build_result=build_result, - chat_inputs=chat_inputs, - client_id=client_id, - session_id=self.connection_ids[client_id], - ) - self.set_cache(client_id, build_result) - except Exception as e: - # Log stack trace - logger.exception(e) - self.chat_history.empty_history(client_id) - raise e - # Send a response back to the frontend, if needed - intermediate_steps = intermediate_steps or "" - history = self.chat_history.get_history(client_id, filter_messages=False) - file_responses = [] - if history: - # Iterate backwards through the history - for msg in reversed(history): - if isinstance(msg, FileResponse): - if msg.data_type == "image": - # Base64 encode the image - if isinstance(msg.data, str): - continue - msg.data = pil_to_base64(msg.data) - file_responses.append(msg) - if msg.type == "start": - break - - response = ChatResponse( - message=result, - intermediate_steps=intermediate_steps.strip(), - type="end", - files=file_responses, - ) - await self.send_json(client_id, response) - self.chat_history.add_message(client_id, response) - def set_cache(self, client_id: str, data: Any) -> bool: """ Set the cache for a client. @@ -189,58 +31,6 @@ class ChatService(Service): self.cache_service.upsert(client_id, result_dict) return client_id in self.cache_service - async def handle_websocket(self, client_id: str, websocket: WebSocket): - await self.connect(client_id, websocket) - - try: - chat_history = self.chat_history.get_history(client_id) - # iterate and make BaseModel into dict - chat_history = [chat.model_dump() for chat in chat_history] - await websocket.send_json(chat_history) - - while True: - json_payload = await websocket.receive_json() - if isinstance(json_payload, str): - payload = orjson.loads(json_payload) - elif isinstance(json_payload, dict): - payload = json_payload - if "clear_history" in payload and payload["clear_history"]: - self.chat_history.history[client_id] = [] - continue - - with self.chat_cache.set_client_id(client_id): - if build_result := self.cache_service.get(client_id).get("result"): - await self.process_message(client_id, payload, build_result) - - else: - raise RuntimeError( - f"Could not find a build result for client_id {client_id}" - ) - except Exception as exc: - # Handle any exceptions that might occur - logger.exception(f"Error handling websocket: {exc}") - if websocket.client_state == WebSocketState.CONNECTED: - await self.close_connection( - client_id=client_id, - code=status.WS_1011_INTERNAL_ERROR, - reason=str(exc), - ) - elif websocket.client_state == WebSocketState.DISCONNECTED: - self.disconnect(client_id) - - finally: - try: - # first check if the connection is still open - if websocket.client_state == WebSocketState.CONNECTED: - await self.close_connection( - client_id=client_id, - code=status.WS_1000_NORMAL_CLOSURE, - reason="Client disconnected", - ) - except Exception as exc: - logger.error(f"Error closing connection: {exc}") - self.disconnect(client_id) - def get_cache(self, client_id: str) -> Any: """ Get the cache for a client. @@ -252,26 +42,3 @@ class ChatService(Service): Clear the cache for a client. """ self.cache_service.delete(client_id) - - -def dict_to_markdown_table(my_dict): - markdown_table = "| Key | Value |\n|---|---|\n" - for key, value in my_dict.items(): - markdown_table += f"| {key} | {value} |\n" - return markdown_table - - -def list_of_dicts_to_markdown_table(dict_list): - if not dict_list: - return "No data provided." - - # Extract headers from the keys of the first dictionary - headers = dict_list[0].keys() - markdown_table = "| " + " | ".join(headers) + " |\n" - markdown_table += "| " + " | ".join("---" for _ in headers) + " |\n" - - for row_dict in dict_list: - row = [str(row_dict.get(header, "")) for header in headers] - markdown_table += "| " + " | ".join(row) + " |\n" - - return markdown_table