refactor(callback.py): rename StreamingLLMCallbackHandler to AsyncStreamingLLMCallbackHandler and add new StreamingLLMCallbackHandler
fix(chat_manager.py): assign result to result variable instead of empty string in ChatResponse refactor(run.py): add AsyncStreamingLLMCallbackHandler and StreamingLLMCallbackHandler imports and use kwargs instead of callbacks in get_result_and_steps function
This commit is contained in:
parent
234358bc6e
commit
5cf531b520
3 changed files with 28 additions and 12 deletions
|
|
@ -7,7 +7,7 @@ from langflow.api.schemas import ChatResponse
|
|||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
|
|
@ -16,3 +16,17 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coroutine = self.websocket.send_json(resp.dict())
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class ChatManager:
|
|||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
_, intermediate_steps = await process_graph(
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
chat_message=chat_message,
|
||||
|
|
@ -144,7 +144,7 @@ class ChatManager:
|
|||
break
|
||||
|
||||
response = ChatResponse(
|
||||
message="",
|
||||
message=result,
|
||||
intermediate_steps=intermediate_steps.strip(),
|
||||
type="end",
|
||||
files=file_responses,
|
||||
|
|
@ -212,9 +212,8 @@ async def process_graph(
|
|||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", callbacks=[stream_handler]
|
||||
langchain_object, chat_message.message or "", websocket=websocket
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import contextlib
|
|||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException # type: ignore
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
|
||||
|
||||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph.graph import Graph
|
||||
|
|
@ -185,11 +186,9 @@ def fix_memory_inputs(langchain_object):
|
|||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
async def get_result_and_steps(langchain_object, message: str, callbacks=None):
|
||||
async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
|
@ -215,11 +214,15 @@ async def get_result_and_steps(langchain_object, message: str, callbacks=None):
|
|||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
output = await langchain_object.acall(chat_input, callbacks=callbacks)
|
||||
except ValueError as exc:
|
||||
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
|
||||
output = await langchain_object.acall(
|
||||
chat_input, callbacks=async_callbacks
|
||||
)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input, callbacks=callbacks)
|
||||
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
|
||||
output = langchain_object(chat_input, callbacks=sync_callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue