refactor(chat_manager.py): move process_graph function outside of ChatManager class

test(websocket.py): add tests for websocket connection, chat history, and sending messages
This commit is contained in:
Gabriel Almeida 2023-04-19 22:23:31 -03:00
commit 0a630cd70d
2 changed files with 74 additions and 35 deletions

View file

@ -41,13 +41,13 @@ class ChatManager:
async def send_json(self, client_id: str, message: ChatMessage):
websocket = self.active_connections[client_id]
self.chat_history.add_message(client_id, message)
await websocket.send_json(message.dict())
await websocket.send_json(json.dumps(message.dict()))
async def process_message(self, client_id: str, payload: Dict):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="user", message=chat_message)
chat_message = ChatMessage(sender="you", message=chat_message)
self.chat_history.add_message(client_id, chat_message)
graph_data = payload
@ -57,22 +57,14 @@ class ChatManager:
await self.send_json(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
result, intermediate_steps = await process_graph(
graph_data=graph_data,
is_first_message=is_first_message,
chat_message=chat_message,
)
logger.debug("Generated result and intermediate_steps")
except Exception as e:
# Log stack trace
logger.exception(e)
@ -105,3 +97,29 @@ class ChatManager:
print(f"Error: {e}")
finally:
self.disconnect(client_id)
async def process_graph(
graph_data: Dict, is_first_message: bool, chat_message: ChatMessage
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps
except Exception as e:
# Log stack trace
logger.exception(e)
raise e