Refactor callback handler to use SocketIOService

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-31 23:51:04 -03:00
commit edb9ba5a80

View file

@ -1,28 +1,32 @@
import asyncio
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish
from loguru import logger
from langflow.api.v1.schemas import ChatResponse, PromptResponse
from langflow.services.deps import get_chat_service
from langflow.utils.util import remove_ansi_escape_codes
from loguru import logger
if TYPE_CHECKING:
from langflow.services.socket.service import SocketIOService
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, client_id: str):
def __init__(self, session_id: str):
self.chat_service = get_chat_service()
self.client_id = client_id
self.websocket = self.chat_service.active_connections[self.client_id]
self.client_id = session_id
self.socketio_service: "SocketIOService" = self.chat_service.socketio_service
self.sid = session_id
# self.socketio_service = self.chat_service.active_connections[self.client_id]
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
@ -31,7 +35,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
type="stream",
intermediate_steps=f"Tool input: {input_str}",
)
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
@ -62,7 +66,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
try:
# This is to emulate the stream of tokens
for resp in resps:
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
except Exception as exc:
logger.error(f"Error sending response: {exc}")
@ -88,7 +92,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
resp = PromptResponse(
prompt=text,
)
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump())
self.chat_service.chat_history.add_message(self.client_id, resp)
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
@ -99,10 +103,10 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
logs = log.split("\n")
for log in logs:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
else:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
@ -111,7 +115,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
type="stream",
intermediate_steps=finish.log,
)
await self.websocket.send_json(resp.model_dump())
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
class StreamingLLMCallbackHandler(BaseCallbackHandler):
@ -120,11 +124,11 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler):
def __init__(self, client_id: str):
self.chat_service = get_chat_service()
self.client_id = client_id
self.websocket = self.chat_service.active_connections[self.client_id]
self.socketio_service = self.chat_service.active_connections[self.client_id]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.model_dump())
coroutine = self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())
asyncio.run_coroutine_threadsafe(coroutine, loop)