diff --git a/src/backend/langflow/chat/manager.py b/src/backend/langflow/chat/manager.py index 2c3427a12..a04da871e 100644 --- a/src/backend/langflow/chat/manager.py +++ b/src/backend/langflow/chat/manager.py @@ -1,4 +1,5 @@ from collections import defaultdict +import uuid from fastapi import WebSocket, status from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse from langflow.cache import cache_manager @@ -45,6 +46,7 @@ class ChatHistory(Subject): class ChatManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} + self.connection_ids: Dict[str, str] = {} self.chat_history = ChatHistory() self.cache_manager = cache_manager self.cache_manager.attach(self.update) @@ -90,9 +92,13 @@ class ChatManager: async def connect(self, client_id: str, websocket: WebSocket): await websocket.accept() self.active_connections[client_id] = websocket + # This is to avoid having multiple clients with the same id + #! Temporary solution + self.connection_ids[client_id] = f"{client_id}-{uuid.uuid4()}" def disconnect(self, client_id: str): self.active_connections.pop(client_id, None) + self.connection_ids.pop(client_id, None) async def send_message(self, client_id: str, message: str): websocket = self.active_connections[client_id] @@ -134,6 +140,7 @@ class ChatManager: langchain_object=langchain_object, chat_inputs=chat_inputs, websocket=self.active_connections[client_id], + session_id=self.connection_ids[client_id], ) except Exception as e: # Log stack trace diff --git a/src/backend/langflow/chat/utils.py b/src/backend/langflow/chat/utils.py index 17c976eb9..e11b38b64 100644 --- a/src/backend/langflow/chat/utils.py +++ b/src/backend/langflow/chat/utils.py @@ -9,6 +9,7 @@ async def process_graph( langchain_object, chat_inputs: ChatMessage, websocket: WebSocket, + session_id: str, ): langchain_object = try_setting_streaming_options(langchain_object, websocket) logger.debug("Loaded langchain object") @@ -27,7 +28,10 @@ async def process_graph( logger.debug("Generating result and thought") result, intermediate_steps = await get_result_and_steps( - langchain_object, chat_inputs.message, websocket=websocket + langchain_object, + chat_inputs.message, + websocket=websocket, + session_id=session_id, ) logger.debug("Generated result and intermediate_steps") return result, intermediate_steps diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py index 4e2b1d716..bbd68a47b 100644 --- a/src/backend/langflow/processing/base.py +++ b/src/backend/langflow/processing/base.py @@ -4,6 +4,7 @@ from langflow.api.v1.callback import ( StreamingLLMCallbackHandler, ) from langflow.processing.process import fix_memory_inputs, format_actions + from langflow.utils.logger import logger from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler @@ -12,7 +13,7 @@ if TYPE_CHECKING: from langfuse.callback import CallbackHandler # type: ignore -def setup_callbacks(sync, **kwargs): +def setup_callbacks(sync, trace_id, **kwargs): """Setup callbacks for langchain object""" callbacks = [] if sync: @@ -20,31 +21,22 @@ def setup_callbacks(sync, **kwargs): else: callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs)) - if langfuse_callback := get_langfuse_callback(): + if langfuse_callback := get_langfuse_callback(trace_id=trace_id): logger.debug("Langfuse callback loaded") callbacks.append(langfuse_callback) return callbacks -def get_langfuse_callback(): - from langflow.settings import settings +def get_langfuse_callback(trace_id): + from langflow.services.plugins.langfuse import LangfuseInstance + from langfuse.callback import CreateTrace logger.debug("Initializing langfuse callback") - if settings.LANGFUSE_PUBLIC_KEY and settings.LANGFUSE_SECRET_KEY: + if langfuse := LangfuseInstance.get(): logger.debug("Langfuse credentials found") try: - from langfuse.callback import CallbackHandler # type: ignore - - return CallbackHandler( - public_key=settings.LANGFUSE_PUBLIC_KEY, - secret_key=settings.LANGFUSE_SECRET_KEY, - host=settings.LANGFUSE_HOST, - ) - except ImportError as exc: - raise ImportError( - "Error importing langfuse callback. " - "Please install langfuse with `pip install langfuse`" - ) from exc + trace = langfuse.trace(CreateTrace(id=trace_id)) + return trace.getNewHandler() except Exception as exc: logger.error(f"Error initializing langfuse callback: {exc}") @@ -82,12 +74,14 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa logger.error(f"Error fixing memory inputs: {exc}") try: - callbacks = setup_callbacks(sync=False, **kwargs) + trace_id = kwargs.pop("session_id", None) + callbacks = setup_callbacks(sync=False, trace_id=trace_id, **kwargs) output = await langchain_object.acall(inputs, callbacks=callbacks) except Exception as exc: # make the error message more informative logger.debug(f"Error: {str(exc)}") - callbacks = setup_callbacks(sync=True, **kwargs) + trace_id = kwargs.pop("session_id", None) + callbacks = setup_callbacks(sync=True, trace_id=trace_id, **kwargs) output = langchain_object(inputs, callbacks=callbacks) # if langfuse callback is present, run callback.langfuse.flush()