diff --git a/src/backend/langflow/template/frontend_node/embeddings.py b/src/backend/langflow/template/frontend_node/embeddings.py index d21e12e73..d466b11a4 100644 --- a/src/backend/langflow/template/frontend_node/embeddings.py +++ b/src/backend/langflow/template/frontend_node/embeddings.py @@ -22,6 +22,22 @@ class EmbeddingFrontendNode(FrontendNode): field.display_name = "Jina API URL" field.password = False + @staticmethod + def format_openai_fields(field: TemplateField): + if "openai" in field.name: + field.show = True + field.advanced = True + split_name = field.name.split("_") + title_name = " ".join([s.capitalize() for s in split_name]) + field.display_name = title_name.replace("Openai", "OpenAI").replace( + "Api", "API" + ) + + if "api_key" in field.name: + field.password = True + field.show = True + field.advanced = False + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: FrontendNode.format_field(field, name) @@ -30,9 +46,6 @@ class EmbeddingFrontendNode(FrontendNode): if field.name == "headers": field.show = False - if "openai" in field.name: - field.show = True - field.advanced = "api_key" not in field.name - # Format Jina fields EmbeddingFrontendNode.format_jina_fields(field) + EmbeddingFrontendNode.format_openai_fields(field) diff --git a/tests/test_embeddings_template.py b/tests/test_embeddings_template.py new file mode 100644 index 000000000..7334c2abd --- /dev/null +++ b/tests/test_embeddings_template.py @@ -0,0 +1,59 @@ +from langflow.template.field.base import TemplateField +from langflow.template.frontend_node.embeddings import EmbeddingFrontendNode + + +def test_format_jina_fields(): + field = TemplateField(name="jina") + EmbeddingFrontendNode.format_jina_fields(field) + assert field.show is True + assert field.advanced is False + + field = TemplateField(name="auth") + EmbeddingFrontendNode.format_jina_fields(field) + assert field.password is True + assert field.show is True + assert field.advanced is False + + field = TemplateField(name="jina_api_url") + EmbeddingFrontendNode.format_jina_fields(field) + assert field.show is True + assert field.advanced is True + assert field.display_name == "Jina API URL" + assert field.password is False + + +def test_format_openai_fields(): + field = TemplateField(name="openai") + EmbeddingFrontendNode.format_openai_fields(field) + assert field.show is True + assert field.advanced is True + assert field.display_name == "OpenAI" + + field = TemplateField(name="openai_api_key") + EmbeddingFrontendNode.format_openai_fields(field) + assert field.password is True + assert field.show is True + assert field.advanced is False + + +def test_format_field(): + field = TemplateField(name="headers") + EmbeddingFrontendNode.format_field(field) + assert field.show is False + + field = TemplateField(name="jina") + EmbeddingFrontendNode.format_field(field) + assert field.advanced is False + assert field.show is True + + field = TemplateField(name="openai") + EmbeddingFrontendNode.format_field(field) + assert field.advanced is True + assert field.show is True + assert field.display_name == "OpenAI" + + field = TemplateField(name="test_field", required=True) + EmbeddingFrontendNode.format_field(field) + assert field.advanced is False + assert field.show is True + assert field.required is True