diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index d88d2525f..b7b893dda 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -1,11 +1,12 @@ +import warnings from pathlib import Path from typing import TYPE_CHECKING, List + from fastapi import HTTPException +from platformdirs import user_cache_dir + from langflow.services.store.schema import StoreComponentCreate from langflow.services.store.utils import get_lf_version_from_pypi -import warnings - -from platformdirs import user_cache_dir if TYPE_CHECKING: from langflow.services.database.models.flow.model import Flow @@ -62,7 +63,7 @@ def build_input_keys_response(langchain_object, artifacts): return input_keys_response -def update_frontend_node_with_template_values(frontend_node, raw_template_data): +def update_frontend_node_with_template_values(frontend_node, raw_frontend_node): """ Updates the given frontend node with values from the raw template data. @@ -70,19 +71,28 @@ def update_frontend_node_with_template_values(frontend_node, raw_template_data): :param raw_template_data: A dict representing raw template data. :return: Updated frontend node. """ - if not is_valid_data(frontend_node, raw_template_data): + if not is_valid_data(frontend_node, raw_frontend_node): return frontend_node - update_template_values(frontend_node["template"], raw_template_data.template) + # Check if the display_name is different than "CustomComponent" + # if so, update the display_name in the frontend_node + if raw_frontend_node["display_name"] != "CustomComponent": + frontend_node["display_name"] = raw_frontend_node["display_name"] + + update_template_values(frontend_node["template"], raw_frontend_node["template"]) return frontend_node -def is_valid_data(frontend_node, raw_template_data): +def raw_frontend_data_is_valid(raw_frontend_data): + """Check if the raw frontend data is valid for processing.""" + return "template" in raw_frontend_data and "display_name" in raw_frontend_data + + +def is_valid_data(frontend_node, raw_frontend_data): """Check if the data is valid for processing.""" - return ( - frontend_node and "template" in frontend_node and raw_template_data and hasattr(raw_template_data, "template") - ) + + return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data) def update_template_values(frontend_template, raw_template): diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index d4bb083ef..6ebb04182 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,6 +3,9 @@ from typing import Annotated, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status +from loguru import logger +from sqlmodel import select + from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, @@ -20,8 +23,6 @@ from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.user.model import User from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service -from loguru import logger -from sqlmodel import select try: from langflow.worker import process_graph_cached_task @@ -31,9 +32,10 @@ except ImportError: raise NotImplementedError("Celery is not installed") -from langflow.services.task.service import TaskService from sqlmodel import Session +from langflow.services.task.service import TaskService + # build router router = APIRouter(tags=["Base"]) @@ -218,7 +220,7 @@ async def custom_component( built_frontend_node = build_custom_component_template(component, user_id=user.id) - built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code) + built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node) return built_frontend_node diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 2d3da1d21..8c8c4efd7 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -3,11 +3,12 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union from uuid import UUID +from pydantic import BaseModel, Field, field_validator + from langflow.services.database.models.api_key.model import ApiKeyRead from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import FlowCreate, FlowRead from langflow.services.database.models.user import UserRead -from pydantic import BaseModel, Field, field_validator class BuildStatus(Enum): @@ -157,7 +158,7 @@ class StreamData(BaseModel): class CustomComponentCode(BaseModel): code: str field: Optional[str] = None - template: Optional[dict] = None + frontend_node: Optional[dict] = None class CustomComponentResponseError(BaseModel): diff --git a/src/backend/langflow/components/custom_components/CustomComponent.py b/src/backend/langflow/components/custom_components/CustomComponent.py index 35a6036dd..533ccb727 100644 --- a/src/backend/langflow/components/custom_components/CustomComponent.py +++ b/src/backend/langflow/components/custom_components/CustomComponent.py @@ -3,8 +3,6 @@ from langflow.field_typing import Data class Component(CustomComponent): - display_name: str = "Custom Component" - description: str = "Create any custom component you want!" documentation: str = "http://docs.langflow.org/components/custom" def build_config(self): diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 438cc50cf..4ff58fd27 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -19,8 +19,8 @@ from langflow.utils import validate class CustomComponent(Component): - display_name: Optional[str] = "Custom Component" - description: Optional[str] = "Custom Component" + display_name: Optional[str] = None + description: Optional[str] = None code: Optional[str] = None field_config: dict = {} code_class_base_inheritance: ClassVar[str] = "CustomComponent" diff --git a/src/backend/langflow/template/frontend_node/custom_components.py b/src/backend/langflow/template/frontend_node/custom_components.py index 22fd70814..75019f85e 100644 --- a/src/backend/langflow/template/frontend_node/custom_components.py +++ b/src/backend/langflow/template/frontend_node/custom_components.py @@ -3,7 +3,6 @@ from typing import Optional from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode from langflow.template.template.base import Template -from pydantic import field_serializer DEFAULT_CUSTOM_COMPONENT_CODE = """from langflow import CustomComponent from typing import Optional, List, Dict, Union @@ -47,7 +46,7 @@ class Component(CustomComponent): class CustomComponentFrontendNode(FrontendNode): name: str = "CustomComponent" - display_name: str = "Custom Component" + display_name: Optional[str] = "CustomComponent" beta: bool = True template: Template = Template( type_name="CustomComponent", @@ -67,9 +66,3 @@ class CustomComponentFrontendNode(FrontendNode): ) description: Optional[str] = None base_classes: list[str] = [] - - @field_serializer("display_name") - def process_display_name(self, display_name: str) -> str: - """Sets the display name of the frontend node.""" - - return display_name diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index b08b92bb9..4c81e29a5 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -156,11 +156,16 @@ export default function ParameterComponent({ }; const handleNodeClass = (newNodeClass: APIClassType, code?: string): void => { + if (!data.node) return; if (data.node!.template[name].value !== newNodeClass.template[name].value) { takeSnapshot(); } - data.node = newNodeClass; - data.node.template[name].value = code; + data.node! = { + ...newNodeClass, + description: newNodeClass.description ?? data.node!.description, + display_name: newNodeClass.display_name ?? data.node!.display_name, + }; + data.node!.template[name].value = code; updateNodeInternals(data.id); // Set state to pending //@ts-ignore @@ -174,19 +179,22 @@ export default function ParameterComponent({ } renderTooltips(); let flow = flows.find((flow) => flow.id === tabId); - if (reactFlowInstance && flow && flow.data) { - cleanEdges({ - flow: { - edges: flow.data!.edges, - nodes: flow.data!.nodes, - }, - updateEdge: (edge) => { - reactFlowInstance.setEdges(edge); - updateNodeInternals(data.id); - }, - }); - updateFlow(flow); - } + setTimeout(() => { + //timeout necessary because ReactFlow updates are not async + if (reactFlowInstance && flow && flow.data) { + cleanEdges({ + flow: { + edges: flow.data!.edges, + nodes: flow.data!.nodes, + }, + updateEdge: (edge) => { + reactFlowInstance.setEdges(edge); + updateNodeInternals(data.id); + }, + }); + updateFlow(flow); + } + }, 50); }; const [errorDuplicateKey, setErrorDuplicateKey] = useState(false); diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index 73ab9aeef..032349cfd 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -16,7 +16,7 @@ import { validationStatusType } from "../../types/components"; import { NodeDataType } from "../../types/flow"; import { handleKeyDown, scapedJSONStringfy } from "../../utils/reactflowUtils"; import { nodeColors, nodeIconsLucide } from "../../utils/styleUtils"; -import { classNames, getFieldTitle } from "../../utils/utils"; +import { classNames, cn, getFieldTitle } from "../../utils/utils"; import ParameterComponent from "./components/parameterComponent"; export default function GenericNode({ @@ -107,6 +107,8 @@ export default function GenericNode({ const showNode = data.showNode ?? true; + const nameEditable = data.node?.flow || data.type === "CustomComponent"; + return ( <> @@ -164,7 +166,7 @@ export default function GenericNode({ /> {showNode && (
- {data.node?.flow && inputName ? ( + {nameEditable && inputName ? (
{ @@ -172,6 +174,7 @@ export default function GenericNode({ if (nodeName.trim() !== "") { setNodeName(nodeName); data.node!.display_name = nodeName; + updateNodeInternals(data.id); } else { setNodeName(data.node!.display_name); } @@ -194,7 +197,7 @@ export default function GenericNode({
{data.node?.display_name}
- {data.node?.flow && ( + {nameEditable && ( - {data.node?.description !== "" && - showNode && - data.node?.flow && - inputDescription ? ( -