refactor(chat.py, chat_manager.py, schemas.py, run.py): add chat history to ChatManager and ChatMessage schema
feat(chat.py, chat_manager.py): add error handling for async_get_result_and_steps feat(chat.py): add client_id to websocket endpoint feat(schemas.py): add data_type field to ChatResponse schema refactor(run.py): memoize build_langchain_object_with_caching function with maxsize of 10
This commit is contained in:
parent
9b1f86b681
commit
7d183ff57e
4 changed files with 87 additions and 24 deletions
|
|
@ -7,7 +7,6 @@ router = APIRouter()
|
|||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
client_id = str(uuid4())
|
||||
@router.websocket("/ws/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
from langflow.api.schemas import ChatMessage, ChatResponse
|
||||
|
||||
from langflow.interface.run import (
|
||||
get_result_and_steps,
|
||||
async_get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -38,22 +38,25 @@ class ChatManager:
|
|||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def send_json(self, client_id: str, message: Dict):
|
||||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(message)
|
||||
self.chat_history.add_message(client_id, message)
|
||||
await websocket.send_json(message.dict())
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="user", message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(
|
||||
sender="bot", message="", type="start", intermediate_steps=""
|
||||
sender="bot", message=None, type="start", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, start_resp.dict())
|
||||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(graph_data.get("chatHistory", [])) == 0
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
|
|
@ -64,15 +67,20 @@ class ChatManager:
|
|||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = get_result_and_steps(
|
||||
langchain_object, chat_message.message
|
||||
)
|
||||
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
# Save the message to chat history
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
error_resp = ChatResponse(
|
||||
sender="bot", message=str(e), type="error", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, error_resp)
|
||||
return
|
||||
# Send a response back to the frontend, if needed
|
||||
response = ChatResponse(
|
||||
sender="bot",
|
||||
|
|
@ -80,16 +88,16 @@ class ChatManager:
|
|||
intermediate_steps=intermediate_steps or "",
|
||||
type="end",
|
||||
)
|
||||
await self.send_json(client_id, response.dict())
|
||||
await self.send_json(client_id, response)
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
try:
|
||||
chat_history = self.chat_history.get_history(client_id)
|
||||
await websocket.send_text(json.dumps(chat_history))
|
||||
await websocket.send_json(json.dumps(chat_history))
|
||||
|
||||
while True:
|
||||
json_payload = await websocket.receive_text()
|
||||
json_payload = await websocket.receive_json()
|
||||
payload = json.loads(json_payload)
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any
|
||||
from typing import Any, Union
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ class ChatMessage(BaseModel):
|
|||
"""Chat message schema."""
|
||||
|
||||
sender: str
|
||||
message: str
|
||||
message: Union[str, None] = None
|
||||
|
||||
@validator("sender")
|
||||
def sender_must_be_bot_or_you(cls, v):
|
||||
|
|
@ -21,6 +21,7 @@ class ChatResponse(ChatMessage):
|
|||
intermediate_steps: str
|
||||
type: str
|
||||
data: Any = None
|
||||
data_type: str = ""
|
||||
|
||||
@validator("type")
|
||||
def validate_message_type(cls, v):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False):
|
|||
return build_langchain_object_with_caching(data_graph)
|
||||
|
||||
|
||||
@memoize_dict(maxsize=1)
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_langchain_object_with_caching(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
|
|
@ -235,6 +235,61 @@ def get_result_and_steps(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
async def async_get_result_and_steps(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
if hasattr(langchain_object, "acall"):
|
||||
output = await langchain_object.acall(chat_input)
|
||||
else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
||||
|
||||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue