Merge remote-tracking branch 'origin/chat_and_cache' into websocket
This commit is contained in:
commit
6ba97f64c4
7 changed files with 70 additions and 33 deletions
|
|
@ -1 +1,4 @@
|
|||
from langflow.interface.loading import load_flow_from_json # noqa
|
||||
from langflow.interface.loading import load_flow_from_json
|
||||
from langflow.cache import cache_manager
|
||||
|
||||
__all__ = ["load_flow_from_json", "cache_manager"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, WebSocket
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
chat_manager = ChatManager()
|
||||
|
|
@ -9,4 +10,9 @@ chat_manager = ChatManager()
|
|||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
"""Websocket endpoint for chat."""
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
try:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ import json
|
|||
from langchain.llms import OpenAI, AzureOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject
|
||||
from langflow.cache.manager import AsyncSubject, Subject
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -20,21 +21,23 @@ from langflow.cache import cache_manager
|
|||
from PIL.Image import Image
|
||||
|
||||
|
||||
class ChatHistory(AsyncSubject):
|
||||
class ChatHistory(Subject):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
async def add_message(self, client_id: str, message: ChatMessage):
|
||||
def add_message(self, client_id: str, message: ChatMessage):
|
||||
"""Add a message to the chat history."""
|
||||
|
||||
self.history[client_id].append(message)
|
||||
await self.notify()
|
||||
self.notify()
|
||||
|
||||
def get_history(self, client_id: str) -> List[ChatMessage]:
|
||||
def get_history(self, client_id: str, filter=True) -> List[ChatMessage]:
|
||||
"""Get the chat history for a client."""
|
||||
if history := self.history.get(client_id, []):
|
||||
return [msg for msg in history if msg.type not in ["start", "stream"]]
|
||||
if filter:
|
||||
return [msg for msg in history if msg.type not in ["start", "stream"]]
|
||||
return history
|
||||
else:
|
||||
return []
|
||||
|
||||
|
|
@ -47,11 +50,11 @@ class ChatManager:
|
|||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
|
||||
async def on_chat_history_update(self):
|
||||
def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
client_id = self.cache_manager.current_client_id
|
||||
if client_id in self.active_connections:
|
||||
chat_response = self.chat_history.get_history(client_id)[-1]
|
||||
chat_response = self.chat_history.get_history(client_id, filter=False)[-1]
|
||||
if chat_response.is_bot:
|
||||
# Process FileResponse
|
||||
if isinstance(chat_response, FileResponse):
|
||||
|
|
@ -61,8 +64,11 @@ class ChatManager:
|
|||
elif chat_response.data_type == "image":
|
||||
# Base64 encode the image
|
||||
chat_response.data = pil_to_base64(chat_response.data)
|
||||
# get event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
await self.send_json(client_id, chat_response)
|
||||
coroutine = self.send_json(client_id, chat_response)
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
||||
def update(self):
|
||||
if self.cache_manager.current_client_id in self.active_connections:
|
||||
|
|
@ -75,10 +81,8 @@ class ChatManager:
|
|||
data_type=self.last_cached_object_dict["type"],
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.chat_history.add_message(
|
||||
self.cache_manager.current_client_id, chat_response
|
||||
)
|
||||
self.chat_history.add_message(
|
||||
self.cache_manager.current_client_id, chat_response
|
||||
)
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
|
|
@ -100,11 +104,11 @@ class ChatManager:
|
|||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(message=chat_message)
|
||||
await self.chat_history.add_message(client_id, chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
await self.chat_history.add_message(client_id, start_resp)
|
||||
self.chat_history.add_message(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
# Generate result and thought
|
||||
|
|
@ -122,12 +126,13 @@ class ChatManager:
|
|||
logger.exception(e)
|
||||
raise e
|
||||
# Send a response back to the frontend, if needed
|
||||
intermediate_steps = intermediate_steps or ""
|
||||
response = ChatResponse(
|
||||
message=result or "",
|
||||
intermediate_steps=intermediate_steps or "",
|
||||
intermediate_steps=intermediate_steps.strip(),
|
||||
type="end",
|
||||
)
|
||||
await self.chat_history.add_message(client_id, response)
|
||||
self.chat_history.add_message(client_id, response)
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
|
@ -173,7 +178,7 @@ async def process_graph(
|
|||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
result, intermediate_steps = get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# - Build each inner agent first, then build the outer agent
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
|
@ -14,6 +15,7 @@ from langflow.graph.constants import DIRECT_TYPES
|
|||
from langflow.interface import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import sync_to_async
|
||||
|
||||
|
||||
class Node:
|
||||
|
|
@ -158,13 +160,20 @@ class Node:
|
|||
continue
|
||||
result = value.build()
|
||||
# If the key is "func", then we need to use the run method
|
||||
if key == "func" and not isinstance(result, types.FunctionType):
|
||||
# func can be PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n')
|
||||
# so we need to check if there is an attribute called run
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
# func can be PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n')
|
||||
# so we need to check if there is an attribute called run
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
elif inspect.iscoroutinefunction(result):
|
||||
self.params["coroutine"] = result
|
||||
else:
|
||||
# turn result which is a function into a coroutine
|
||||
# so that it can be awaited
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
|
||||
self.params[key] = result
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
async def async_get_result_and_steps(langchain_object, message: str):
|
||||
def async_get_result_and_steps(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
|
|
@ -267,10 +267,10 @@ async def async_get_result_and_steps(langchain_object, message: str):
|
|||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
if hasattr(langchain_object, "acall"):
|
||||
output = await langchain_object.acall(chat_input)
|
||||
else:
|
||||
output = langchain_object(chat_input)
|
||||
# if hasattr(langchain_object, "acall"):
|
||||
# output = await langchain_object.acall(chat_input)
|
||||
# else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
from functools import partial, wraps
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
|
|
@ -301,3 +303,15 @@ def update_verbose(d: dict, new_value: bool) -> dict:
|
|||
elif k == "verbose":
|
||||
d[k] = new_value
|
||||
return d
|
||||
|
||||
|
||||
def sync_to_async(func):
|
||||
"""
|
||||
Decorator to convert a sync function to an async function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ def create_function(code, function_name):
|
|||
exec_globals[function_name] = locals()[function_name]
|
||||
|
||||
# Return a function that imports necessary modules and calls the target function
|
||||
def wrapped_function(*args, **kwargs):
|
||||
async def wrapped_function(*args, **kwargs):
|
||||
for module_name, module in exec_globals.items():
|
||||
if isinstance(module, type(importlib)):
|
||||
globals()[module_name] = module
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue