🐛 fix(callback.py): change the initialization of the callback handler to use the client_id and retrieve the websocket from the chat service
🐛 fix(callback.py): add the response message to the chat history in the callback handler 🐛 fix(utils.py): remove the unused websocket parameter in the try_setting_streaming_options function 🐛 fix(manager.py): change the parameter name in the process_graph function to client_id and pass it to the get_result_and_steps function 🐛 fix(utils.py): change the parameter name in the process_graph function to client_id and pass it to the get_result_and_steps function
This commit is contained in:
parent
746f2b6799
commit
3bf055a990
4 changed files with 11 additions and 7 deletions
|
|
@ -8,6 +8,7 @@ from langflow.api.v1.schemas import ChatResponse
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import WebSocket
|
||||
from langflow.services.getters import get_chat_service
|
||||
|
||||
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
|
|
@ -19,8 +20,10 @@ from loguru import logger
|
|||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
def __init__(self, client_id: str):
|
||||
self.chat_service = get_chat_service()
|
||||
self.client_id = client_id
|
||||
self.websocket = self.chat_service.active_connections[self.client_id]
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
|
@ -96,6 +99,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
prompt=text,
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
self.chat_service.chat_history.add_message(self.client_id, resp)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
log = f"Thought: {action.log}"
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def pil_to_base64(image: Image) -> str:
|
|||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
def try_setting_streaming_options(langchain_object):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class ChatService(Service):
|
|||
result, intermediate_steps = await process_graph(
|
||||
langchain_object=langchain_object,
|
||||
chat_inputs=chat_inputs,
|
||||
websocket=self.active_connections[client_id],
|
||||
client_id=client_id,
|
||||
session_id=self.connection_ids[client_id],
|
||||
)
|
||||
self.set_cache(client_id, langchain_object)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from loguru import logger
|
|||
async def process_graph(
|
||||
langchain_object,
|
||||
chat_inputs: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
session_id: str,
|
||||
):
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
langchain_object = try_setting_streaming_options(langchain_object)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
|
|
@ -30,7 +30,7 @@ async def process_graph(
|
|||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object,
|
||||
chat_inputs.message,
|
||||
websocket=websocket,
|
||||
client_id=client_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue