refactor(langflow): replace langchain_object.run with langchain_object.acall in get_result_and_steps function
feat(langflow): add support for streaming intermediate steps to the client via websockets
This commit is contained in:
parent
c630f293b0
commit
474e14efaf
5 changed files with 21 additions and 27 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.schemas import ChatResponse
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,5 @@ chat_manager = ChatManager()
|
|||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Dict, List
|
|||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache import cache_manager
|
||||
from langflow.cache.manager import Subject
|
||||
|
|
@ -175,12 +176,11 @@ class ChatManager:
|
|||
# Handle any exceptions that might occur
|
||||
logger.exception(e)
|
||||
# send a message to the client
|
||||
await self.send_message(client_id, str(e))
|
||||
raise e
|
||||
await self.active_connections[client_id].close(code=1000, reason=str(e))
|
||||
finally:
|
||||
await self.active_connections[client_id].close(
|
||||
code=1000, reason="Client disconnected"
|
||||
)
|
||||
# await self.active_connections[client_id].close(
|
||||
# code=1000, reason="Client disconnected"
|
||||
# )
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
|
|
@ -203,8 +203,9 @@ async def process_graph(
|
|||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", callbacks=[stream_handler]
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
|
|
|
|||
|
|
@ -185,8 +185,11 @@ def fix_memory_inputs(langchain_object):
|
|||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def get_result_and_steps(langchain_object, message: str):
|
||||
async def get_result_and_steps(langchain_object, message: str, callbacks=None):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
|
@ -206,17 +209,17 @@ def get_result_and_steps(langchain_object, message: str):
|
|||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
langchain_object.return_intermediate_steps = True
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
output = langchain_object(chat_input)
|
||||
output = await langchain_object.acall(chat_input, callbacks=callbacks)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
output = langchain_object.run(chat_input, callbacks=callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
|
|
|
|||
|
|
@ -4,13 +4,9 @@ import os
|
|||
from io import BytesIO
|
||||
|
||||
import yaml
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.llms import AzureOpenAI, OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from PIL.Image import Image
|
||||
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
|
||||
|
||||
def load_file_into_dict(file_path: str) -> dict:
|
||||
if not os.path.exists(file_path):
|
||||
|
|
@ -48,10 +44,7 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
if isinstance(llm, BaseLanguageModel):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue