refactor(callback.py): rename StreamingLLMCallbackHandler to AsyncStreamingLLMCallbackHandler and add new StreamingLLMCallbackHandler

fix(chat_manager.py): assign result to result variable instead of empty string in ChatResponse
refactor(run.py): add AsyncStreamingLLMCallbackHandler and StreamingLLMCallbackHandler imports and use kwargs instead of callbacks in get_result_and_steps function
This commit is contained in:
Gabriel Almeida 2023-05-06 08:29:06 -03:00
commit 5cf531b520
3 changed files with 28 additions and 12 deletions

View file

@ -7,7 +7,7 @@ from langflow.api.schemas import ChatResponse
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
@ -16,3 +16,17 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.dict())
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
self.websocket = websocket
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

View file

@ -117,7 +117,7 @@ class ChatManager:
try:
logger.debug("Generating result and thought")
_, intermediate_steps = await process_graph(
result, intermediate_steps = await process_graph(
graph_data=graph_data,
is_first_message=is_first_message,
chat_message=chat_message,
@ -144,7 +144,7 @@ class ChatManager:
break
response = ChatResponse(
message="",
message=result,
intermediate_steps=intermediate_steps.strip(),
type="end",
files=file_responses,
@ -212,9 +212,8 @@ async def process_graph(
# Generate result and thought
try:
logger.debug("Generating result and thought")
stream_handler = StreamingLLMCallbackHandler(websocket)
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_message.message or "", callbacks=[stream_handler]
langchain_object, chat_message.message or "", websocket=websocket
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps

View file

@ -2,7 +2,8 @@ import contextlib
import io
from typing import Any, Dict
from chromadb.errors import NotEnoughElementsException # type: ignore
from chromadb.errors import NotEnoughElementsException
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
@ -185,11 +186,9 @@ def fix_memory_inputs(langchain_object):
update_memory_keys(langchain_object, possible_new_mem_key)
async def get_result_and_steps(langchain_object, message: str, callbacks=None):
async def get_result_and_steps(langchain_object, message: str, **kwargs):
"""Get result and thought from extracted json"""
if callbacks is None:
callbacks = []
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
@ -215,11 +214,15 @@ async def get_result_and_steps(langchain_object, message: str, callbacks=None):
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
output = await langchain_object.acall(chat_input, callbacks=callbacks)
except ValueError as exc:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(
chat_input, callbacks=async_callbacks
)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input, callbacks=callbacks)
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []