From b042a8e8855c566a8a87a3a9bbfbbfbc7e854bb9 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 28 Sep 2023 09:02:07 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(callback.py):=20remove=20unu?= =?UTF-8?q?sed=20imports=20and=20methods=20in=20AsyncStreamingLLMCallbackH?= =?UTF-8?q?andler=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/callback.py | 41 +++++++------------------ 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index 2a16a0bd2..3427cd090 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -1,15 +1,16 @@ import asyncio +from uuid import UUID from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langflow.api.v1.schemas import ChatResponse -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional from fastapi import WebSocket -from langchain.schema import AgentAction, LLMResult, AgentFinish +from langchain.schema import AgentAction, AgentFinish from loguru import logger @@ -24,32 +25,6 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): resp = ChatResponse(message=token, type="stream", intermediate_steps="") await self.websocket.send_json(resp.dict()) - async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: - """Run when LLM starts running.""" - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: - """Run when LLM ends running.""" - - async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when LLM errors.""" - - async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: - """Run when chain starts running.""" - - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: - """Run when chain ends running.""" - - async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when chain errors.""" - async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> Any: @@ -95,8 +70,14 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): logger.error(f"Error sending response: {exc}") async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: """Run when tool errors.""" async def on_text(self, text: str, **kwargs: Any) -> Any: