Update function base classes to use Callable
instead of function
This commit is contained in:
parent
5c5ef227d1
commit
d87b6228df
8 changed files with 18 additions and 18 deletions
|
|
@ -36,7 +36,7 @@ TOOL_INPUTS = {
|
|||
field_type="BaseLanguageModel", required=True, is_list=False, show=True
|
||||
),
|
||||
"func": TemplateField(
|
||||
field_type="function",
|
||||
field_type="Callable",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
|
|
@ -126,7 +126,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
elif tool_type in CUSTOM_TOOLS:
|
||||
# Get custom tool params
|
||||
params = self.type_to_loader_dict[name]["params"] # type: ignore
|
||||
base_classes = ["function"]
|
||||
base_classes = ["Callable"]
|
||||
if node := customs.get_custom_nodes("tools").get(tool_type):
|
||||
return node
|
||||
elif tool_type in FILE_TOOLS:
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ class InitializeAgentNode(FrontendNode):
|
|||
],
|
||||
)
|
||||
description: str = """Construct a zero shot agent from an LLM and tools."""
|
||||
base_classes: list[str] = ["AgentExecutor", "function"]
|
||||
base_classes: list[str] = ["AgentExecutor", "Callable"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class SeriesCharacterChainNode(FrontendNode):
|
|||
"Chain",
|
||||
"ConversationChain",
|
||||
"SeriesCharacterChain",
|
||||
"function",
|
||||
"Callable",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ class CombineDocsChainNode(FrontendNode):
|
|||
],
|
||||
)
|
||||
description: str = """Load question answering chain."""
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "function"]
|
||||
base_classes: list[str] = ["BaseCombineDocumentsChain", "Callable"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class ToolNode(FrontendNode):
|
|||
),
|
||||
TemplateField(
|
||||
name="func",
|
||||
field_type="function",
|
||||
field_type="Callable",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
|
|
@ -135,7 +135,7 @@ class PythonFunctionNode(FrontendNode):
|
|||
],
|
||||
)
|
||||
description: str = "Python function to be executed."
|
||||
base_classes: list[str] = ["function"]
|
||||
base_classes: list[str] = ["Callable"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def build_template_from_function(
|
|||
# the output to be a function
|
||||
base_classes = get_base_classes(_class)
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
base_classes.append("Callable")
|
||||
|
||||
return {
|
||||
"template": format_dict(variables, name),
|
||||
|
|
@ -114,7 +114,7 @@ def build_template_from_class(
|
|||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
base_classes.append("Callable")
|
||||
return {
|
||||
"template": format_dict(variables, name),
|
||||
"description": docs.short_description or "",
|
||||
|
|
@ -178,7 +178,7 @@ def build_template_from_method(
|
|||
|
||||
# Adding function to base classes to allow the output to be a function
|
||||
if add_function:
|
||||
base_classes.append("function")
|
||||
base_classes.append("Callable")
|
||||
|
||||
return {
|
||||
"template": format_dict(variables, class_name),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers):
|
|||
"ZeroShotAgent",
|
||||
"BaseSingleActionAgent",
|
||||
"Agent",
|
||||
"function",
|
||||
"Callable",
|
||||
}
|
||||
template = zero_shot_agent["template"]
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ def test_initialize_agent(client: TestClient, logged_in_headers):
|
|||
agents = json_response["agents"]
|
||||
|
||||
initialize_agent = agents["AgentInitializer"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
|
||||
assert initialize_agent["base_classes"] == ["AgentExecutor", "Callable"]
|
||||
template = initialize_agent["template"]
|
||||
|
||||
assert template["agent"] == {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def test_conversation_chain(client: TestClient, logged_in_headers):
|
|||
"ConversationChain",
|
||||
"LLMChain",
|
||||
"Chain",
|
||||
"function",
|
||||
"Callable",
|
||||
}
|
||||
|
||||
template = chain["template"]
|
||||
|
|
@ -111,7 +111,7 @@ def test_llm_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMChain",
|
||||
"Chain",
|
||||
}
|
||||
|
|
@ -182,7 +182,7 @@ def test_llm_checker_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMCheckerChain",
|
||||
"Chain",
|
||||
}
|
||||
|
|
@ -216,7 +216,7 @@ def test_llm_math_chain(client: TestClient, logged_in_headers):
|
|||
chain = chains["LLMMathChain"]
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMMathChain",
|
||||
"Chain",
|
||||
}
|
||||
|
|
@ -309,7 +309,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers):
|
|||
|
||||
# Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects
|
||||
assert set(chain["base_classes"]) == {
|
||||
"function",
|
||||
"Callable",
|
||||
"LLMChain",
|
||||
"BaseCustomChain",
|
||||
"Chain",
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def test_build_template_from_function():
|
|||
"ExampleClass1", type_to_loader_dict, add_function=True
|
||||
)
|
||||
assert result_with_function is not None
|
||||
assert "function" in result_with_function["base_classes"]
|
||||
assert "Callable" in result_with_function["base_classes"]
|
||||
|
||||
# Test with invalid name
|
||||
with pytest.raises(ValueError, match=r".* not found"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue