Add support for different types of build_result in process_graph function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-02 22:36:17 -03:00
commit e69f3cfdef

View file

@ -1,20 +1,28 @@
from typing import Any
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain_core.runnables import Runnable
from loguru import logger
from langflow.api.v1.schemas import ChatMessage
from langflow.interface.utils import try_setting_streaming_options
from langflow.processing.base import get_result_and_steps
from langflow.utils.chat import ChatAdapter
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)
async def process_graph(
langchain_object,
build_result,
chat_inputs: ChatMessage,
client_id: str,
session_id: str,
):
langchain_object = try_setting_streaming_options(langchain_object)
build_result = try_setting_streaming_options(build_result)
logger.debug("Loaded langchain object")
if langchain_object is None:
if build_result is None:
# Raise user facing error
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
@ -25,15 +33,31 @@ async def process_graph(
chat_inputs.message = {}
logger.debug("Generating result and thought")
result, intermediate_steps, raw_output = await get_result_and_steps(
langchain_object,
chat_inputs.message,
client_id=client_id,
session_id=session_id,
)
if isinstance(build_result, LANGCHAIN_RUNNABLES):
result, intermediate_steps, raw_output = await get_result_and_steps(
build_result,
chat_inputs.message,
client_id=client_id,
session_id=session_id,
)
elif isinstance(build_result, ChatAdapter):
result = await run_build_result(
build_result,
chat_inputs,
client_id=client_id,
session_id=session_id,
)
intermediate_steps = []
raw_output = result
else:
raise TypeError(f"Unknown type {type(build_result)}")
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps, raw_output
except Exception as e:
# Log stack trace
logger.exception(e)
raise e
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
return build_result(**chat_inputs.message)