diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index e61eb6e26..46fe06408 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,33 +1,24 @@ from http import HTTPStatus from typing import Annotated, Optional, Union -from langflow.services.auth.utils import api_key_security, get_current_active_user - -from langflow.services.cache.utils import save_uploaded_file -from langflow.services.database.models.flow import Flow -from langflow.processing.process import process_graph_cached, process_tweaks -from langflow.services.database.models.user.user import User -from langflow.services.deps import ( - get_session_service, - get_settings_service, - get_task_service, -) -from loguru import logger -from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status import sqlalchemy as sa -from langflow.interface.custom.custom_component import CustomComponent - +from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status +from loguru import logger from langflow.api.v1.schemas import ( + CustomComponentCode, ProcessResponse, TaskResponse, TaskStatusResponse, UploadFileResponse, - CustomComponentCode, ) - - -from langflow.services.deps import get_session +from langflow.interface.types import build_langchain_template_custom_component, create_and_validate_component +from langflow.processing.process import process_graph_cached, process_tweaks +from langflow.services.auth.utils import api_key_security, get_current_active_user +from langflow.services.cache.utils import save_uploaded_file +from langflow.services.database.models.flow import Flow +from langflow.services.database.models.user.user import User +from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service try: from langflow.worker import process_graph_cached_task @@ -39,7 +30,6 @@ except ImportError: from sqlmodel import Session - from langflow.services.task.service import TaskService # build router @@ -212,28 +202,18 @@ async def custom_component( raw_code: CustomComponentCode, user: User = Depends(get_current_active_user), ): - from langflow.interface.types import ( - build_langchain_template_custom_component, - ) + component = create_and_validate_component(raw_code.code) - extractor = CustomComponent(code=raw_code.code) - extractor.is_check_valid() - - return build_langchain_template_custom_component(extractor, user_id=user.id) + return build_langchain_template_custom_component(component, user_id=user.id) @router.post("/custom_component/update", status_code=HTTPStatus.OK) async def custom_component_update( raw_code: CustomComponentCode, - field: str, user: User = Depends(get_current_active_user), ): - from langflow.interface.types import ( - build_langchain_template_custom_component, - ) + component = create_and_validate_component(raw_code.code) - extractor = CustomComponent(code=raw_code.code) - extractor.is_check_valid() - - component_node = build_langchain_template_custom_component(extractor, user_id=user.id) + component_node = build_langchain_template_custom_component(component, user_id=user.id, update_field=raw_code.field) # Update the field + return component_node diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 6338f9398..ba3cdc1ba 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -343,7 +343,7 @@ def build_langchain_template_custom_component( update_attributes(frontend_node, template_config) logger.debug("Updated attributes") - field_config = build_field_config(custom_component, user_id=user_id) + field_config = build_field_config(custom_component, user_id=user_id, update_field=update_field) logger.debug("Built field config") entrypoint_args = custom_component.get_function_entrypoint_args @@ -527,3 +527,9 @@ def merge_nested_dicts(dict1, dict2): else: dict1[key] = value return dict1 + + +def create_and_validate_component(code: str) -> CustomComponent: + component = CustomComponent(code=code) + component.is_check_valid() + return component