From ba2736f0858b246799246aee1a688461ab79da4a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 29 Sep 2023 11:28:36 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(callback.py):=20replace=20Ch?= =?UTF-8?q?atResponse=20with=20PromptResponse=20in=20AsyncStreamingLLMCall?= =?UTF-8?q?backHandler=20to=20correctly=20handle=20prompt=20after=20format?= =?UTF-8?q?ting=20=F0=9F=94=80=20chore(schemas.py):=20add=20PromptResponse?= =?UTF-8?q?=20schema=20to=20handle=20prompt=20responses=20in=20addition=20?= =?UTF-8?q?to=20ChatResponse=20schema?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/callback.py | 7 ++----- src/backend/langflow/api/v1/schemas.py | 10 +++++++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py index 90799499f..bda05865b 100644 --- a/src/backend/langflow/api/v1/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -3,7 +3,7 @@ from uuid import UUID from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langflow.api.v1.schemas import ChatResponse +from langflow.api.v1.schemas import ChatResponse, PromptResponse from typing import Any, Dict, List, Optional @@ -92,10 +92,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): if "Prompt after formatting" in text: text = text.replace("Prompt after formatting:\n", "") text = remove_ansi_escape_codes(text) - resp = ChatResponse( - message="", - type="stream", - intermediate_steps="", + resp = PromptResponse( prompt=text, ) await self.websocket.send_json(resp.dict()) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 9a4df7c4f..37e7d712d 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -86,7 +86,7 @@ class ChatResponse(ChatMessage): """Chat response schema.""" intermediate_steps: str - prompt: Optional[str] = "" + type: str is_bot: bool = True files: list = [] @@ -98,6 +98,14 @@ class ChatResponse(ChatMessage): return v +class PromptResponse(ChatMessage): + """Prompt response schema.""" + + prompt: str + type: str = "prompt" + is_bot: bool = True + + class FileResponse(ChatMessage): """File response schema."""