Refactor ChatService process_message method

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-02 22:28:07 -03:00
commit a901f89cd5

View file

@ -5,14 +5,15 @@ from typing import Any, Dict, List
import orjson
from fastapi import WebSocket, status
from loguru import logger
from starlette.websockets import WebSocketState
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.interface.utils import pil_to_base64
from langflow.services import ServiceType, service_manager
from langflow.services.base import Service
from langflow.services.chat.cache import Subject
from langflow.services.chat.utils import process_graph
from loguru import logger
from starlette.websockets import WebSocketState
from .cache import cache_service
@ -117,7 +118,7 @@ class ChatService(Service):
if "after sending" in str(exc):
logger.error(f"Error closing connection: {exc}")
async def process_message(self, client_id: str, payload: Dict, langchain_object: Any):
async def process_message(self, client_id: str, payload: Dict, build_result: Any):
# Process the graph data and chat message
chat_inputs = payload.pop("inputs", {})
chatkey = payload.pop("chatKey", None)
@ -134,12 +135,12 @@ class ChatService(Service):
logger.debug("Generating result and thought")
result, intermediate_steps, raw_output = await process_graph(
langchain_object=langchain_object,
build_result=build_result,
chat_inputs=chat_inputs,
client_id=client_id,
session_id=self.connection_ids[client_id],
)
self.set_cache(client_id, langchain_object)
self.set_cache(client_id, build_result)
except Exception as e:
# Log stack trace
logger.exception(e)
@ -205,8 +206,8 @@ class ChatService(Service):
continue
with self.chat_cache.set_client_id(client_id):
if langchain_object := self.cache_service.get(client_id).get("result"):
await self.process_message(client_id, payload, langchain_object)
if build_result := self.cache_service.get(client_id).get("result"):
await self.process_message(client_id, payload, build_result)
else:
raise RuntimeError(f"Could not find a build result for client_id {client_id}")