refactor(chat_manager.py): move process_graph function outside of ChatManager class
test(websocket.py): add tests for websocket connection, chat history, and sending messages
This commit is contained in:
parent
7d183ff57e
commit
0a630cd70d
2 changed files with 74 additions and 35 deletions
|
|
@ -41,13 +41,13 @@ class ChatManager:
|
|||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
self.chat_history.add_message(client_id, message)
|
||||
await websocket.send_json(message.dict())
|
||||
await websocket.send_json(json.dumps(message.dict()))
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="user", message=chat_message)
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
|
|
@ -57,22 +57,14 @@ class ChatManager:
|
|||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
chat_message=chat_message,
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -105,3 +97,29 @@ class ChatManager:
|
|||
print(f"Error: {e}")
|
||||
finally:
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict, is_first_message: bool, chat_message: ChatMessage
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue