🐛 fix(manager.py): add check for langchain object in process_message to avoid errors
✨ feat(manager.py): add build method to build langchain object and store it in an in-memory cache
The `process_message` method now checks if the langchain object has been built and stored in the in-memory cache before processing the message. If the object is not found, the connection is closed with an error message. A new `build` method has been added to build the langchain object and store it in an in-memory cache. This method is called before processing any messages.
This commit is contained in:
parent
dacc90d901
commit
ccf9477b7f
2 changed files with 41 additions and 14 deletions
|
|
@ -10,7 +10,10 @@ from langflow.utils.logger import logger
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langflow.cache.flow import InMemoryCache
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
|
|
@ -46,6 +49,7 @@ class ChatManager:
|
|||
self.chat_history = ChatHistory()
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
|
|
@ -99,24 +103,30 @@ class ChatManager:
|
|||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(message.dict())
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
async def close_connection(self, client_id: str, code: status, reason: str):
|
||||
if websocket := self.active_connections[client_id]:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
self.disconnect(client_id)
|
||||
|
||||
async def process_message(
|
||||
self, client_id: str, payload: Dict, langchain_object: Any
|
||||
):
|
||||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
# graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
|
||||
# is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
langchain_object=langchain_object,
|
||||
chat_message=chat_message,
|
||||
websocket=self.active_connections[client_id],
|
||||
)
|
||||
|
|
@ -149,6 +159,17 @@ class ChatManager:
|
|||
await self.send_json(client_id, response)
|
||||
self.chat_history.add_message(client_id, response)
|
||||
|
||||
def build(self, client_id: str, graph_data: Dict) -> bool:
|
||||
"""
|
||||
Build the langchain object and set the streaming options,
|
||||
then store it in the in-memory cache.
|
||||
"""
|
||||
logger.debug("Building langchain object")
|
||||
graph = Graph.from_payload(graph_data)
|
||||
langchain_object = graph.build()
|
||||
self.in_memory_cache.set(client_id, langchain_object)
|
||||
return client_id in self.in_memory_cache
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
||||
|
|
@ -169,16 +190,24 @@ class ChatManager:
|
|||
continue
|
||||
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
if client_id not in self.in_memory_cache:
|
||||
self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason="Please, build the flow before sending messages",
|
||||
)
|
||||
else:
|
||||
langchain_object = self.in_memory_cache.get(client_id)
|
||||
await self.process_message(client_id, payload, langchain_object)
|
||||
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
logger.exception(e)
|
||||
# send a message to the client
|
||||
await self.active_connections[client_id].close(
|
||||
code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120]
|
||||
self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason=str(e)[:120],
|
||||
)
|
||||
self.disconnect(client_id)
|
||||
finally:
|
||||
try:
|
||||
connection = self.active_connections.get(client_id)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,10 @@ from typing import Dict
|
|||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
langchain_object,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue