diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index c907bea7b..90799499f 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -8,6 +8,7 @@ from langflow.api.v1.schemas import ChatResponse from typing import Any, Dict, List, Optional from fastapi import WebSocket +from langflow.services.getters import get_chat_service from langflow.utils.util import remove_ansi_escape_codes @@ -19,8 +20,10 @@ from loguru import logger class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" - def __init__(self, websocket: WebSocket): - self.websocket = websocket + def __init__(self, client_id: str): + self.chat_service = get_chat_service() + self.client_id = client_id + self.websocket = self.chat_service.active_connections[self.client_id] async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(message=token, type="stream", intermediate_steps="") @@ -96,6 +99,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): prompt=text, ) await self.websocket.send_json(resp.dict()) + self.chat_service.chat_history.add_message(self.client_id, resp) async def on_agent_action(self, action: AgentAction, **kwargs: Any): log = f"Thought: {action.log}" diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index 7b78e75c5..b28a660bf 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -36,7 +36,7 @@ def pil_to_base64(image: Image) -> str: return img_str.decode("utf-8") -def try_setting_streaming_options(langchain_object, websocket): +def try_setting_streaming_options(langchain_object): # If the LLM type is OpenAI or ChatOpenAI, # set streaming to True # First we need to find the LLM diff --git a/src/backend/langflow/services/chat/manager.py b/src/backend/langflow/services/chat/manager.py index ee4e6a9fc..9d50f3026 100644 --- a/src/backend/langflow/services/chat/manager.py +++ b/src/backend/langflow/services/chat/manager.py @@ -142,7 +142,7 @@ class ChatService(Service): result, intermediate_steps = await process_graph( langchain_object=langchain_object, chat_inputs=chat_inputs, - websocket=self.active_connections[client_id], + client_id=client_id, session_id=self.connection_ids[client_id], ) self.set_cache(client_id, langchain_object) diff --git a/src/backend/langflow/services/chat/utils.py b/src/backend/langflow/services/chat/utils.py index a332e381e..2e5bd7b58 100644 --- a/src/backend/langflow/services/chat/utils.py +++ b/src/backend/langflow/services/chat/utils.py @@ -8,10 +8,10 @@ from loguru import logger async def process_graph( langchain_object, chat_inputs: ChatMessage, - websocket: WebSocket, + client_id: str, session_id: str, ): - langchain_object = try_setting_streaming_options(langchain_object, websocket) + langchain_object = try_setting_streaming_options(langchain_object) logger.debug("Loaded langchain object") if langchain_object is None: @@ -30,7 +30,7 @@ async def process_graph( result, intermediate_steps = await get_result_and_steps( langchain_object, chat_inputs.message, - websocket=websocket, + client_id=client_id, session_id=session_id, ) logger.debug("Generated result and intermediate_steps")