diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index b71f652d1..ca31139ab 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -1,24 +1,17 @@ import asyncio -import base64 -from io import BytesIO from typing import Dict, List from collections import defaultdict from fastapi import WebSocket import json -from langchain.llms import OpenAI, AzureOpenAI -from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse -from langflow.cache.manager import AsyncSubject, Subject -from langchain.callbacks.base import AsyncCallbackManager -from langflow.api.callback import StreamingLLMCallbackHandler +from langflow.cache.manager import Subject from langflow.interface.run import ( - async_get_result_and_steps, get_result_and_steps, load_or_build_langchain_object, ) +from langflow.interface.utils import pil_to_base64, try_setting_streaming_options from langflow.utils.logger import logger from langflow.cache import cache_manager -from PIL.Image import Image class ChatHistory(Subject): @@ -149,6 +142,10 @@ class ChatManager: payload = json.loads(json_payload) except TypeError: payload = json_payload + if "clear_history" in payload: + self.chat_history.history[client_id] = [] + continue + with self.cache_manager.set_client_id(client_id): await self.process_message(client_id, payload) except Exception as e: @@ -187,30 +184,3 @@ async def process_graph( # Log stack trace logger.exception(e) raise e - - -def try_setting_streaming_options(langchain_object, websocket): - # If the LLM type is OpenAI or ChatOpenAI, - # set streaming to True - # First we need to find the LLM - llm = None - if hasattr(langchain_object, "llm"): - llm = langchain_object.llm - elif hasattr(langchain_object, "llm_chain") and hasattr( - langchain_object.llm_chain, "llm" - ): - llm = langchain_object.llm_chain.llm - if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): - llm.streaming = bool(hasattr(llm, "streaming")) - stream_handler = StreamingLLMCallbackHandler(websocket) - stream_manager = AsyncCallbackManager([stream_handler]) - llm.callback_manager = stream_manager - - return langchain_object - - -def pil_to_base64(image: Image) -> str: - buffered = BytesIO() - image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()) - return img_str.decode("utf-8") diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index b3b154790..604df6f28 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -1,5 +1,12 @@ +import base64 +from io import BytesIO import json import os +from PIL.Image import Image +from langchain.callbacks.base import AsyncCallbackManager +from langchain.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain.llms import AzureOpenAI, OpenAI +from langflow.api.callback import StreamingLLMCallbackHandler import yaml @@ -20,3 +27,30 @@ def load_file_into_dict(file_path: str) -> dict: raise ValueError("Unsupported file type. Please provide a JSON or YAML file.") return data + + +def pil_to_base64(image: Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()) + return img_str.decode("utf-8") + + +def try_setting_streaming_options(langchain_object, websocket): + # If the LLM type is OpenAI or ChatOpenAI, + # set streaming to True + # First we need to find the LLM + llm = None + if hasattr(langchain_object, "llm"): + llm = langchain_object.llm + elif hasattr(langchain_object, "llm_chain") and hasattr( + langchain_object.llm_chain, "llm" + ): + llm = langchain_object.llm_chain.llm + if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): + llm.streaming = bool(hasattr(llm, "streaming")) + stream_handler = StreamingLLMCallbackHandler(websocket) + stream_manager = AsyncCallbackManager([stream_handler]) + llm.callback_manager = stream_manager + + return langchain_object diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 74e147075..9ec128bcf 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -28,7 +28,7 @@ def test_chat_history(client: TestClient): # Receive the response from the server response = websocket.receive_json() assert json.loads(response) == { - "sender": "bot", + "is_bot": True, "message": None, "intermediate_steps": "", "type": "start", @@ -40,7 +40,7 @@ def test_chat_history(client: TestClient): # Receive the response from the server response = websocket.receive_json() assert json.loads(response) == { - "sender": "bot", + "is_bot": True, "message": "Hello, I'm a mock response!", "intermediate_steps": "", "type": "end",