🐛 fix(vertex/base.py): fix vertex_type assignment to handle uppercase template types

🐛 fix(interface/loading.py): fix custom_node instantiation to handle classes without initialize method
 feat(template/frontend_node/prompts.py): change type_name to match class name
 feat(template/frontend_node/tools.py): change type_name to match class name
🔥 chore(test_agents_template.py): remove test_agents_settings and update initialize_agent test
The vertex_type assignment in the Vertex class was not handling uppercase template types correctly. This has been fixed to handle both uppercase and lowercase types. The custom_node instantiation in the instantiate_class function was not handling classes without an initialize method correctly. This has been fixed to instantiate the class directly if the initialize method is not present. The type_name in the ZeroShotPromptNode and PythonFunctionToolNode classes have been changed to match the class name. The test_agents_settings test has been removed as it is no longer necessary and the initialize_agent test has been updated to match the new AgentInitializer class name.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-07 21:36:40 -03:00
commit 81d231c632
5 changed files with 13 additions and 20 deletions

View file

@ -48,10 +48,12 @@ class Vertex:
]
template_dict = self.data["node"]["template"]
# self.vertex_type = (
# self.data["type"] if "Tool" not in self.output else template_dict["_type"]
# )
self.vertex_type = template_dict["_type"]
self.vertex_type = (
self.data["type"]
if "Tool" not in self.output or template_dict["_type"].islower()
else template_dict["_type"]
)
# self.vertex_type = template_dict["_type"]
if self.base_type is None:
for base_type, value in ALL_TYPES_DICT.items():
if self.vertex_type in value:

View file

@ -33,8 +33,10 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_NODES:
if custom_agent := CUSTOM_NODES.get(node_type):
return custom_agent.initialize(**params)
if custom_node := CUSTOM_NODES.get(node_type):
if hasattr(custom_node, "initialize"):
return custom_node.initialize(**params)
return custom_node(**params)
class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(class_object, base_type, node_type, params)

View file

@ -74,7 +74,7 @@ class BasePromptFrontendNode(FrontendNode):
class ZeroShotPromptNode(BasePromptFrontendNode):
name: str = "ZeroShotPrompt"
template: Template = Template(
type_name="zero_shot",
type_name="ZeroShotPrompt",
fields=[
TemplateField(
field_type="str",

View file

@ -108,7 +108,7 @@ class PythonFunctionToolNode(FrontendNode):
class PythonFunctionNode(FrontendNode):
name: str = "PythonFunction"
template: Template = Template(
type_name="python_function",
type_name="PythonFunction",
fields=[
TemplateField(
field_type="code",

View file

@ -1,15 +1,4 @@
from fastapi.testclient import TestClient
from langflow.settings import settings
# check that all agents are in settings.agents
# are in json_response["agents"]
def test_agents_settings(client: TestClient):
response = client.get("api/v1/all")
assert response.status_code == 200
json_response = response.json()
agents = json_response["agents"]
assert set(agents.keys()) == set(settings.agents)
def test_zero_shot_agent(client: TestClient):
@ -131,7 +120,7 @@ def test_initialize_agent(client: TestClient):
json_response = response.json()
agents = json_response["agents"]
initialize_agent = agents["initialize_agent"]
initialize_agent = agents["AgentInitializer"]
assert initialize_agent["base_classes"] == ["AgentExecutor", "function"]
template = initialize_agent["template"]