Update function base classes to use Callable

instead of function
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-02 22:45:49 -03:00
commit d87b6228df
8 changed files with 18 additions and 18 deletions

View file

@ -36,7 +36,7 @@ TOOL_INPUTS = {
field_type="BaseLanguageModel", required=True, is_list=False, show=True
),
"func": TemplateField(
field_type="function",
field_type="Callable",
required=True,
is_list=False,
show=True,
@ -126,7 +126,7 @@ class ToolCreator(LangChainTypeCreator):
elif tool_type in CUSTOM_TOOLS:
# Get custom tool params
params = self.type_to_loader_dict[name]["params"] # type: ignore
base_classes = ["function"]
base_classes = ["Callable"]
if node := customs.get_custom_nodes("tools").get(tool_type):
return node
elif tool_type in FILE_TOOLS:

View file

@ -206,7 +206,7 @@ class InitializeAgentNode(FrontendNode):
],
)
description: str = """Construct a zero shot agent from an LLM and tools."""
base_classes: list[str] = ["AgentExecutor", "function"]
base_classes: list[str] = ["AgentExecutor", "Callable"]
def to_dict(self):
return super().to_dict()

View file

@ -140,7 +140,7 @@ class SeriesCharacterChainNode(FrontendNode):
"Chain",
"ConversationChain",
"SeriesCharacterChain",
"function",
"Callable",
]
@ -241,7 +241,7 @@ class CombineDocsChainNode(FrontendNode):
],
)
description: str = """Load question answering chain."""
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
base_classes: list[str] = ["BaseCombineDocumentsChain", "Callable"]
def to_dict(self):
return super().to_dict()

View file

@ -35,7 +35,7 @@ class ToolNode(FrontendNode):
),
TemplateField(
name="func",
field_type="function",
field_type="Callable",
required=True,
is_list=False,
show=True,
@ -135,7 +135,7 @@ class PythonFunctionNode(FrontendNode):
],
)
description: str = "Python function to be executed."
base_classes: list[str] = ["function"]
base_classes: list[str] = ["Callable"]
def to_dict(self):
return super().to_dict()

View file

@ -60,7 +60,7 @@ def build_template_from_function(
# the output to be a function
base_classes = get_base_classes(_class)
if add_function:
base_classes.append("function")
base_classes.append("Callable")
return {
"template": format_dict(variables, name),
@ -114,7 +114,7 @@ def build_template_from_class(
# Adding function to base classes to allow
# the output to be a function
if add_function:
base_classes.append("function")
base_classes.append("Callable")
return {
"template": format_dict(variables, name),
"description": docs.short_description or "",
@ -178,7 +178,7 @@ def build_template_from_method(
# Adding function to base classes to allow the output to be a function
if add_function:
base_classes.append("function")
base_classes.append("Callable")
return {
"template": format_dict(variables, class_name),

View file

@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
"ZeroShotAgent",
"BaseSingleActionAgent",
"Agent",
"function",
"Callable",
}
template = zero_shot_agent["template"]
@ -202,7 +202,7 @@ def test_initialize_agent(client: TestClient, logged_in_headers):
agents = json_response["agents"]
initialize_agent = agents["AgentInitializer"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "Callable"]
template = initialize_agent["template"]
assert template["agent"] == {

View file

@ -22,7 +22,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers):
"ConversationChain",
"LLMChain",
"Chain",
"function",
"Callable",
}
template = chain["template"]
@ -111,7 +111,7 @@ def test_llm_chain(client: TestClient, logged_in_headers):
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"function",
"Callable",
"LLMChain",
"Chain",
}
@ -182,7 +182,7 @@ def test_llm_checker_chain(client: TestClient, logged_in_headers):
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"function",
"Callable",
"LLMCheckerChain",
"Chain",
}
@ -216,7 +216,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
chain = chains["LLMMathChain"]
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"function",
"Callable",
"LLMMathChain",
"Chain",
}
@ -309,7 +309,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
assert set(chain["base_classes"]) == {
"function",
"Callable",
"LLMChain",
"BaseCustomChain",
"Chain",

View file

@ -69,7 +69,7 @@ def test_build_template_from_function():
"ExampleClass1", type_to_loader_dict, add_function=True
)
assert result_with_function is not None
assert "function" in result_with_function["base_classes"]
assert "Callable" in result_with_function["base_classes"]
# Test with invalid name
with pytest.raises(ValueError, match=r".* not found"):