Merge remote-tracking branch 'origin/chat_and_cache' into websocket
This commit is contained in:
commit
6ae18ff14d
3 changed files with 42 additions and 38 deletions
|
|
@ -1,24 +1,17 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langchain.llms import OpenAI, AzureOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject, Subject
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.cache.manager import Subject
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.interface.utils import pil_to_base64, try_setting_streaming_options
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.cache import cache_manager
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
|
|
@ -149,6 +142,10 @@ class ChatManager:
|
|||
payload = json.loads(json_payload)
|
||||
except TypeError:
|
||||
payload = json_payload
|
||||
if "clear_history" in payload:
|
||||
self.chat_history.history[client_id] = []
|
||||
continue
|
||||
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
|
|
@ -187,30 +184,3 @@ async def process_graph(
|
|||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,12 @@
|
|||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
import os
|
||||
from PIL.Image import Image
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.llms import AzureOpenAI, OpenAI
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -20,3 +27,30 @@ def load_file_into_dict(file_path: str) -> dict:
|
|||
raise ValueError("Unsupported file type. Please provide a JSON or YAML file.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def test_chat_history(client: TestClient):
|
|||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"is_bot": True,
|
||||
"message": None,
|
||||
"intermediate_steps": "",
|
||||
"type": "start",
|
||||
|
|
@ -40,7 +40,7 @@ def test_chat_history(client: TestClient):
|
|||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"is_bot": True,
|
||||
"message": "Hello, I'm a mock response!",
|
||||
"intermediate_steps": "",
|
||||
"type": "end",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue