temporary fix for intermediate steps

This commit is contained in:
Gabriel Almeida 2023-05-13 21:49:06 -03:00
commit 2bc7d3afd8

View file

@ -1,6 +1,6 @@
import contextlib
import io
from typing import Any, Dict
from typing import Any, Dict, List, Tuple
from chromadb.errors import NotEnoughElementsException # type: ignore
@ -8,6 +8,7 @@ from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLM
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
from langflow.utils.logger import logger
from langchain.schema import AgentAction
def load_langchain_object(data_graph, is_first_message=False):
@ -175,33 +176,25 @@ async def get_result_and_steps(langchain_object, message: str, **kwargs):
langchain_object.return_intermediate_steps = True
fix_memory_inputs(langchain_object)
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(
chat_input, callbacks=async_callbacks
)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
else:
thought = output_buffer.getvalue()
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except NotEnoughElementsException as exc:
raise ValueError(
"Error: Not enough documents for ChromaDB to index. Try reducing chunk size in TextSplitter."
@ -257,7 +250,7 @@ def get_result_and_thought(langchain_object, message: str):
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
thought = format_actions(intermediate_steps)
else:
thought = output_buffer.getvalue()
@ -266,19 +259,17 @@ def get_result_and_thought(langchain_object, message: str):
return result, thought
def format_intermediate_steps(intermediate_steps):
formatted_chain = "> Entering new AgentExecutor chain...\n"
for step in intermediate_steps:
action = step[0]
observation = step[1]
formatted_chain += (
f" {action.log}\nAction: {action.tool}\nAction Input: {action.tool_input}\n"
)
formatted_chain += f"Observation: {observation}\n"
final_answer = f"Final Answer: {observation}\n"
formatted_chain += f"Thought: I now know the final answer\n{final_answer}\n"
formatted_chain += "> Finished chain.\n"
return formatted_chain
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
"""Format a list of (AgentAction, answer) tuples into a string."""
output = []
for action, answer in actions:
log = action.log
tool = action.tool
tool_input = action.tool_input
output.append(f"Log: {log}")
if "Action" not in log and "Action Input" not in log:
output.append(f"Tool: {tool}")
output.append(f"Tool Input: {tool_input}")
output.append(f"Answer: {answer}")
output.append("") # Add a blank line
return "\n".join(output)