diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index a838a8750..3d6bb4061 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -1,28 +1,32 @@ import asyncio -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish -from loguru import logger - from langflow.api.v1.schemas import ChatResponse, PromptResponse from langflow.services.deps import get_chat_service from langflow.utils.util import remove_ansi_escape_codes +from loguru import logger + +if TYPE_CHECKING: + from langflow.services.socket.service import SocketIOService # https://github.com/hwchase17/chat-langchain/blob/master/callback.py class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" - def __init__(self, client_id: str): + def __init__(self, session_id: str): self.chat_service = get_chat_service() - self.client_id = client_id - self.websocket = self.chat_service.active_connections[self.client_id] + self.client_id = session_id + self.socketio_service: "SocketIOService" = self.chat_service.socketio_service + self.sid = session_id + # self.socketio_service = self.chat_service.active_connections[self.client_id] async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(message=token, type="stream", intermediate_steps="") - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: """Run when tool starts running.""" @@ -31,7 +35,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): type="stream", intermediate_steps=f"Tool input: {input_str}", ) - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) async def on_tool_end(self, output: str, **kwargs: Any) -> Any: """Run when tool ends running.""" @@ -62,7 +66,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): try: # This is to emulate the stream of tokens for resp in resps: - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) except Exception as exc: logger.error(f"Error sending response: {exc}") @@ -88,7 +92,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): resp = PromptResponse( prompt=text, ) - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump()) self.chat_service.chat_history.add_message(self.client_id, resp) async def on_agent_action(self, action: AgentAction, **kwargs: Any): @@ -99,10 +103,10 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): logs = log.split("\n") for log in logs: resp = ChatResponse(message="", type="stream", intermediate_steps=log) - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) else: resp = ChatResponse(message="", type="stream", intermediate_steps=log) - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run on agent end.""" @@ -111,7 +115,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): type="stream", intermediate_steps=finish.log, ) - await self.websocket.send_json(resp.model_dump()) + await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) class StreamingLLMCallbackHandler(BaseCallbackHandler): @@ -120,11 +124,11 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler): def __init__(self, client_id: str): self.chat_service = get_chat_service() self.client_id = client_id - self.websocket = self.chat_service.active_connections[self.client_id] + self.socketio_service = self.chat_service.active_connections[self.client_id] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(message=token, type="stream", intermediate_steps="") loop = asyncio.get_event_loop() - coroutine = self.websocket.send_json(resp.model_dump()) + coroutine = self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) asyncio.run_coroutine_threadsafe(coroutine, loop)