Refactor process_graph function to handle ChatDefinition with dict output key

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-02 23:19:55 -03:00
commit 99ef882801

View file

@ -3,11 +3,12 @@ from typing import Any
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain_core.runnables import Runnable
from loguru import logger
from langflow.api.v1.schemas import ChatMessage
from langflow.interface.utils import try_setting_streaming_options
from langflow.processing.base import get_result_and_steps
from langflow.utils.chat import ChatDefinition
from loguru import logger
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)
@ -40,18 +41,19 @@ async def process_graph(
session_id=session_id,
)
elif isinstance(build_result, ChatDefinition):
result = await run_build_result(
raw_output = await run_build_result(
build_result,
chat_inputs,
client_id=client_id,
session_id=session_id,
)
if isinstance(result, dict):
if isinstance(raw_output, dict):
if not build_result.output_key:
raise ValueError("No output key provided to select the output from the result")
result = result[build_result.output_key]
raise ValueError("No output key provided to ChatDefinition when returning a dict.")
result = raw_output[build_result.output_key]
else:
result = raw_output
intermediate_steps = []
raw_output = result
else:
raise TypeError(f"Unknown type {type(build_result)}")
logger.debug("Generated result and intermediate_steps")
@ -63,4 +65,4 @@ async def process_graph(
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
return build_result(**chat_inputs.message)
return build_result(inputs=chat_inputs.message)