refactor(loading.py): make allowed_tools a list if it's not already a list or set

This commit is contained in:
Gabriel Almeida 2023-05-02 11:29:38 -03:00
commit c623b02bf2

View file

@ -1,5 +1,5 @@
import json
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Iterable, Optional
from langchain.agents import ZeroShotAgent
from langchain.agents import agent as agent_module
@ -195,6 +195,9 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
"""Load agent executor from agent class, tools and chain"""
allowed_tools = params["allowed_tools"]
llm_chain = params["llm_chain"]
# if allowed_tools is not a list or set, make it a list
if not isinstance(allowed_tools, (list, set)):
allowed_tools = [allowed_tools]
tool_names = [tool.name for tool in allowed_tools]
# Agent class requires an output_parser but Agent classes
# have a default output_parser.