feat(api): add callback handler for streaming LLM responses
Add a new file `callback.py` that contains a new class `StreamingLLMCallbackHandler` that inherits from `AsyncCallbackHandler`. This class handles streaming LLM responses. It has a constructor that takes a `websocket` parameter and sets it as an instance variable. It also has an `on_llm_new_token` method that takes a `token` parameter and sends a `ChatResponse` object to the `websocket` instance variable. Update `chat_manager.py` to import the new `StreamingLLMCallbackHandler` class. Add a new function `try_setting_streaming_options` that takes a `langchain_object` and a `websocket` parameter. This function checks if the `llm` attribute of the `langchain_object` is an instance of `OpenAI`, `ChatOpenAI`, `AzureOpenAI`, or `AzureChatOpenAI`. If it is, it sets the
This commit is contained in:
parent
5169c0bc27
commit
ebc1f6a0df
3 changed files with 54 additions and 5 deletions
18
src/backend/langflow/api/callback.py
Normal file
18
src/backend/langflow/api/callback.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from typing import Any
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
from langflow.api.schemas import ChatResponse
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(
|
||||
sender="bot", message=token, type="stream", intermediate_steps=""
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
from fastapi import APIRouter, WebSocket
|
||||
from uuid import uuid4
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
|
||||
|
|
@ -7,6 +6,8 @@ router = APIRouter()
|
|||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
@router.websocket("/ws/{client_id}")
|
||||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,12 @@ from typing import Dict, List
|
|||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langchain.llms import OpenAI, AzureOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
|
|
@ -90,7 +93,6 @@ class ChatManager:
|
|||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
await self.chat_history.add_message(client_id, chat_message)
|
||||
|
|
@ -105,10 +107,12 @@ class ChatManager:
|
|||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
chat_message=chat_message,
|
||||
websocket=self.active_connections[client_id],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
|
|
@ -129,6 +133,7 @@ class ChatManager:
|
|||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
||||
try:
|
||||
chat_history = self.chat_history.get_history(client_id)
|
||||
await websocket.send_json(json.dumps(chat_history))
|
||||
|
|
@ -146,9 +151,13 @@ class ChatManager:
|
|||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict, is_first_message: bool, chat_message: ChatMessage
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
|
|
@ -171,6 +180,27 @@ async def process_graph(
|
|||
raise e
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
|
||||
if hasattr(langchain_object, "callback_manager"):
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
langchain_object.callback_manager = stream_manager
|
||||
return langchain_object
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue