Merge remote-tracking branch 'origin/chat_and_cache' into websocket

This commit is contained in:
anovazzi1 2023-04-25 17:34:21 -03:00
commit bc97420b5e
4 changed files with 20 additions and 28 deletions

View file

@ -12,7 +12,5 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(
sender="bot", message=token, type="stream", intermediate_steps=""
)
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.websocket.send_json(resp.dict())

View file

@ -8,6 +8,5 @@ chat_manager = ChatManager()
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
await chat_manager.handle_websocket(client_id, websocket)

View file

@ -26,11 +26,17 @@ class ChatHistory(AsyncSubject):
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
async def add_message(self, client_id: str, message: ChatMessage):
"""Add a message to the chat history."""
self.history[client_id].append(message)
await self.notify()
def get_history(self, client_id: str) -> List[ChatMessage]:
return self.history[client_id]
"""Get the chat history for a client."""
if history := self.history.get(client_id, []):
return [msg for msg in history if msg.type not in ["start", "stream"]]
else:
return []
class ChatManager:
@ -46,7 +52,7 @@ class ChatManager:
client_id = self.cache_manager.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(client_id)[-1]
if chat_response.sender == "bot":
if chat_response.is_bot:
# Process FileResponse
if isinstance(chat_response, FileResponse):
# If data_type is pandas, convert to csv
@ -63,7 +69,6 @@ class ChatManager:
self.last_cached_object_dict = self.cache_manager.get_last()
# Add a new ChatResponse with the data
chat_response = FileResponse(
sender="bot",
message=None,
type="file",
data=self.last_cached_object_dict["obj"],
@ -94,13 +99,11 @@ class ChatManager:
async def process_message(self, client_id: str, payload: Dict):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="you", message=chat_message)
chat_message = ChatMessage(message=chat_message)
await self.chat_history.add_message(client_id, chat_message)
graph_data = payload
start_resp = ChatResponse(
sender="bot", message=None, type="start", intermediate_steps=""
)
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
await self.chat_history.add_message(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
@ -117,14 +120,9 @@ class ChatManager:
except Exception as e:
# Log stack trace
logger.exception(e)
error_resp = ChatResponse(
sender="bot", message=str(e), type="error", intermediate_steps=""
)
await self.send_json(client_id, error_resp)
return
raise e
# Send a response back to the frontend, if needed
response = ChatResponse(
sender="bot",
message=result or "",
intermediate_steps=intermediate_steps or "",
type="end",
@ -151,6 +149,7 @@ class ChatManager:
except Exception as e:
# Handle any exceptions that might occur
print(f"Error: {e}")
raise e
finally:
self.disconnect(client_id)
@ -198,11 +197,10 @@ def try_setting_streaming_options(langchain_object, websocket):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
llm.streaming = bool(hasattr(llm, "streaming"))
if hasattr(langchain_object, "callback_manager"):
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
langchain_object.callback_manager = stream_manager
llm.callback_manager = stream_manager
return langchain_object

View file

@ -5,14 +5,9 @@ from pydantic import BaseModel, validator
class ChatMessage(BaseModel):
"""Chat message schema."""
sender: str
is_bot: bool = False
message: Union[str, None] = None
@validator("sender")
def sender_must_be_bot_or_you(cls, v):
if v not in ["bot", "you"]:
raise ValueError("sender must be bot or you")
return v
type: str = "human"
class ChatResponse(ChatMessage):
@ -20,6 +15,7 @@ class ChatResponse(ChatMessage):
intermediate_steps: str
type: str
is_bot: bool = True
@validator("type")
def validate_message_type(cls, v):
@ -34,6 +30,7 @@ class FileResponse(ChatMessage):
data: Any
data_type: str
type: str = "file"
is_bot: bool = True
@validator("data_type")
def validate_data_type(cls, v):