diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py index a6590673e..d63e107c4 100644 --- a/src/backend/langflow/api/callback.py +++ b/src/backend/langflow/api/callback.py @@ -7,7 +7,7 @@ from langflow.api.schemas import ChatResponse # https://github.com/hwchase17/chat-langchain/blob/master/callback.py -class StreamingLLMCallbackHandler(AsyncCallbackHandler): +class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" def __init__(self, websocket): @@ -16,3 +16,17 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler): async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(message=token, type="stream", intermediate_steps="") await self.websocket.send_json(resp.dict()) + + +class StreamingLLMCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse(message=token, type="stream", intermediate_steps="") + + loop = asyncio.get_event_loop() + coroutine = self.websocket.send_json(resp.dict()) + asyncio.run_coroutine_threadsafe(coroutine, loop) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index e4a4a1b4e..73d9ad6da 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -117,7 +117,7 @@ class ChatManager: try: logger.debug("Generating result and thought") - _, intermediate_steps = await process_graph( + result, intermediate_steps = await process_graph( graph_data=graph_data, is_first_message=is_first_message, chat_message=chat_message, @@ -144,7 +144,7 @@ class ChatManager: break response = ChatResponse( - message="", + message=result, intermediate_steps=intermediate_steps.strip(), type="end", files=file_responses, @@ -212,9 +212,8 @@ async def process_graph( # Generate result and thought try: logger.debug("Generating result and thought") - stream_handler = StreamingLLMCallbackHandler(websocket) result, intermediate_steps = await get_result_and_steps( - langchain_object, chat_message.message or "", callbacks=[stream_handler] + langchain_object, chat_message.message or "", websocket=websocket ) logger.debug("Generated result and intermediate_steps") return result, intermediate_steps diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index a8d55d3be..7bc1bbb0c 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -2,7 +2,8 @@ import contextlib import io from typing import Any, Dict -from chromadb.errors import NotEnoughElementsException # type: ignore +from chromadb.errors import NotEnoughElementsException +from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict from langflow.graph.graph import Graph @@ -185,11 +186,9 @@ def fix_memory_inputs(langchain_object): update_memory_keys(langchain_object, possible_new_mem_key) -async def get_result_and_steps(langchain_object, message: str, callbacks=None): +async def get_result_and_steps(langchain_object, message: str, **kwargs): """Get result and thought from extracted json""" - if callbacks is None: - callbacks = [] try: if hasattr(langchain_object, "verbose"): langchain_object.verbose = True @@ -215,11 +214,15 @@ async def get_result_and_steps(langchain_object, message: str, callbacks=None): with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): try: - output = await langchain_object.acall(chat_input, callbacks=callbacks) - except ValueError as exc: + async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)] + output = await langchain_object.acall( + chat_input, callbacks=async_callbacks + ) + except Exception as exc: # make the error message more informative logger.debug(f"Error: {str(exc)}") - output = langchain_object.run(chat_input, callbacks=callbacks) + sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)] + output = langchain_object(chat_input, callbacks=sync_callbacks) intermediate_steps = ( output.get("intermediate_steps", []) if isinstance(output, dict) else []