Refactor ChatService process_message method
This commit is contained in:
parent
5e615c0c14
commit
a901f89cd5
1 changed files with 8 additions and 7 deletions
|
|
@ -5,14 +5,15 @@ from typing import Any, Dict, List
|
|||
|
||||
import orjson
|
||||
from fastapi import WebSocket, status
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.interface.utils import pil_to_base64
|
||||
from langflow.services import ServiceType, service_manager
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.chat.cache import Subject
|
||||
from langflow.services.chat.utils import process_graph
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from .cache import cache_service
|
||||
|
||||
|
|
@ -117,7 +118,7 @@ class ChatService(Service):
|
|||
if "after sending" in str(exc):
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict, langchain_object: Any):
|
||||
async def process_message(self, client_id: str, payload: Dict, build_result: Any):
|
||||
# Process the graph data and chat message
|
||||
chat_inputs = payload.pop("inputs", {})
|
||||
chatkey = payload.pop("chatKey", None)
|
||||
|
|
@ -134,12 +135,12 @@ class ChatService(Service):
|
|||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps, raw_output = await process_graph(
|
||||
langchain_object=langchain_object,
|
||||
build_result=build_result,
|
||||
chat_inputs=chat_inputs,
|
||||
client_id=client_id,
|
||||
session_id=self.connection_ids[client_id],
|
||||
)
|
||||
self.set_cache(client_id, langchain_object)
|
||||
self.set_cache(client_id, build_result)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -205,8 +206,8 @@ class ChatService(Service):
|
|||
continue
|
||||
|
||||
with self.chat_cache.set_client_id(client_id):
|
||||
if langchain_object := self.cache_service.get(client_id).get("result"):
|
||||
await self.process_message(client_id, payload, langchain_object)
|
||||
if build_result := self.cache_service.get(client_id).get("result"):
|
||||
await self.process_message(client_id, payload, build_result)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Could not find a build result for client_id {client_id}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue