From 2d6854165024a6bdffa98e05e78b792a6a88db17 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 25 Apr 2023 17:30:53 -0300 Subject: [PATCH] refactor(api): remove sender field from ChatMessage and ChatResponse schemas fix(api): fix ChatManager.get_history method to exclude start and stream messages feat(api): add is_bot field to ChatMessage, ChatResponse, and FileResponse schemas --- src/backend/langflow/api/callback.py | 4 +--- src/backend/langflow/api/chat.py | 3 +-- src/backend/langflow/api/chat_manager.py | 30 +++++++++++------------- src/backend/langflow/api/schemas.py | 11 ++++----- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py index 47a8d945c..cad4b1416 100644 --- a/src/backend/langflow/api/callback.py +++ b/src/backend/langflow/api/callback.py @@ -12,7 +12,5 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler): self.websocket = websocket async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - resp = ChatResponse( - sender="bot", message=token, type="stream", intermediate_steps="" - ) + resp = ChatResponse(message=token, type="stream", intermediate_steps="") await self.websocket.send_json(resp.dict()) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index b2da73d52..d5c2dc879 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -8,6 +8,5 @@ chat_manager = ChatManager() @router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): + """Websocket endpoint for chat.""" await chat_manager.handle_websocket(client_id, websocket) - - diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index bc9b2dc2d..8f407a791 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -26,11 +26,17 @@ class ChatHistory(AsyncSubject): self.history: Dict[str, List[ChatMessage]] = defaultdict(list) async def add_message(self, client_id: str, message: ChatMessage): + """Add a message to the chat history.""" + self.history[client_id].append(message) await self.notify() def get_history(self, client_id: str) -> List[ChatMessage]: - return self.history[client_id] + """Get the chat history for a client.""" + if history := self.history.get(client_id, []): + return [msg for msg in history if msg.type not in ["start", "stream"]] + else: + return [] class ChatManager: @@ -46,7 +52,7 @@ class ChatManager: client_id = self.cache_manager.current_client_id if client_id in self.active_connections: chat_response = self.chat_history.get_history(client_id)[-1] - if chat_response.sender == "bot": + if chat_response.is_bot: # Process FileResponse if isinstance(chat_response, FileResponse): # If data_type is pandas, convert to csv @@ -63,7 +69,6 @@ class ChatManager: self.last_cached_object_dict = self.cache_manager.get_last() # Add a new ChatResponse with the data chat_response = FileResponse( - sender="bot", message=None, type="file", data=self.last_cached_object_dict["obj"], @@ -94,13 +99,11 @@ class ChatManager: async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message chat_message = payload.pop("message", "") - chat_message = ChatMessage(sender="you", message=chat_message) + chat_message = ChatMessage(message=chat_message) await self.chat_history.add_message(client_id, chat_message) graph_data = payload - start_resp = ChatResponse( - sender="bot", message=None, type="start", intermediate_steps="" - ) + start_resp = ChatResponse(message=None, type="start", intermediate_steps="") await self.chat_history.add_message(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 @@ -117,14 +120,9 @@ class ChatManager: except Exception as e: # Log stack trace logger.exception(e) - error_resp = ChatResponse( - sender="bot", message=str(e), type="error", intermediate_steps="" - ) - await self.send_json(client_id, error_resp) - return + raise e # Send a response back to the frontend, if needed response = ChatResponse( - sender="bot", message=result or "", intermediate_steps=intermediate_steps or "", type="end", @@ -151,6 +149,7 @@ class ChatManager: except Exception as e: # Handle any exceptions that might occur print(f"Error: {e}") + raise e finally: self.disconnect(client_id) @@ -198,11 +197,10 @@ def try_setting_streaming_options(langchain_object, websocket): llm = langchain_object.llm_chain.llm if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): llm.streaming = bool(hasattr(llm, "streaming")) - - if hasattr(langchain_object, "callback_manager"): stream_handler = StreamingLLMCallbackHandler(websocket) stream_manager = AsyncCallbackManager([stream_handler]) - langchain_object.callback_manager = stream_manager + llm.callback_manager = stream_manager + return langchain_object diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index 1aefe5c8e..c9b210210 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -5,14 +5,9 @@ from pydantic import BaseModel, validator class ChatMessage(BaseModel): """Chat message schema.""" - sender: str + is_bot: bool = False message: Union[str, None] = None - - @validator("sender") - def sender_must_be_bot_or_you(cls, v): - if v not in ["bot", "you"]: - raise ValueError("sender must be bot or you") - return v + type: str = "human" class ChatResponse(ChatMessage): @@ -20,6 +15,7 @@ class ChatResponse(ChatMessage): intermediate_steps: str type: str + is_bot: bool = True @validator("type") def validate_message_type(cls, v): @@ -34,6 +30,7 @@ class FileResponse(ChatMessage): data: Any data_type: str type: str = "file" + is_bot: bool = True @validator("data_type") def validate_data_type(cls, v):