🔀 refactor(callback.py): reorganize imports and add missing type hints

🔀 refactor(callback.py): reorganize imports and add missing type hints to improve code readability and maintainability
The changes in this commit include reorganizing the imports to group them properly and adding missing type hints to improve code readability and maintainability.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-06 15:32:46 -03:00
commit 54b13a6bb6

View file

@ -1,22 +1,107 @@
import asyncio
from typing import Any
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langflow.api.v1.schemas import ChatResponse
from typing import Any, Dict, List, Union
from fastapi import WebSocket
from langchain.schema import AgentAction, LLMResult, AgentFinish
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
def __init__(self, websocket: WebSocket):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.dict())
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when LLM errors."""
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
"""Run when chain starts running."""
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when chain errors."""
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=f"Tool input: {input_str}",
)
await self.websocket.send_json(resp.dict())
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=f"Tool output: {output}",
)
await self.websocket.send_json(resp.dict())
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when tool errors."""
async def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
# This runs when first sending the prompt
# to the LLM, adding it will send the final prompt
# to the frontend
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
log = f"Thought: {action.log}"
# if there are line breaks, split them and send them
# as separate messages
if "\n" in log:
logs = log.split("\n")
for log in logs:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.dict())
else:
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.websocket.send_json(resp.dict())
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
resp = ChatResponse(
message="",
type="stream",
intermediate_steps=finish.log,
)
await self.websocket.send_json(resp.dict())
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""