diff --git a/src/backend/langflow/interface/embeddings/base.py b/src/backend/langflow/interface/embeddings/base.py index 061b1d3b5..933d8ad39 100644 --- a/src/backend/langflow/interface/embeddings/base.py +++ b/src/backend/langflow/interface/embeddings/base.py @@ -1,8 +1,10 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator from langflow.interface.custom_lists import embedding_type_to_cls_dict from langflow.settings import settings +from langflow.template.base import FrontendNode +from langflow.template.nodes import EmbeddingFrontendNode from langflow.utils.logger import logger from langflow.utils.util import build_template_from_class @@ -14,6 +16,10 @@ class EmbeddingCreator(LangChainTypeCreator): def type_to_loader_dict(self) -> Dict: return embedding_type_to_cls_dict + @property + def frontend_node_class(self) -> Type[FrontendNode]: + return EmbeddingFrontendNode + def get_signature(self, name: str) -> Optional[Dict]: """Get the signature of an embedding.""" try: diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 6cd9246b8..93135067b 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -108,6 +108,7 @@ def instantiate_toolkit(node_type, class_object, params): def instantiate_embedding(class_object, params): params.pop("model", None) + params.pop("headers", None) try: return class_object(**params) except ValidationError: diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index 6ce08e57b..a4d3fa1cf 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -614,3 +614,11 @@ class LLMFrontendNode(FrontendNode): elif field.name in ["model_name", "temperature"]: field.advanced = False field.show = True + + +class EmbeddingFrontendNode(FrontendNode): + @staticmethod + def format_field(field: TemplateField, name: Optional[str] = None) -> None: + FrontendNode.format_field(field, name) + if field.name == "headers": + field.show = False