Fix various issues and refactor codebase (#1196)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-11 14:59:20 -03:00 committed by GitHub
commit bf114172b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 532 additions and 476 deletions

1
.gitignore vendored
View file

@ -166,6 +166,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.testmondata*
# Translations
*.mo

View file

@ -20,7 +20,7 @@ coverage:
tests:
@make install_backend
poetry run pytest tests
poetry run pytest tests --instafail
tests_frontend:
ifeq ($(UI), true)

693
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -125,6 +125,7 @@ pytest-xdist = "^3.3.1"
types-pywin32 = "^306.0.0.4"
types-google-cloud-ndb = "^2.2.0.0"
pytest-sugar = "^0.9.7"
pytest-instafail = "^0.5.0"
[tool.poetry.extras]

View file

@ -9,17 +9,19 @@ from typing import Optional
import httpx
import typer
from dotenv import load_dotenv
from langflow.main import setup_app
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.utils import initialize_services, initialize_settings_service
from langflow.utils.logger import configure, logger
from multiprocess import Process, cpu_count # type: ignore
from rich import box
from rich import print as rprint
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from sqlmodel import select
from langflow.main import setup_app
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.utils import initialize_services, initialize_settings_service
from langflow.utils.logger import configure, logger
console = Console()
@ -338,7 +340,7 @@ def superuser(
# Verify that the superuser was created
from langflow.services.database.models.user.model import User
user: User = session.query(User).filter(User.username == username).first()
user: User = session.exec(select(User).where(User.username == username)).first()
if user is None or not user.is_superuser:
typer.echo("Superuser creation failed.")
return

View file

@ -1,7 +1,10 @@
from typing import Optional
from langflow import CustomComponent
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.llms.base import BaseLLM
from pydantic.v1 import SecretStr
from langflow import CustomComponent
class QianfanChatEndpointComponent(CustomComponent):
@ -80,8 +83,8 @@ class QianfanChatEndpointComponent(CustomComponent):
try:
output = QianfanChatEndpoint( # type: ignore
model=model,
qianfan_ak=qianfan_ak,
qianfan_sk=qianfan_sk,
qianfan_ak=SecretStr(qianfan_ak) if qianfan_ak else None,
qianfan_sk=SecretStr(qianfan_sk) if qianfan_sk else None,
top_p=top_p,
temperature=temperature,
penalty_score=penalty_score,

View file

@ -165,6 +165,9 @@ class CustomComponent(Component):
return next(iter(classes), "")
@property
def template_config(self):
return self.build_template_config()
def build_template_config(self):
if not self.code:
return {}

View file

@ -31,7 +31,6 @@ from langflow.interface.utilities.base import utility_creator
from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
from langflow.utils.util import get_base_classes
@ -89,7 +88,7 @@ def process_type(field_type: str):
# TODO: Move to correct place
def add_new_custom_field(
template,
frontend_node: CustomComponentFrontendNode,
field_name: str,
field_type: str,
field_value: Any,
@ -115,8 +114,6 @@ def add_new_custom_field(
if "name" in field_config:
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
field_config.pop("name", None)
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -131,10 +128,11 @@ def add_new_custom_field(
display_name=display_name,
**sanitize_field_config(field_config),
)
template.get("template")[field_name] = new_field.model_dump(by_alias=True, exclude_none=True)
template.get("custom_fields")[field_name] = None
frontend_node.template.upsert_field(field_name, new_field)
if isinstance(frontend_node.custom_fields, dict):
frontend_node.custom_fields[field_name] = None
return template
return frontend_node
def sanitize_field_config(field_config: Dict):
@ -145,27 +143,22 @@ def sanitize_field_config(field_config: Dict):
# TODO: Move to correct place
def add_code_field(template, raw_code, field_config):
# Field with the Python code to allow update
def add_code_field(frontend_node: CustomComponentFrontendNode, raw_code, field_config):
code_field = TemplateField(
dynamic=True,
required=True,
placeholder="",
multiline=True,
value=raw_code,
password=False,
name="code",
advanced=field_config.pop("advanced", False),
field_type="code",
is_list=False,
)
frontend_node.template.add_field(code_field)
code_field = {
"code": {
"dynamic": True,
"required": True,
"placeholder": "",
"show": field_config.pop("show", True),
"multiline": True,
"value": raw_code,
"password": False,
"name": "code",
"advanced": field_config.pop("advanced", False),
"type": "code",
"list": False,
}
}
template.get("template")["code"] = code_field.get("code")
return template
return frontend_node
def extract_type_from_optional(field_type):
@ -182,28 +175,30 @@ def extract_type_from_optional(field_type):
return match[1] if match else None
def build_frontend_node(custom_component: CustomComponent):
def build_frontend_node(template_config):
"""Build a frontend node for a custom component"""
try:
return CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
sanitized_template_config = sanitize_template_config(template_config)
return CustomComponentFrontendNode(**sanitized_template_config)
except Exception as exc:
logger.error(f"Error while building base frontend node: {exc}")
return None
raise exc
def update_attributes(frontend_node, template_config):
"""Update the display name and description of a frontend node"""
attributes = [
def sanitize_template_config(template_config):
"""Sanitize the template config"""
attributes = {
"display_name",
"description",
"beta",
"documentation",
"output_types",
]
for attribute in attributes:
if attribute in template_config:
frontend_node[attribute] = template_config[attribute]
}
for key in template_config.copy():
if key not in attributes:
template_config.pop(key, None)
return template_config
def build_field_config(
@ -318,7 +313,7 @@ def get_field_properties(extra_field):
return field_name, field_type, field_value, field_required
def add_base_classes(frontend_node, return_types: List[str]):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
@ -333,11 +328,10 @@ def add_base_classes(frontend_node, return_types: List[str]):
base_classes = get_base_classes(return_type_instance)
for base_class in base_classes:
if base_class not in CLASSES_TO_REMOVE:
frontend_node.get("base_classes").append(base_class)
frontend_node.add_base_class(base_class)
def add_output_types(frontend_node, return_types: List[str]):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
@ -355,7 +349,7 @@ def add_output_types(frontend_node, return_types: List[str]):
else:
return_type = str(return_type)
frontend_node.get("output_types").append(return_type)
frontend_node.add_output_type(return_type)
def build_custom_component_template(
@ -366,14 +360,10 @@ def build_custom_component_template(
"""Build a custom component template for the langchain"""
try:
logger.debug("Building custom component template")
frontend_node = build_frontend_node(custom_component)
frontend_node = build_frontend_node(custom_component.template_config)
if frontend_node is None:
return None
logger.debug("Built base frontend node")
template_config = custom_component.build_template_config
update_attributes(frontend_node, template_config)
logger.debug("Updated attributes")
field_config = build_field_config(custom_component, user_id=user_id, update_field=update_field)
logger.debug("Built field config")
@ -386,7 +376,7 @@ def build_custom_component_template(
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
logger.debug("Added base classes")
return frontend_node
return frontend_node.to_dict(add_name=False)
except Exception as exc:
if isinstance(exc, HTTPException):
raise exc

View file

@ -37,10 +37,10 @@ class TemplateField(BaseModel):
password: bool = False
"""Specifies if the field is a password. Defaults to False."""
options: Union[list[str], Callable] = None
options: Optional[Union[list[str], Callable]] = None
"""List of options for the field. Only used when is_list=True. Default is an empty list."""
name: str = ""
name: Optional[str] = None
"""Name of the field. Default is an empty string."""
display_name: Optional[str] = None
@ -61,7 +61,7 @@ class TemplateField(BaseModel):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
range_spec: Optional[RangeSpec] = Field(None, serialization_alias="rangeSpec")
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
def to_dict(self):

View file

@ -1,13 +1,14 @@
import re
from collections import defaultdict
from typing import ClassVar, Dict, List, Optional
from typing import ClassVar, Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_serializer, model_serializer
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE, FORCE_SHOW_FIELDS
from langflow.template.frontend_node.formatter import field_formatters
from langflow.template.template.base import Template
from langflow.utils import constants
from pydantic import BaseModel, Field, field_serializer, model_serializer
class FieldFormatters(BaseModel):
@ -43,7 +44,7 @@ class FrontendNode(BaseModel):
description: Optional[str] = None
base_classes: List[str]
name: str = ""
display_name: str = ""
display_name: Optional[str] = ""
documentation: str = ""
custom_fields: Optional[Dict] = defaultdict(list)
output_types: List[str] = []
@ -85,10 +86,12 @@ class FrontendNode(BaseModel):
return {name: result}
# For backwards compatibility
def to_dict(self) -> dict:
def to_dict(self, add_name=True) -> dict:
"""Returns a dict representation of the frontend node."""
return self.model_dump(by_alias=True, exclude_none=True)
dump = self.model_dump(by_alias=True, exclude_none=True)
if not add_name:
return dump.pop(self.name)
return dump
def add_extra_fields(self) -> None:
pass
@ -96,6 +99,20 @@ class FrontendNode(BaseModel):
def add_extra_base_classes(self) -> None:
pass
def add_base_class(self, base_class: Union[str, List[str]]) -> None:
"""Adds a base class to the frontend node."""
if isinstance(base_class, str):
self.base_classes.append(base_class)
elif isinstance(base_class, list):
self.base_classes.extend(base_class)
def add_output_type(self, output_type: Union[str, List[str]]) -> None:
"""Adds an output type to the frontend node."""
if isinstance(output_type, str):
self.output_types.append(output_type)
elif isinstance(output_type, list):
self.output_types.extend(output_type)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
"""Formats a given field based on its attributes and value."""
@ -184,7 +201,8 @@ class FrontendNode(BaseModel):
@staticmethod
def handle_kwargs_field(field: TemplateField) -> None:
"""Handles kwargs field by setting certain attributes."""
if "kwargs" in field.name.lower():
if "kwargs" in (field.name or "").lower():
field.advanced = True
field.required = False
field.show = False

View file

@ -48,16 +48,16 @@ class ChainFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if "name" == "RetrievalQA" and field.name == "memory":
key = field.name or ""
if "name" == "RetrievalQA" and key == "memory":
field.show = False
field.required = False
field.advanced = False
if "key" in field.name:
if "key" in key:
field.password = False
field.show = False
if field.name in ["input_key", "output_key"]:
if key in ["input_key", "output_key"]:
field.required = True
field.show = True
field.advanced = True
@ -71,26 +71,26 @@ class ChainFrontendNode(FrontendNode):
# field.value = field.value.template
# Separated for possible future changes
if field.name == "prompt" and field.value is None:
if key == "prompt" and field.value is None:
field.required = True
field.show = True
field.advanced = False
if field.name == "memory":
if key == "memory":
# field.required = False
field.show = True
field.advanced = False
if field.name == "verbose":
if key == "verbose":
field.required = False
field.show = False
field.advanced = True
if field.name == "llm":
if key == "llm":
field.required = True
field.show = True
field.advanced = False
field.field_type = "BaseLanguageModel" # temporary fix
field.is_list = False
if field.name == "return_source_documents":
if key == "return_source_documents":
field.required = False
field.show = True
field.advanced = True

View file

@ -22,13 +22,14 @@ class EmbeddingFrontendNode(FrontendNode):
@staticmethod
def format_vertex_field(field: TemplateField, name: str):
if "VertexAI" in name:
key = field.name or ""
advanced_fields = [
"verbose",
"top_p",
"top_k",
"max_output_tokens",
]
if field.name in advanced_fields:
if key in advanced_fields:
field.advanced = True
show_fields = [
"verbose",
@ -42,21 +43,22 @@ class EmbeddingFrontendNode(FrontendNode):
"top_k",
]
if field.name in show_fields:
if key in show_fields:
field.show = True
@staticmethod
def format_jina_fields(field: TemplateField):
if "jina" in field.name:
name = field.name or ""
if "jina" in name:
field.show = True
field.advanced = False
if "auth" in field.name or "token" in field.name:
if "auth" in name or "token" in name:
field.password = True
field.show = True
field.advanced = False
if field.name == "jina_api_url":
if name == "jina_api_url":
field.show = True
field.advanced = True
field.display_name = "Jina API URL"
@ -64,14 +66,15 @@ class EmbeddingFrontendNode(FrontendNode):
@staticmethod
def format_openai_fields(field: TemplateField):
if "openai" in field.name:
name = field.name or ""
if "openai" in name:
field.show = True
field.advanced = True
split_name = field.name.split("_")
split_name = name.split("_")
title_name = " ".join([s.capitalize() for s in split_name])
field.display_name = title_name.replace("Openai", "OpenAI").replace("Api", "API")
if "api_key" in field.name:
if "api_key" in name:
field.password = True
field.show = True
field.advanced = False
@ -83,13 +86,14 @@ class EmbeddingFrontendNode(FrontendNode):
EmbeddingFrontendNode.format_vertex_field(field, name)
field.advanced = not field.required
field.show = True
if field.name == "headers":
key = field.name or ""
if key == "headers":
field.show = False
if field.name == "model_kwargs":
if key == "model_kwargs":
field.field_type = "dict"
field.advanced = True
field.show = True
elif field.name in [
elif key in [
"model_name",
"temperature",
"model_file",
@ -99,9 +103,9 @@ class EmbeddingFrontendNode(FrontendNode):
]:
field.advanced = False
field.show = True
if field.name == "credentials":
if key == "credentials":
field.field_type = "file"
if name == "VertexAI" and field.name not in [
if name == "VertexAI" and key not in [
"callbacks",
"client",
"stop",

View file

@ -9,7 +9,7 @@ from langflow.utils.constants import ANTHROPIC_MODELS, CHAT_OPENAI_MODELS, OPENA
class OpenAIAPIKeyFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "api_key" in field.name and "OpenAI" in str(name):
if field.name and "api_key" in field.name and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False
if field.value is None:
@ -25,14 +25,14 @@ class ModelSpecificFieldFormatter(FieldFormatter):
}
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if name in self.MODEL_DICT and field.name == "model_name":
if field.name and name in self.MODEL_DICT and field.name == "model_name":
field.options = self.MODEL_DICT[name]
field.is_list = True
class KwargsFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "kwargs" in field.name.lower():
if field.name and "kwargs" in field.name.lower():
field.advanced = True
field.required = False
field.show = False
@ -40,11 +40,11 @@ class KwargsFormatter(FieldFormatter):
class APIKeyFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
if "api" in field.name.lower() and "key" in field.name.lower():
if field.name and "api" in field.name.lower() and "key" in field.name.lower():
field.required = False
field.advanced = False
field.display_name = field.name.replace("_", " ").title()
field.display_name = (field.name or "").replace("_", " ").title()
field.display_name = field.display_name.replace("Api", "API")
@ -94,7 +94,7 @@ class SpecialFieldFormatter(FieldFormatter):
class ShowFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
required = field.required
field.show = (
(required and key not in ["input_variables"])
@ -106,7 +106,7 @@ class ShowFieldFormatter(FieldFormatter):
class PasswordFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
show = field.show
if any(text in key.lower() for text in {"password", "token", "api", "key"}) and show:
field.password = True
@ -114,7 +114,7 @@ class PasswordFieldFormatter(FieldFormatter):
class MultilineFieldFormatter(FieldFormatter):
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
key = field.name or ""
if key in {
"suffix",
"prefix",

View file

@ -1,10 +1,9 @@
from typing import Optional
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.base import orjson_dumps
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG
from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG, OPENAI_API_BASE_INFO
class LLMFrontendNode(FrontendNode):
@ -24,6 +23,7 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_vertex_field(field: TemplateField, name: str):
key = field.name or ""
if "VertexAI" in name:
advanced_fields = [
"tuned_model_name",
@ -32,7 +32,7 @@ class LLMFrontendNode(FrontendNode):
"top_k",
"max_output_tokens",
]
if field.name in advanced_fields:
if key in advanced_fields:
field.advanced = True
show_fields = [
"tuned_model_name",
@ -47,20 +47,19 @@ class LLMFrontendNode(FrontendNode):
"top_k",
]
if field.name in show_fields:
if key in show_fields:
field.show = True
@staticmethod
def format_openai_field(field: TemplateField):
if "openai" in field.name.lower():
field.display_name = (field.name.title().replace("Openai", "OpenAI").replace("_", " ")).replace(
"Api", "API"
)
key = field.name or ""
if "openai" in key.lower():
field.display_name = (key.title().replace("Openai", "OpenAI").replace("_", " ")).replace("Api", "API")
if "key" not in field.name.lower() and "token" not in field.name.lower():
if "key" not in key.lower() and "token" not in key.lower():
field.password = False
if field.name == "openai_api_base":
if key == "openai_api_base":
field.info = OPENAI_API_BASE_INFO
def add_extra_base_classes(self) -> None:
@ -69,13 +68,14 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_azure_field(field: TemplateField):
if field.name == "model_name":
key = field.name or ""
if key == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
elif field.name == "openai_api_type":
elif key == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
elif field.name == "openai_api_version":
elif key == "openai_api_version":
field.password = False
@staticmethod
@ -85,7 +85,8 @@ class LLMFrontendNode(FrontendNode):
@staticmethod
def format_ctransformers_field(field: TemplateField):
if field.name == "config":
key = field.name or ""
if key == "config":
field.show = True
field.advanced = True
field.value = orjson_dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent_2=True)
@ -105,10 +106,11 @@ class LLMFrontendNode(FrontendNode):
if name and "vertex" in name.lower():
LLMFrontendNode.format_vertex_field(field, name)
SHOW_FIELDS = ["repo_id"]
if field.name in SHOW_FIELDS:
key = field.name or ""
if key in SHOW_FIELDS:
field.show = True
if "api" in field.name and ("key" in field.name or ("token" in field.name and "tokens" not in field.name)):
if "api" in key and ("key" in key or ("token" in key and "tokens" not in key)):
field.password = True
field.show = True
# Required should be False to support
@ -116,7 +118,7 @@ class LLMFrontendNode(FrontendNode):
field.required = False
field.advanced = False
if field.name == "task":
if key == "task":
field.required = True
field.show = True
field.is_list = True
@ -124,13 +126,13 @@ class LLMFrontendNode(FrontendNode):
field.value = field.options[0]
field.advanced = True
if display_name := display_names_dict.get(field.name):
if display_name := display_names_dict.get(key):
field.display_name = display_name
if field.name == "model_kwargs":
if key == "model_kwargs":
field.field_type = "dict"
field.advanced = True
field.show = True
elif field.name in [
elif key in [
"model_name",
"temperature",
"model_file",
@ -140,9 +142,9 @@ class LLMFrontendNode(FrontendNode):
]:
field.advanced = False
field.show = True
if field.name == "credentials":
if key == "credentials":
field.field_type = "file"
if name == "VertexAI" and field.name not in [
if name == "VertexAI" and key not in [
"callbacks",
"client",
"stop",

View file

@ -1,6 +1,7 @@
from typing import Optional
from langchain.agents.mrkl import prompt
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import DEFAULT_PROMPT, HUMAN_PROMPT, SYSTEM_PROMPT
@ -20,14 +21,15 @@ class PromptFrontendNode(FrontendNode):
"examples",
"format_instructions",
]
key = field.name or ""
if field.field_type == "StringPromptTemplate" and "Message" in str(name):
field.field_type = "prompt"
field.multiline = True
field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT
if field.name == "template" and field.value == "":
field.value = HUMAN_PROMPT if "Human" in key else SYSTEM_PROMPT
if key == "template" and field.value == "":
field.value = DEFAULT_PROMPT
if field.name in PROMPT_FIELDS:
if key and key in PROMPT_FIELDS:
field.field_type = "prompt"
field.advanced = False
@ -48,7 +50,8 @@ class PromptTemplateNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if field.name == "examples":
if (field.name or "") == "examples":
field.advanced = False

View file

@ -3,7 +3,6 @@ from typing import List, Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
BASIC_FIELDS = [
"work_dir",
"collection_name",
@ -313,7 +312,7 @@ class VectorStoreFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.is_list = True
elif "embedding" in field.name:
elif field.name and "embedding" in field.name:
# for backwards compatibility
field.name = "embedding"
field.required = True

View file

@ -1,8 +1,9 @@
from typing import Callable, Union
from pydantic import BaseModel, model_serializer
from langflow.template.field.base import TemplateField
from langflow.utils.constants import DIRECT_TYPES
from pydantic import BaseModel, model_serializer
class Template(BaseModel):
@ -39,3 +40,25 @@ class Template(BaseModel):
def add_field(self, field: TemplateField) -> None:
self.fields.append(field)
def get_field(self, field_name: str) -> TemplateField:
"""Returns the field with the given name."""
field = next((field for field in self.fields if field.name == field_name), None)
if field is None:
raise ValueError(f"Field {field_name} not found in template {self.type_name}")
return field
def update_field(self, field_name: str, field: TemplateField) -> None:
"""Updates the field with the given name."""
for idx, template_field in enumerate(self.fields):
if template_field.name == field_name:
self.fields[idx] = field
return
raise ValueError(f"Field {field_name} not found in template {self.type_name}")
def upsert_field(self, field_name: str, field: TemplateField) -> None:
"""Updates the field with the given name or adds it if it doesn't exist."""
try:
self.update_field(field_name, field)
except ValueError:
self.add_field(field)

View file

@ -1,8 +1,9 @@
from pathlib import Path
from tempfile import tempdir
from langflow.__main__ import app
import pytest
from langflow.__main__ import app
from langflow.services import deps
@ -33,4 +34,5 @@ def test_components_path(runner, client, default_settings):
def test_superuser(runner, client, session):
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
assert result.exit_code == 0, result.stdout
assert "Superuser created successfully." in result.stdout
assert "Superuser creation failed." not in result.output, result.output
assert "Superuser created successfully." in result.output, result.output

View file

@ -120,7 +120,7 @@ def test_custom_component_build_template_config():
Test the build_template_config property of the CustomComponent class.
"""
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
config = custom_component.build_template_config
config = custom_component.template_config
assert isinstance(config, dict)