Refactor callback handling in base.py

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-06 11:05:11 -03:00
commit 667e7c42ff

View file

@ -2,8 +2,10 @@ from typing import TYPE_CHECKING, List, Union
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langflow.api.v1.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler
from langflow.api.v1.callback import (AsyncStreamingLLMCallbackHandler,
StreamingLLMCallbackHandler)
from langflow.processing.process import fix_memory_inputs, format_actions
from langflow.services.deps import get_plugins_service
from loguru import logger
if TYPE_CHECKING:
@ -18,9 +20,10 @@ def setup_callbacks(sync, trace_id, **kwargs):
else:
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
if langfuse_callback := get_langfuse_callback(trace_id=trace_id):
logger.debug("Langfuse callback loaded")
callbacks.append(langfuse_callback)
plugin_service = get_plugins_service()
plugin_callbacks = plugin_service.get_callbacks(_id=trace_id)
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
return callbacks