Fix various issues and refactor codebase (#1196)
This commit is contained in:
commit
bf114172b9
19 changed files with 532 additions and 476 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -166,6 +166,7 @@ coverage.xml
|
|||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.testmondata*
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
|
|
|||
2
Makefile
2
Makefile
|
|
@ -20,7 +20,7 @@ coverage:
|
|||
|
||||
tests:
|
||||
@make install_backend
|
||||
poetry run pytest tests
|
||||
poetry run pytest tests --instafail
|
||||
|
||||
tests_frontend:
|
||||
ifeq ($(UI), true)
|
||||
|
|
|
|||
693
poetry.lock
generated
693
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -125,6 +125,7 @@ pytest-xdist = "^3.3.1"
|
|||
types-pywin32 = "^306.0.0.4"
|
||||
types-google-cloud-ndb = "^2.2.0.0"
|
||||
pytest-sugar = "^0.9.7"
|
||||
pytest-instafail = "^0.5.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
|||
|
|
@ -9,17 +9,19 @@ from typing import Optional
|
|||
import httpx
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.utils.logger import configure, logger
|
||||
from multiprocess import Process, cpu_count # type: ignore
|
||||
from rich import box
|
||||
from rich import print as rprint
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.utils.logger import configure, logger
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -338,7 +340,7 @@ def superuser(
|
|||
# Verify that the superuser was created
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
user: User = session.query(User).filter(User.username == username).first()
|
||||
user: User = session.exec(select(User).where(User.username == username)).first()
|
||||
if user is None or not user.is_superuser:
|
||||
typer.echo("Superuser creation failed.")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.llms.base import BaseLLM
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class QianfanChatEndpointComponent(CustomComponent):
|
||||
|
|
@ -80,8 +83,8 @@ class QianfanChatEndpointComponent(CustomComponent):
|
|||
try:
|
||||
output = QianfanChatEndpoint( # type: ignore
|
||||
model=model,
|
||||
qianfan_ak=qianfan_ak,
|
||||
qianfan_sk=qianfan_sk,
|
||||
qianfan_ak=SecretStr(qianfan_ak) if qianfan_ak else None,
|
||||
qianfan_sk=SecretStr(qianfan_sk) if qianfan_sk else None,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
penalty_score=penalty_score,
|
||||
|
|
|
|||
|
|
@ -165,6 +165,9 @@ class CustomComponent(Component):
|
|||
return next(iter(classes), "")
|
||||
|
||||
@property
|
||||
def template_config(self):
|
||||
return self.build_template_config()
|
||||
|
||||
def build_template_config(self):
|
||||
if not self.code:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from langflow.interface.utilities.base import utility_creator
|
|||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
|
||||
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
|
|
@ -89,7 +88,7 @@ def process_type(field_type: str):
|
|||
|
||||
# TODO: Move to correct place
|
||||
def add_new_custom_field(
|
||||
template,
|
||||
frontend_node: CustomComponentFrontendNode,
|
||||
field_name: str,
|
||||
field_type: str,
|
||||
field_value: Any,
|
||||
|
|
@ -115,8 +114,6 @@ def add_new_custom_field(
|
|||
|
||||
if "name" in field_config:
|
||||
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
|
||||
field_config.pop("name", None)
|
||||
|
||||
required = field_config.pop("required", field_required)
|
||||
placeholder = field_config.pop("placeholder", "")
|
||||
|
||||
|
|
@ -131,10 +128,11 @@ def add_new_custom_field(
|
|||
display_name=display_name,
|
||||
**sanitize_field_config(field_config),
|
||||
)
|
||||
template.get("template")[field_name] = new_field.model_dump(by_alias=True, exclude_none=True)
|
||||
template.get("custom_fields")[field_name] = None
|
||||
frontend_node.template.upsert_field(field_name, new_field)
|
||||
if isinstance(frontend_node.custom_fields, dict):
|
||||
frontend_node.custom_fields[field_name] = None
|
||||
|
||||
return template
|
||||
return frontend_node
|
||||
|
||||
|
||||
def sanitize_field_config(field_config: Dict):
|
||||
|
|
@ -145,27 +143,22 @@ def sanitize_field_config(field_config: Dict):
|
|||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_code_field(template, raw_code, field_config):
|
||||
# Field with the Python code to allow update
|
||||
def add_code_field(frontend_node: CustomComponentFrontendNode, raw_code, field_config):
|
||||
code_field = TemplateField(
|
||||
dynamic=True,
|
||||
required=True,
|
||||
placeholder="",
|
||||
multiline=True,
|
||||
value=raw_code,
|
||||
password=False,
|
||||
name="code",
|
||||
advanced=field_config.pop("advanced", False),
|
||||
field_type="code",
|
||||
is_list=False,
|
||||
)
|
||||
frontend_node.template.add_field(code_field)
|
||||
|
||||
code_field = {
|
||||
"code": {
|
||||
"dynamic": True,
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"show": field_config.pop("show", True),
|
||||
"multiline": True,
|
||||
"value": raw_code,
|
||||
"password": False,
|
||||
"name": "code",
|
||||
"advanced": field_config.pop("advanced", False),
|
||||
"type": "code",
|
||||
"list": False,
|
||||
}
|
||||
}
|
||||
template.get("template")["code"] = code_field.get("code")
|
||||
|
||||
return template
|
||||
return frontend_node
|
||||
|
||||
|
||||
def extract_type_from_optional(field_type):
|
||||
|
|
@ -182,28 +175,30 @@ def extract_type_from_optional(field_type):
|
|||
return match[1] if match else None
|
||||
|
||||
|
||||
def build_frontend_node(custom_component: CustomComponent):
|
||||
def build_frontend_node(template_config):
|
||||
"""Build a frontend node for a custom component"""
|
||||
try:
|
||||
return CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
|
||||
|
||||
sanitized_template_config = sanitize_template_config(template_config)
|
||||
return CustomComponentFrontendNode(**sanitized_template_config)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building base frontend node: {exc}")
|
||||
return None
|
||||
raise exc
|
||||
|
||||
|
||||
def update_attributes(frontend_node, template_config):
|
||||
"""Update the display name and description of a frontend node"""
|
||||
attributes = [
|
||||
def sanitize_template_config(template_config):
|
||||
"""Sanitize the template config"""
|
||||
attributes = {
|
||||
"display_name",
|
||||
"description",
|
||||
"beta",
|
||||
"documentation",
|
||||
"output_types",
|
||||
]
|
||||
for attribute in attributes:
|
||||
if attribute in template_config:
|
||||
frontend_node[attribute] = template_config[attribute]
|
||||
}
|
||||
for key in template_config.copy():
|
||||
if key not in attributes:
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
|
||||
def build_field_config(
|
||||
|
|
@ -318,7 +313,7 @@ def get_field_properties(extra_field):
|
|||
return field_name, field_type, field_value, field_required
|
||||
|
||||
|
||||
def add_base_classes(frontend_node, return_types: List[str]):
|
||||
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
|
||||
"""Add base classes to the frontend node"""
|
||||
for return_type_instance in return_types:
|
||||
if return_type_instance is None:
|
||||
|
|
@ -333,11 +328,10 @@ def add_base_classes(frontend_node, return_types: List[str]):
|
|||
base_classes = get_base_classes(return_type_instance)
|
||||
|
||||
for base_class in base_classes:
|
||||
if base_class not in CLASSES_TO_REMOVE:
|
||||
frontend_node.get("base_classes").append(base_class)
|
||||
frontend_node.add_base_class(base_class)
|
||||
|
||||
|
||||
def add_output_types(frontend_node, return_types: List[str]):
|
||||
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
|
||||
"""Add output types to the frontend node"""
|
||||
for return_type in return_types:
|
||||
if return_type is None:
|
||||
|
|
@ -355,7 +349,7 @@ def add_output_types(frontend_node, return_types: List[str]):
|
|||
else:
|
||||
return_type = str(return_type)
|
||||
|
||||
frontend_node.get("output_types").append(return_type)
|
||||
frontend_node.add_output_type(return_type)
|
||||
|
||||
|
||||
def build_custom_component_template(
|
||||
|
|
@ -366,14 +360,10 @@ def build_custom_component_template(
|
|||
"""Build a custom component template for the langchain"""
|
||||
try:
|
||||
logger.debug("Building custom component template")
|
||||
frontend_node = build_frontend_node(custom_component)
|
||||
frontend_node = build_frontend_node(custom_component.template_config)
|
||||
|
||||
if frontend_node is None:
|
||||
return None
|
||||
logger.debug("Built base frontend node")
|
||||
template_config = custom_component.build_template_config
|
||||
|
||||
update_attributes(frontend_node, template_config)
|
||||
logger.debug("Updated attributes")
|
||||
field_config = build_field_config(custom_component, user_id=user_id, update_field=update_field)
|
||||
logger.debug("Built field config")
|
||||
|
|
@ -386,7 +376,7 @@ def build_custom_component_template(
|
|||
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
|
||||
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
|
||||
logger.debug("Added base classes")
|
||||
return frontend_node
|
||||
return frontend_node.to_dict(add_name=False)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, HTTPException):
|
||||
raise exc
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ class TemplateField(BaseModel):
|
|||
password: bool = False
|
||||
"""Specifies if the field is a password. Defaults to False."""
|
||||
|
||||
options: Union[list[str], Callable] = None
|
||||
options: Optional[Union[list[str], Callable]] = None
|
||||
"""List of options for the field. Only used when is_list=True. Default is an empty list."""
|
||||
|
||||
name: str = ""
|
||||
name: Optional[str] = None
|
||||
"""Name of the field. Default is an empty string."""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
|
|
@ -61,7 +61,7 @@ class TemplateField(BaseModel):
|
|||
refresh: Optional[bool] = None
|
||||
"""Specifies if the field should be refreshed. Defaults to False."""
|
||||
|
||||
range_spec: Optional[RangeSpec] = Field(None, serialization_alias="rangeSpec")
|
||||
range_spec: Optional[RangeSpec] = Field(default=None, serialization_alias="rangeSpec")
|
||||
"""Range specification for the field. Defaults to None."""
|
||||
|
||||
def to_dict(self):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
from typing import ClassVar, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, model_serializer
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE, FORCE_SHOW_FIELDS
|
||||
from langflow.template.frontend_node.formatter import field_formatters
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils import constants
|
||||
from pydantic import BaseModel, Field, field_serializer, model_serializer
|
||||
|
||||
|
||||
class FieldFormatters(BaseModel):
|
||||
|
|
@ -43,7 +44,7 @@ class FrontendNode(BaseModel):
|
|||
description: Optional[str] = None
|
||||
base_classes: List[str]
|
||||
name: str = ""
|
||||
display_name: str = ""
|
||||
display_name: Optional[str] = ""
|
||||
documentation: str = ""
|
||||
custom_fields: Optional[Dict] = defaultdict(list)
|
||||
output_types: List[str] = []
|
||||
|
|
@ -85,10 +86,12 @@ class FrontendNode(BaseModel):
|
|||
return {name: result}
|
||||
|
||||
# For backwards compatibility
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self, add_name=True) -> dict:
|
||||
"""Returns a dict representation of the frontend node."""
|
||||
|
||||
return self.model_dump(by_alias=True, exclude_none=True)
|
||||
dump = self.model_dump(by_alias=True, exclude_none=True)
|
||||
if not add_name:
|
||||
return dump.pop(self.name)
|
||||
return dump
|
||||
|
||||
def add_extra_fields(self) -> None:
|
||||
pass
|
||||
|
|
@ -96,6 +99,20 @@ class FrontendNode(BaseModel):
|
|||
def add_extra_base_classes(self) -> None:
|
||||
pass
|
||||
|
||||
def add_base_class(self, base_class: Union[str, List[str]]) -> None:
|
||||
"""Adds a base class to the frontend node."""
|
||||
if isinstance(base_class, str):
|
||||
self.base_classes.append(base_class)
|
||||
elif isinstance(base_class, list):
|
||||
self.base_classes.extend(base_class)
|
||||
|
||||
def add_output_type(self, output_type: Union[str, List[str]]) -> None:
|
||||
"""Adds an output type to the frontend node."""
|
||||
if isinstance(output_type, str):
|
||||
self.output_types.append(output_type)
|
||||
elif isinstance(output_type, list):
|
||||
self.output_types.extend(output_type)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
|
|
@ -184,7 +201,8 @@ class FrontendNode(BaseModel):
|
|||
@staticmethod
|
||||
def handle_kwargs_field(field: TemplateField) -> None:
|
||||
"""Handles kwargs field by setting certain attributes."""
|
||||
if "kwargs" in field.name.lower():
|
||||
|
||||
if "kwargs" in (field.name or "").lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
|
|
|
|||
|
|
@ -48,16 +48,16 @@ class ChainFrontendNode(FrontendNode):
|
|||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
||||
if "name" == "RetrievalQA" and field.name == "memory":
|
||||
key = field.name or ""
|
||||
if "name" == "RetrievalQA" and key == "memory":
|
||||
field.show = False
|
||||
field.required = False
|
||||
|
||||
field.advanced = False
|
||||
if "key" in field.name:
|
||||
if "key" in key:
|
||||
field.password = False
|
||||
field.show = False
|
||||
if field.name in ["input_key", "output_key"]:
|
||||
if key in ["input_key", "output_key"]:
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
|
|
@ -71,26 +71,26 @@ class ChainFrontendNode(FrontendNode):
|
|||
# field.value = field.value.template
|
||||
|
||||
# Separated for possible future changes
|
||||
if field.name == "prompt" and field.value is None:
|
||||
if key == "prompt" and field.value is None:
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name == "memory":
|
||||
if key == "memory":
|
||||
# field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name == "verbose":
|
||||
if key == "verbose":
|
||||
field.required = False
|
||||
field.show = False
|
||||
field.advanced = True
|
||||
if field.name == "llm":
|
||||
if key == "llm":
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.field_type = "BaseLanguageModel" # temporary fix
|
||||
field.is_list = False
|
||||
|
||||
if field.name == "return_source_documents":
|
||||
if key == "return_source_documents":
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
|
|
|
|||
|
|
@ -22,13 +22,14 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
@staticmethod
|
||||
def format_vertex_field(field: TemplateField, name: str):
|
||||
if "VertexAI" in name:
|
||||
key = field.name or ""
|
||||
advanced_fields = [
|
||||
"verbose",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_output_tokens",
|
||||
]
|
||||
if field.name in advanced_fields:
|
||||
if key in advanced_fields:
|
||||
field.advanced = True
|
||||
show_fields = [
|
||||
"verbose",
|
||||
|
|
@ -42,21 +43,22 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
"top_k",
|
||||
]
|
||||
|
||||
if field.name in show_fields:
|
||||
if key in show_fields:
|
||||
field.show = True
|
||||
|
||||
@staticmethod
|
||||
def format_jina_fields(field: TemplateField):
|
||||
if "jina" in field.name:
|
||||
name = field.name or ""
|
||||
if "jina" in name:
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
if "auth" in field.name or "token" in field.name:
|
||||
if "auth" in name or "token" in name:
|
||||
field.password = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
if field.name == "jina_api_url":
|
||||
if name == "jina_api_url":
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
field.display_name = "Jina API URL"
|
||||
|
|
@ -64,14 +66,15 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
|
||||
@staticmethod
|
||||
def format_openai_fields(field: TemplateField):
|
||||
if "openai" in field.name:
|
||||
name = field.name or ""
|
||||
if "openai" in name:
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
split_name = field.name.split("_")
|
||||
split_name = name.split("_")
|
||||
title_name = " ".join([s.capitalize() for s in split_name])
|
||||
field.display_name = title_name.replace("Openai", "OpenAI").replace("Api", "API")
|
||||
|
||||
if "api_key" in field.name:
|
||||
if "api_key" in name:
|
||||
field.password = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
|
@ -83,13 +86,14 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
EmbeddingFrontendNode.format_vertex_field(field, name)
|
||||
field.advanced = not field.required
|
||||
field.show = True
|
||||
if field.name == "headers":
|
||||
key = field.name or ""
|
||||
if key == "headers":
|
||||
field.show = False
|
||||
if field.name == "model_kwargs":
|
||||
if key == "model_kwargs":
|
||||
field.field_type = "dict"
|
||||
field.advanced = True
|
||||
field.show = True
|
||||
elif field.name in [
|
||||
elif key in [
|
||||
"model_name",
|
||||
"temperature",
|
||||
"model_file",
|
||||
|
|
@ -99,9 +103,9 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
]:
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
if field.name == "credentials":
|
||||
if key == "credentials":
|
||||
field.field_type = "file"
|
||||
if name == "VertexAI" and field.name not in [
|
||||
if name == "VertexAI" and key not in [
|
||||
"callbacks",
|
||||
"client",
|
||||
"stop",
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from langflow.utils.constants import ANTHROPIC_MODELS, CHAT_OPENAI_MODELS, OPENA
|
|||
|
||||
class OpenAIAPIKeyFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "api_key" in field.name and "OpenAI" in str(name):
|
||||
if field.name and "api_key" in field.name and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
if field.value is None:
|
||||
|
|
@ -25,14 +25,14 @@ class ModelSpecificFieldFormatter(FieldFormatter):
|
|||
}
|
||||
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if name in self.MODEL_DICT and field.name == "model_name":
|
||||
if field.name and name in self.MODEL_DICT and field.name == "model_name":
|
||||
field.options = self.MODEL_DICT[name]
|
||||
field.is_list = True
|
||||
|
||||
|
||||
class KwargsFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "kwargs" in field.name.lower():
|
||||
if field.name and "kwargs" in field.name.lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
|
|
@ -40,11 +40,11 @@ class KwargsFormatter(FieldFormatter):
|
|||
|
||||
class APIKeyFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "api" in field.name.lower() and "key" in field.name.lower():
|
||||
if field.name and "api" in field.name.lower() and "key" in field.name.lower():
|
||||
field.required = False
|
||||
field.advanced = False
|
||||
|
||||
field.display_name = field.name.replace("_", " ").title()
|
||||
field.display_name = (field.name or "").replace("_", " ").title()
|
||||
field.display_name = field.display_name.replace("Api", "API")
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class SpecialFieldFormatter(FieldFormatter):
|
|||
|
||||
class ShowFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
key = field.name or ""
|
||||
required = field.required
|
||||
field.show = (
|
||||
(required and key not in ["input_variables"])
|
||||
|
|
@ -106,7 +106,7 @@ class ShowFieldFormatter(FieldFormatter):
|
|||
|
||||
class PasswordFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
key = field.name or ""
|
||||
show = field.show
|
||||
if any(text in key.lower() for text in {"password", "token", "api", "key"}) and show:
|
||||
field.password = True
|
||||
|
|
@ -114,7 +114,7 @@ class PasswordFieldFormatter(FieldFormatter):
|
|||
|
||||
class MultilineFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
key = field.name or ""
|
||||
if key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Optional
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG
|
||||
from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO
|
||||
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG, OPENAI_API_BASE_INFO
|
||||
|
||||
|
||||
class LLMFrontendNode(FrontendNode):
|
||||
|
|
@ -24,6 +23,7 @@ class LLMFrontendNode(FrontendNode):
|
|||
|
||||
@staticmethod
|
||||
def format_vertex_field(field: TemplateField, name: str):
|
||||
key = field.name or ""
|
||||
if "VertexAI" in name:
|
||||
advanced_fields = [
|
||||
"tuned_model_name",
|
||||
|
|
@ -32,7 +32,7 @@ class LLMFrontendNode(FrontendNode):
|
|||
"top_k",
|
||||
"max_output_tokens",
|
||||
]
|
||||
if field.name in advanced_fields:
|
||||
if key in advanced_fields:
|
||||
field.advanced = True
|
||||
show_fields = [
|
||||
"tuned_model_name",
|
||||
|
|
@ -47,20 +47,19 @@ class LLMFrontendNode(FrontendNode):
|
|||
"top_k",
|
||||
]
|
||||
|
||||
if field.name in show_fields:
|
||||
if key in show_fields:
|
||||
field.show = True
|
||||
|
||||
@staticmethod
|
||||
def format_openai_field(field: TemplateField):
|
||||
if "openai" in field.name.lower():
|
||||
field.display_name = (field.name.title().replace("Openai", "OpenAI").replace("_", " ")).replace(
|
||||
"Api", "API"
|
||||
)
|
||||
key = field.name or ""
|
||||
if "openai" in key.lower():
|
||||
field.display_name = (key.title().replace("Openai", "OpenAI").replace("_", " ")).replace("Api", "API")
|
||||
|
||||
if "key" not in field.name.lower() and "token" not in field.name.lower():
|
||||
if "key" not in key.lower() and "token" not in key.lower():
|
||||
field.password = False
|
||||
|
||||
if field.name == "openai_api_base":
|
||||
if key == "openai_api_base":
|
||||
field.info = OPENAI_API_BASE_INFO
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
|
|
@ -69,13 +68,14 @@ class LLMFrontendNode(FrontendNode):
|
|||
|
||||
@staticmethod
|
||||
def format_azure_field(field: TemplateField):
|
||||
if field.name == "model_name":
|
||||
key = field.name or ""
|
||||
if key == "model_name":
|
||||
field.show = False # Azure uses deployment_name instead of model_name.
|
||||
elif field.name == "openai_api_type":
|
||||
elif key == "openai_api_type":
|
||||
field.show = False
|
||||
field.password = False
|
||||
field.value = "azure"
|
||||
elif field.name == "openai_api_version":
|
||||
elif key == "openai_api_version":
|
||||
field.password = False
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -85,7 +85,8 @@ class LLMFrontendNode(FrontendNode):
|
|||
|
||||
@staticmethod
|
||||
def format_ctransformers_field(field: TemplateField):
|
||||
if field.name == "config":
|
||||
key = field.name or ""
|
||||
if key == "config":
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
field.value = orjson_dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent_2=True)
|
||||
|
|
@ -105,10 +106,11 @@ class LLMFrontendNode(FrontendNode):
|
|||
if name and "vertex" in name.lower():
|
||||
LLMFrontendNode.format_vertex_field(field, name)
|
||||
SHOW_FIELDS = ["repo_id"]
|
||||
if field.name in SHOW_FIELDS:
|
||||
key = field.name or ""
|
||||
if key in SHOW_FIELDS:
|
||||
field.show = True
|
||||
|
||||
if "api" in field.name and ("key" in field.name or ("token" in field.name and "tokens" not in field.name)):
|
||||
if "api" in key and ("key" in key or ("token" in key and "tokens" not in key)):
|
||||
field.password = True
|
||||
field.show = True
|
||||
# Required should be False to support
|
||||
|
|
@ -116,7 +118,7 @@ class LLMFrontendNode(FrontendNode):
|
|||
field.required = False
|
||||
field.advanced = False
|
||||
|
||||
if field.name == "task":
|
||||
if key == "task":
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.is_list = True
|
||||
|
|
@ -124,13 +126,13 @@ class LLMFrontendNode(FrontendNode):
|
|||
field.value = field.options[0]
|
||||
field.advanced = True
|
||||
|
||||
if display_name := display_names_dict.get(field.name):
|
||||
if display_name := display_names_dict.get(key):
|
||||
field.display_name = display_name
|
||||
if field.name == "model_kwargs":
|
||||
if key == "model_kwargs":
|
||||
field.field_type = "dict"
|
||||
field.advanced = True
|
||||
field.show = True
|
||||
elif field.name in [
|
||||
elif key in [
|
||||
"model_name",
|
||||
"temperature",
|
||||
"model_file",
|
||||
|
|
@ -140,9 +142,9 @@ class LLMFrontendNode(FrontendNode):
|
|||
]:
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
if field.name == "credentials":
|
||||
if key == "credentials":
|
||||
field.field_type = "file"
|
||||
if name == "VertexAI" and field.name not in [
|
||||
if name == "VertexAI" and key not in [
|
||||
"callbacks",
|
||||
"client",
|
||||
"stop",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.agents.mrkl import prompt
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import DEFAULT_PROMPT, HUMAN_PROMPT, SYSTEM_PROMPT
|
||||
|
|
@ -20,14 +21,15 @@ class PromptFrontendNode(FrontendNode):
|
|||
"examples",
|
||||
"format_instructions",
|
||||
]
|
||||
key = field.name or ""
|
||||
if field.field_type == "StringPromptTemplate" and "Message" in str(name):
|
||||
field.field_type = "prompt"
|
||||
field.multiline = True
|
||||
field.value = HUMAN_PROMPT if "Human" in field.name else SYSTEM_PROMPT
|
||||
if field.name == "template" and field.value == "":
|
||||
field.value = HUMAN_PROMPT if "Human" in key else SYSTEM_PROMPT
|
||||
if key == "template" and field.value == "":
|
||||
field.value = DEFAULT_PROMPT
|
||||
|
||||
if field.name in PROMPT_FIELDS:
|
||||
if key and key in PROMPT_FIELDS:
|
||||
field.field_type = "prompt"
|
||||
field.advanced = False
|
||||
|
||||
|
|
@ -48,7 +50,8 @@ class PromptTemplateNode(FrontendNode):
|
|||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
if field.name == "examples":
|
||||
|
||||
if (field.name or "") == "examples":
|
||||
field.advanced = False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import List, Optional
|
|||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
BASIC_FIELDS = [
|
||||
"work_dir",
|
||||
"collection_name",
|
||||
|
|
@ -313,7 +312,7 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
field.show = True
|
||||
field.advanced = False
|
||||
field.is_list = True
|
||||
elif "embedding" in field.name:
|
||||
elif field.name and "embedding" in field.name:
|
||||
# for backwards compatibility
|
||||
field.name = "embedding"
|
||||
field.required = True
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Callable, Union
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
from pydantic import BaseModel, model_serializer
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
|
|
@ -39,3 +40,25 @@ class Template(BaseModel):
|
|||
|
||||
def add_field(self, field: TemplateField) -> None:
|
||||
self.fields.append(field)
|
||||
|
||||
def get_field(self, field_name: str) -> TemplateField:
|
||||
"""Returns the field with the given name."""
|
||||
field = next((field for field in self.fields if field.name == field_name), None)
|
||||
if field is None:
|
||||
raise ValueError(f"Field {field_name} not found in template {self.type_name}")
|
||||
return field
|
||||
|
||||
def update_field(self, field_name: str, field: TemplateField) -> None:
|
||||
"""Updates the field with the given name."""
|
||||
for idx, template_field in enumerate(self.fields):
|
||||
if template_field.name == field_name:
|
||||
self.fields[idx] = field
|
||||
return
|
||||
raise ValueError(f"Field {field_name} not found in template {self.type_name}")
|
||||
|
||||
def upsert_field(self, field_name: str, field: TemplateField) -> None:
|
||||
"""Updates the field with the given name or adds it if it doesn't exist."""
|
||||
try:
|
||||
self.update_field(field_name, field)
|
||||
except ValueError:
|
||||
self.add_field(field)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from pathlib import Path
|
||||
from tempfile import tempdir
|
||||
from langflow.__main__ import app
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.__main__ import app
|
||||
from langflow.services import deps
|
||||
|
||||
|
||||
|
|
@ -33,4 +34,5 @@ def test_components_path(runner, client, default_settings):
|
|||
def test_superuser(runner, client, session):
|
||||
result = runner.invoke(app, ["superuser"], input="admin\nadmin\n")
|
||||
assert result.exit_code == 0, result.stdout
|
||||
assert "Superuser created successfully." in result.stdout
|
||||
assert "Superuser creation failed." not in result.output, result.output
|
||||
assert "Superuser created successfully." in result.output, result.output
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def test_custom_component_build_template_config():
|
|||
Test the build_template_config property of the CustomComponent class.
|
||||
"""
|
||||
custom_component = CustomComponent(code=code_default, function_entrypoint_name="build")
|
||||
config = custom_component.build_template_config
|
||||
config = custom_component.template_config
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue