refactor(chat_manager.py, utils.py): remove unused imports and functions, move pil_to_base64 and try_setting_streaming_options to utils module
This commit is contained in:
parent
9138b1c55f
commit
ea210af19b
2 changed files with 39 additions and 36 deletions
|
|
@ -1,24 +1,17 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langchain.llms import OpenAI, AzureOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject, Subject
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.cache.manager import Subject
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.interface.utils import pil_to_base64, try_setting_streaming_options
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.cache import cache_manager
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
|
|
@ -149,6 +142,9 @@ class ChatManager:
|
|||
payload = json.loads(json_payload)
|
||||
except TypeError:
|
||||
payload = json_payload
|
||||
if "clear_history" in payload:
|
||||
self.chat_history.history[client_id] = []
|
||||
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
|
|
@ -187,30 +183,3 @@ async def process_graph(
|
|||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,12 @@
|
|||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
import os
|
||||
from PIL.Image import Image
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.llms import AzureOpenAI, OpenAI
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -20,3 +27,30 @@ def load_file_into_dict(file_path: str) -> dict:
|
|||
raise ValueError("Unsupported file type. Please provide a JSON or YAML file.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue