🐛 fix(vertex/base.py): fix vertex_type assignment to handle uppercase template types
🐛 fix(interface/loading.py): fix custom_node instantiation to handle classes without initialize method ✨ feat(template/frontend_node/prompts.py): change type_name to match class name ✨ feat(template/frontend_node/tools.py): change type_name to match class name 🔥 chore(test_agents_template.py): remove test_agents_settings and update initialize_agent test The vertex_type assignment in the Vertex class was not handling uppercase template types correctly. This has been fixed to handle both uppercase and lowercase types. The custom_node instantiation in the instantiate_class function was not handling classes without an initialize method correctly. This has been fixed to instantiate the class directly if the initialize method is not present. The type_name in the ZeroShotPromptNode and PythonFunctionToolNode classes have been changed to match the class name. The test_agents_settings test has been removed as it is no longer necessary and the initialize_agent test has been updated to match the new AgentInitializer class name.
This commit is contained in:
parent
04bd0f43fb
commit
81d231c632
5 changed files with 13 additions and 20 deletions
|
|
@ -48,10 +48,12 @@ class Vertex:
|
|||
]
|
||||
|
||||
template_dict = self.data["node"]["template"]
|
||||
# self.vertex_type = (
|
||||
# self.data["type"] if "Tool" not in self.output else template_dict["_type"]
|
||||
# )
|
||||
self.vertex_type = template_dict["_type"]
|
||||
self.vertex_type = (
|
||||
self.data["type"]
|
||||
if "Tool" not in self.output or template_dict["_type"].islower()
|
||||
else template_dict["_type"]
|
||||
)
|
||||
# self.vertex_type = template_dict["_type"]
|
||||
if self.base_type is None:
|
||||
for base_type, value in ALL_TYPES_DICT.items():
|
||||
if self.vertex_type in value:
|
||||
|
|
|
|||
|
|
@ -33,8 +33,10 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_agent := CUSTOM_NODES.get(node_type):
|
||||
return custom_agent.initialize(**params)
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class BasePromptFrontendNode(FrontendNode):
|
|||
class ZeroShotPromptNode(BasePromptFrontendNode):
|
||||
name: str = "ZeroShotPrompt"
|
||||
template: Template = Template(
|
||||
type_name="zero_shot",
|
||||
type_name="ZeroShotPrompt",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class PythonFunctionToolNode(FrontendNode):
|
|||
class PythonFunctionNode(FrontendNode):
|
||||
name: str = "PythonFunction"
|
||||
template: Template = Template(
|
||||
type_name="python_function",
|
||||
type_name="PythonFunction",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="code",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue