Merge remote-tracking branch 'origin/chat_and_cache' into websocket

This commit is contained in:
anovazzi1 2023-04-25 22:02:34 -03:00
commit 6ae18ff14d
3 changed files with 42 additions and 38 deletions

View file

@ -1,24 +1,17 @@
import asyncio
import base64
from io import BytesIO
from typing import Dict, List
from collections import defaultdict
from fastapi import WebSocket
import json
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache.manager import AsyncSubject, Subject
from langchain.callbacks.base import AsyncCallbackManager
from langflow.api.callback import StreamingLLMCallbackHandler
from langflow.cache.manager import Subject
from langflow.interface.run import (
async_get_result_and_steps,
get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.interface.utils import pil_to_base64, try_setting_streaming_options
from langflow.utils.logger import logger
from langflow.cache import cache_manager
from PIL.Image import Image
class ChatHistory(Subject):
@ -149,6 +142,10 @@ class ChatManager:
payload = json.loads(json_payload)
except TypeError:
payload = json_payload
if "clear_history" in payload:
self.chat_history.history[client_id] = []
continue
with self.cache_manager.set_client_id(client_id):
await self.process_message(client_id, payload)
except Exception as e:
@ -187,30 +184,3 @@ async def process_graph(
# Log stack trace
logger.exception(e)
raise e
def try_setting_streaming_options(langchain_object, websocket):
# If the LLM type is OpenAI or ChatOpenAI,
# set streaming to True
# First we need to find the LLM
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
llm.streaming = bool(hasattr(llm, "streaming"))
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
llm.callback_manager = stream_manager
return langchain_object
def pil_to_base64(image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")

View file

@ -1,5 +1,12 @@
import base64
from io import BytesIO
import json
import os
from PIL.Image import Image
from langchain.callbacks.base import AsyncCallbackManager
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import AzureOpenAI, OpenAI
from langflow.api.callback import StreamingLLMCallbackHandler
import yaml
@ -20,3 +27,30 @@ def load_file_into_dict(file_path: str) -> dict:
raise ValueError("Unsupported file type. Please provide a JSON or YAML file.")
return data
def pil_to_base64(image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
def try_setting_streaming_options(langchain_object, websocket):
# If the LLM type is OpenAI or ChatOpenAI,
# set streaming to True
# First we need to find the LLM
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
llm.streaming = bool(hasattr(llm, "streaming"))
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
llm.callback_manager = stream_manager
return langchain_object

View file

@ -28,7 +28,7 @@ def test_chat_history(client: TestClient):
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"is_bot": True,
"message": None,
"intermediate_steps": "",
"type": "start",
@ -40,7 +40,7 @@ def test_chat_history(client: TestClient):
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"is_bot": True,
"message": "Hello, I'm a mock response!",
"intermediate_steps": "",
"type": "end",