refactor(chat.py, chat_manager.py, schemas.py, run.py): add chat history to ChatManager and ChatMessage schema

feat(chat.py, chat_manager.py): add error handling for async_get_result_and_steps
feat(chat.py): add client_id to websocket endpoint
feat(schemas.py): add data_type field to ChatResponse schema
refactor(run.py): memoize build_langchain_object_with_caching function with maxsize of 10
This commit is contained in:
Gabriel Almeida 2023-04-19 21:28:05 -03:00
commit 7d183ff57e
4 changed files with 87 additions and 24 deletions

View file

@ -7,7 +7,6 @@ router = APIRouter()
chat_manager = ChatManager()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
client_id = str(uuid4())
@router.websocket("/ws/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
await chat_manager.handle_websocket(client_id, websocket)

View file

@ -5,7 +5,7 @@ import json
from langflow.api.schemas import ChatMessage, ChatResponse
from langflow.interface.run import (
get_result_and_steps,
async_get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.utils.logger import logger
@ -38,22 +38,25 @@ class ChatManager:
websocket = self.active_connections[client_id]
await websocket.send_text(message)
async def send_json(self, client_id: str, message: Dict):
async def send_json(self, client_id: str, message: ChatMessage):
websocket = self.active_connections[client_id]
await websocket.send_json(message)
self.chat_history.add_message(client_id, message)
await websocket.send_json(message.dict())
async def process_message(self, client_id: str, payload: Dict):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="user", message=chat_message)
self.chat_history.add_message(client_id, chat_message)
graph_data = payload
start_resp = ChatResponse(
sender="bot", message="", type="start", intermediate_steps=""
sender="bot", message=None, type="start", intermediate_steps=""
)
await self.send_json(client_id, start_resp.dict())
await self.send_json(client_id, start_resp)
is_first_message = len(graph_data.get("chatHistory", [])) == 0
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
logger.debug("Loaded langchain object")
@ -64,15 +67,20 @@ class ChatManager:
)
# Generate result and thought
logger.debug("Generating result and thought")
result, intermediate_steps = get_result_and_steps(
langchain_object, chat_message.message
)
logger.debug("Generated result and intermediate_steps")
# Save the message to chat history
self.chat_history.add_message(client_id, chat_message)
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
)
logger.debug("Generated result and intermediate_steps")
except Exception as e:
# Log stack trace
logger.exception(e)
error_resp = ChatResponse(
sender="bot", message=str(e), type="error", intermediate_steps=""
)
await self.send_json(client_id, error_resp)
return
# Send a response back to the frontend, if needed
response = ChatResponse(
sender="bot",
@ -80,16 +88,16 @@ class ChatManager:
intermediate_steps=intermediate_steps or "",
type="end",
)
await self.send_json(client_id, response.dict())
await self.send_json(client_id, response)
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
try:
chat_history = self.chat_history.get_history(client_id)
await websocket.send_text(json.dumps(chat_history))
await websocket.send_json(json.dumps(chat_history))
while True:
json_payload = await websocket.receive_text()
json_payload = await websocket.receive_json()
payload = json.loads(json_payload)
await self.process_message(client_id, payload)
except Exception as e:

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Union
from pydantic import BaseModel, validator
@ -6,7 +6,7 @@ class ChatMessage(BaseModel):
"""Chat message schema."""
sender: str
message: str
message: Union[str, None] = None
@validator("sender")
def sender_must_be_bot_or_you(cls, v):
@ -21,6 +21,7 @@ class ChatResponse(ChatMessage):
intermediate_steps: str
type: str
data: Any = None
data_type: str = ""
@validator("type")
def validate_message_type(cls, v):

View file

@ -31,7 +31,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False):
return build_langchain_object_with_caching(data_graph)
@memoize_dict(maxsize=1)
@memoize_dict(maxsize=10)
def build_langchain_object_with_caching(data_graph):
"""
Build langchain object from data_graph.
@ -235,6 +235,61 @@ def get_result_and_steps(langchain_object, message: str):
return result, thought
async def async_get_result_and_steps(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
if hasattr(langchain_object, "acall"):
output = await langchain_object.acall(chat_input)
else:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
else:
thought = output_buffer.getvalue()
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
"""Get result and thought from extracted json"""
try: