🐛 fix(manager.py): add unique session id to avoid multiple clients with the same id

🐛 fix(utils.py): pass session id to process_graph function to ensure unique session id for each client
🐛 fix(base.py): pass session id to setup_callbacks function to ensure unique session id for each client
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-30 09:50:19 -03:00
commit 3efecc8b9b
3 changed files with 25 additions and 20 deletions

View file

@ -1,4 +1,5 @@
from collections import defaultdict
import uuid
from fastapi import WebSocket, status
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache import cache_manager
@ -45,6 +46,7 @@ class ChatHistory(Subject):
class ChatManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_ids: Dict[str, str] = {}
self.chat_history = ChatHistory()
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
@ -90,9 +92,13 @@ class ChatManager:
async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[client_id] = websocket
# This is to avoid having multiple clients with the same id
#! Temporary solution
self.connection_ids[client_id] = f"{client_id}-{uuid.uuid4()}"
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
self.connection_ids.pop(client_id, None)
async def send_message(self, client_id: str, message: str):
websocket = self.active_connections[client_id]
@ -134,6 +140,7 @@ class ChatManager:
langchain_object=langchain_object,
chat_inputs=chat_inputs,
websocket=self.active_connections[client_id],
session_id=self.connection_ids[client_id],
)
except Exception as e:
# Log stack trace

View file

@ -9,6 +9,7 @@ async def process_graph(
langchain_object,
chat_inputs: ChatMessage,
websocket: WebSocket,
session_id: str,
):
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug("Loaded langchain object")
@ -27,7 +28,10 @@ async def process_graph(
logger.debug("Generating result and thought")
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_inputs.message, websocket=websocket
langchain_object,
chat_inputs.message,
websocket=websocket,
session_id=session_id,
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps

View file

@ -4,6 +4,7 @@ from langflow.api.v1.callback import (
StreamingLLMCallbackHandler,
)
from langflow.processing.process import fix_memory_inputs, format_actions
from langflow.utils.logger import logger
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
@ -12,7 +13,7 @@ if TYPE_CHECKING:
from langfuse.callback import CallbackHandler # type: ignore
def setup_callbacks(sync, **kwargs):
def setup_callbacks(sync, trace_id, **kwargs):
"""Setup callbacks for langchain object"""
callbacks = []
if sync:
@ -20,31 +21,22 @@ def setup_callbacks(sync, **kwargs):
else:
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
if langfuse_callback := get_langfuse_callback():
if langfuse_callback := get_langfuse_callback(trace_id=trace_id):
logger.debug("Langfuse callback loaded")
callbacks.append(langfuse_callback)
return callbacks
def get_langfuse_callback():
from langflow.settings import settings
def get_langfuse_callback(trace_id):
from langflow.services.plugins.langfuse import LangfuseInstance
from langfuse.callback import CreateTrace
logger.debug("Initializing langfuse callback")
if settings.LANGFUSE_PUBLIC_KEY and settings.LANGFUSE_SECRET_KEY:
if langfuse := LangfuseInstance.get():
logger.debug("Langfuse credentials found")
try:
from langfuse.callback import CallbackHandler # type: ignore
return CallbackHandler(
public_key=settings.LANGFUSE_PUBLIC_KEY,
secret_key=settings.LANGFUSE_SECRET_KEY,
host=settings.LANGFUSE_HOST,
)
except ImportError as exc:
raise ImportError(
"Error importing langfuse callback. "
"Please install langfuse with `pip install langfuse`"
) from exc
trace = langfuse.trace(CreateTrace(id=trace_id))
return trace.getNewHandler()
except Exception as exc:
logger.error(f"Error initializing langfuse callback: {exc}")
@ -82,12 +74,14 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
logger.error(f"Error fixing memory inputs: {exc}")
try:
callbacks = setup_callbacks(sync=False, **kwargs)
trace_id = kwargs.pop("session_id", None)
callbacks = setup_callbacks(sync=False, trace_id=trace_id, **kwargs)
output = await langchain_object.acall(inputs, callbacks=callbacks)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
callbacks = setup_callbacks(sync=True, **kwargs)
trace_id = kwargs.pop("session_id", None)
callbacks = setup_callbacks(sync=True, trace_id=trace_id, **kwargs)
output = langchain_object(inputs, callbacks=callbacks)
# if langfuse callback is present, run callback.langfuse.flush()