🐛 fix(manager.py): add unique session id to avoid multiple clients with the same id
🐛 fix(utils.py): pass session id to process_graph function to ensure unique session id for each client 🐛 fix(base.py): pass session id to setup_callbacks function to ensure unique session id for each client
This commit is contained in:
parent
55bd1dc6ec
commit
3efecc8b9b
3 changed files with 25 additions and 20 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from collections import defaultdict
|
||||
import uuid
|
||||
from fastapi import WebSocket, status
|
||||
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache import cache_manager
|
||||
|
|
@ -45,6 +46,7 @@ class ChatHistory(Subject):
|
|||
class ChatManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.connection_ids: Dict[str, str] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
|
|
@ -90,9 +92,13 @@ class ChatManager:
|
|||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
# This is to avoid having multiple clients with the same id
|
||||
#! Temporary solution
|
||||
self.connection_ids[client_id] = f"{client_id}-{uuid.uuid4()}"
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
self.active_connections.pop(client_id, None)
|
||||
self.connection_ids.pop(client_id, None)
|
||||
|
||||
async def send_message(self, client_id: str, message: str):
|
||||
websocket = self.active_connections[client_id]
|
||||
|
|
@ -134,6 +140,7 @@ class ChatManager:
|
|||
langchain_object=langchain_object,
|
||||
chat_inputs=chat_inputs,
|
||||
websocket=self.active_connections[client_id],
|
||||
session_id=self.connection_ids[client_id],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ async def process_graph(
|
|||
langchain_object,
|
||||
chat_inputs: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
session_id: str,
|
||||
):
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
|
@ -27,7 +28,10 @@ async def process_graph(
|
|||
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_inputs.message, websocket=websocket
|
||||
langchain_object,
|
||||
chat_inputs.message,
|
||||
websocket=websocket,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from langflow.api.v1.callback import (
|
|||
StreamingLLMCallbackHandler,
|
||||
)
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
|
||||
from langflow.utils.logger import logger
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
|
@ -12,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from langfuse.callback import CallbackHandler # type: ignore
|
||||
|
||||
|
||||
def setup_callbacks(sync, **kwargs):
|
||||
def setup_callbacks(sync, trace_id, **kwargs):
|
||||
"""Setup callbacks for langchain object"""
|
||||
callbacks = []
|
||||
if sync:
|
||||
|
|
@ -20,31 +21,22 @@ def setup_callbacks(sync, **kwargs):
|
|||
else:
|
||||
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
|
||||
|
||||
if langfuse_callback := get_langfuse_callback():
|
||||
if langfuse_callback := get_langfuse_callback(trace_id=trace_id):
|
||||
logger.debug("Langfuse callback loaded")
|
||||
callbacks.append(langfuse_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def get_langfuse_callback():
|
||||
from langflow.settings import settings
|
||||
def get_langfuse_callback(trace_id):
|
||||
from langflow.services.plugins.langfuse import LangfuseInstance
|
||||
from langfuse.callback import CreateTrace
|
||||
|
||||
logger.debug("Initializing langfuse callback")
|
||||
if settings.LANGFUSE_PUBLIC_KEY and settings.LANGFUSE_SECRET_KEY:
|
||||
if langfuse := LangfuseInstance.get():
|
||||
logger.debug("Langfuse credentials found")
|
||||
try:
|
||||
from langfuse.callback import CallbackHandler # type: ignore
|
||||
|
||||
return CallbackHandler(
|
||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||
host=settings.LANGFUSE_HOST,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Error importing langfuse callback. "
|
||||
"Please install langfuse with `pip install langfuse`"
|
||||
) from exc
|
||||
trace = langfuse.trace(CreateTrace(id=trace_id))
|
||||
return trace.getNewHandler()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error initializing langfuse callback: {exc}")
|
||||
|
||||
|
|
@ -82,12 +74,14 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
|
|||
logger.error(f"Error fixing memory inputs: {exc}")
|
||||
|
||||
try:
|
||||
callbacks = setup_callbacks(sync=False, **kwargs)
|
||||
trace_id = kwargs.pop("session_id", None)
|
||||
callbacks = setup_callbacks(sync=False, trace_id=trace_id, **kwargs)
|
||||
output = await langchain_object.acall(inputs, callbacks=callbacks)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
callbacks = setup_callbacks(sync=True, **kwargs)
|
||||
trace_id = kwargs.pop("session_id", None)
|
||||
callbacks = setup_callbacks(sync=True, trace_id=trace_id, **kwargs)
|
||||
output = langchain_object(inputs, callbacks=callbacks)
|
||||
|
||||
# if langfuse callback is present, run callback.langfuse.flush()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue