diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py new file mode 100644 index 000000000..47a8d945c --- /dev/null +++ b/src/backend/langflow/api/callback.py @@ -0,0 +1,18 @@ +from typing import Any +from langchain.callbacks.base import AsyncCallbackHandler + +from langflow.api.schemas import ChatResponse + + +# https://github.com/hwchase17/chat-langchain/blob/master/callback.py +class StreamingLLMCallbackHandler(AsyncCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse( + sender="bot", message=token, type="stream", intermediate_steps="" + ) + await self.websocket.send_json(resp.dict()) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index 11b861c77..b2da73d52 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, WebSocket -from uuid import uuid4 from langflow.api.chat_manager import ChatManager @@ -7,6 +6,8 @@ router = APIRouter() chat_manager = ChatManager() -@router.websocket("/ws/{client_id}") +@router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): await chat_manager.handle_websocket(client_id, websocket) + + diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 8dcaf05ac..5ce7d2452 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -5,9 +5,12 @@ from typing import Dict, List from collections import defaultdict from fastapi import WebSocket import json +from langchain.llms import OpenAI, AzureOpenAI +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse from langflow.cache.manager import AsyncSubject - +from langchain.callbacks.base import AsyncCallbackManager +from langflow.api.callback import StreamingLLMCallbackHandler from langflow.interface.run import ( async_get_result_and_steps, load_or_build_langchain_object, @@ -90,7 +93,6 @@ class ChatManager: async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message - chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="you", message=chat_message) await self.chat_history.add_message(client_id, chat_message) @@ -105,10 +107,12 @@ class ChatManager: # Generate result and thought try: logger.debug("Generating result and thought") + result, intermediate_steps = await process_graph( graph_data=graph_data, is_first_message=is_first_message, chat_message=chat_message, + websocket=self.active_connections[client_id], ) except Exception as e: # Log stack trace @@ -129,6 +133,7 @@ class ChatManager: async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) + try: chat_history = self.chat_history.get_history(client_id) await websocket.send_json(json.dumps(chat_history)) @@ -146,9 +151,13 @@ class ChatManager: async def process_graph( - graph_data: Dict, is_first_message: bool, chat_message: ChatMessage + graph_data: Dict, + is_first_message: bool, + chat_message: ChatMessage, + websocket: WebSocket, ): langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + langchain_object = try_setting_streaming_options(langchain_object, websocket) logger.debug("Loaded langchain object") if langchain_object is None: @@ -171,6 +180,27 @@ async def process_graph( raise e +def try_setting_streaming_options(langchain_object, websocket): + # If the LLM type is OpenAI or ChatOpenAI, + # set streaming to True + # First we need to find the LLM + llm = None + if hasattr(langchain_object, "llm"): + llm = langchain_object.llm + elif hasattr(langchain_object, "llm_chain") and hasattr( + langchain_object.llm_chain, "llm" + ): + llm = langchain_object.llm_chain.llm + if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): + llm.streaming = bool(hasattr(llm, "streaming")) + + if hasattr(langchain_object, "callback_manager"): + stream_handler = StreamingLLMCallbackHandler(websocket) + stream_manager = AsyncCallbackManager([stream_handler]) + langchain_object.callback_manager = stream_manager + return langchain_object + + def pil_to_base64(image: Image) -> str: buffered = BytesIO() image.save(buffered, format="PNG")