Merge remote-tracking branch 'origin/chat_and_cache' into websocket
This commit is contained in:
commit
bc97420b5e
4 changed files with 20 additions and 28 deletions
|
|
@ -12,7 +12,5 @@ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
self.websocket = websocket
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(
|
||||
sender="bot", message=token, type="stream", intermediate_steps=""
|
||||
)
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
|
|
|||
|
|
@ -8,6 +8,5 @@ chat_manager = ChatManager()
|
|||
|
||||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
"""Websocket endpoint for chat."""
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,11 +26,17 @@ class ChatHistory(AsyncSubject):
|
|||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
async def add_message(self, client_id: str, message: ChatMessage):
|
||||
"""Add a message to the chat history."""
|
||||
|
||||
self.history[client_id].append(message)
|
||||
await self.notify()
|
||||
|
||||
def get_history(self, client_id: str) -> List[ChatMessage]:
|
||||
return self.history[client_id]
|
||||
"""Get the chat history for a client."""
|
||||
if history := self.history.get(client_id, []):
|
||||
return [msg for msg in history if msg.type not in ["start", "stream"]]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class ChatManager:
|
||||
|
|
@ -46,7 +52,7 @@ class ChatManager:
|
|||
client_id = self.cache_manager.current_client_id
|
||||
if client_id in self.active_connections:
|
||||
chat_response = self.chat_history.get_history(client_id)[-1]
|
||||
if chat_response.sender == "bot":
|
||||
if chat_response.is_bot:
|
||||
# Process FileResponse
|
||||
if isinstance(chat_response, FileResponse):
|
||||
# If data_type is pandas, convert to csv
|
||||
|
|
@ -63,7 +69,6 @@ class ChatManager:
|
|||
self.last_cached_object_dict = self.cache_manager.get_last()
|
||||
# Add a new ChatResponse with the data
|
||||
chat_response = FileResponse(
|
||||
sender="bot",
|
||||
message=None,
|
||||
type="file",
|
||||
data=self.last_cached_object_dict["obj"],
|
||||
|
|
@ -94,13 +99,11 @@ class ChatManager:
|
|||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
chat_message = ChatMessage(message=chat_message)
|
||||
await self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(
|
||||
sender="bot", message=None, type="start", intermediate_steps=""
|
||||
)
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
await self.chat_history.add_message(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
|
|
@ -117,14 +120,9 @@ class ChatManager:
|
|||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
error_resp = ChatResponse(
|
||||
sender="bot", message=str(e), type="error", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, error_resp)
|
||||
return
|
||||
raise e
|
||||
# Send a response back to the frontend, if needed
|
||||
response = ChatResponse(
|
||||
sender="bot",
|
||||
message=result or "",
|
||||
intermediate_steps=intermediate_steps or "",
|
||||
type="end",
|
||||
|
|
@ -151,6 +149,7 @@ class ChatManager:
|
|||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
|
@ -198,11 +197,10 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
|
||||
if hasattr(langchain_object, "callback_manager"):
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
langchain_object.callback_manager = stream_manager
|
||||
llm.callback_manager = stream_manager
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,9 @@ from pydantic import BaseModel, validator
|
|||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
sender: str
|
||||
is_bot: bool = False
|
||||
message: Union[str, None] = None
|
||||
|
||||
@validator("sender")
|
||||
def sender_must_be_bot_or_you(cls, v):
|
||||
if v not in ["bot", "you"]:
|
||||
raise ValueError("sender must be bot or you")
|
||||
return v
|
||||
type: str = "human"
|
||||
|
||||
|
||||
class ChatResponse(ChatMessage):
|
||||
|
|
@ -20,6 +15,7 @@ class ChatResponse(ChatMessage):
|
|||
|
||||
intermediate_steps: str
|
||||
type: str
|
||||
is_bot: bool = True
|
||||
|
||||
@validator("type")
|
||||
def validate_message_type(cls, v):
|
||||
|
|
@ -34,6 +30,7 @@ class FileResponse(ChatMessage):
|
|||
data: Any
|
||||
data_type: str
|
||||
type: str = "file"
|
||||
is_bot: bool = True
|
||||
|
||||
@validator("data_type")
|
||||
def validate_data_type(cls, v):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue