Add RangeSpec class and functionality in the UI (#1192)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-10 16:53:00 -03:00 committed by GitHub
commit 2703dd06d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 88 additions and 8 deletions

View file

@ -73,7 +73,7 @@ The CustomComponent class serves as the foundation for creating custom component
| _`required: bool`_ | Makes the field required. |
| _`info: str`_ | Adds a tooltip to the field. |
| _`file_types: List[str]`_ | This is a requirement if the _`field_type`_ is _file_. Defines which file types will be accepted. For example, _json_, _yaml_ or _yml_. |
| _`range_spec: langflow.field_typing.RangeSpec`_ | This is a requirement if the _`field_type`_ is _`float`_. Defines the range of values accepted and the step size. If none is defined, the default is _`[-1, 1, 0.1]`_. |
- The CustomComponent class also provides helpful methods for specific tasks (e.g., to load and use other flows from the Langflow platform):
| Method Name | Description |

View file

@ -24,6 +24,7 @@ from .constants import (
Tool,
VectorStore,
)
from .range_spec import RangeSpec
__all__ = [
"NestedDict",
@ -48,5 +49,6 @@ __all__ = [
"BasePromptTemplate",
"ChatPromptTemplate",
"Prompt",
"RangeSpec",
"TemplateField",
]

View file

@ -0,0 +1,21 @@
from pydantic import BaseModel, field_validator
class RangeSpec(BaseModel):
min: float = -1.0
max: float = 1.0
step: float = 0.1
@field_validator("max")
@classmethod
def max_must_be_greater_than_min(cls, v, values, **kwargs):
if "min" in values.data and v <= values.data["min"]:
raise ValueError("max must be greater than min")
return v
@field_validator("step")
@classmethod
def step_must_be_positive(cls, v):
if v <= 0:
raise ValueError("step must be positive")
return v

View file

@ -8,6 +8,9 @@ from uuid import UUID
from cachetools import LRUCache, cached
from fastapi import HTTPException
from loguru import logger
from langflow.field_typing.range_spec import RangeSpec
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.custom.custom_component import CustomComponent
@ -31,7 +34,6 @@ from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
from langflow.utils.util import get_base_classes
from loguru import logger
# Used to get the base_classes list
@ -272,6 +274,10 @@ def update_field_dict(field_dict):
field_dict["value"] = field_dict["value"](field_dict.get("options", []))
field_dict["refresh"] = True
# Let's check if "range_spec" is a RangeSpec object
if "rangeSpec" in field_dict and isinstance(field_dict["rangeSpec"], RangeSpec):
field_dict["rangeSpec"] = field_dict["rangeSpec"].model_dump()
def add_extra_fields(frontend_node, field_config, function_args):
"""Add extra fields to the frontend node"""

View file

@ -3,6 +3,8 @@ from typing import Any, Callable, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from langflow.field_typing.range_spec import RangeSpec
class TemplateField(BaseModel):
model_config = ConfigDict()
@ -60,6 +62,9 @@ class TemplateField(BaseModel):
refresh: Optional[bool] = None
"""Specifies if the field should be refreshed. Defaults to False."""
range_spec: Optional[RangeSpec] = Field(None, serialization_alias="rangeSpec")
"""Range specification for the field. Defaults to None."""
def to_dict(self):
return self.model_dump(by_alias=True, exclude_none=True)
@ -68,3 +73,11 @@ class TemplateField(BaseModel):
if self.field_type == "file":
return value
return ""
@field_serializer("field_type")
def serialize_field_type(self, value, _info):
if value == "float":
# check if range_spec is set
if self.range_spec is None:
self.range_spec = RangeSpec()
return value

View file

@ -457,6 +457,7 @@ export default function ParameterComponent({
<FloatComponent
disabled={disabled}
value={data.node?.template[name].value ?? ""}
rangeSpec={data.node?.template[name].rangeSpec}
onChange={handleOnNewValue}
/>
</div>

View file

@ -472,6 +472,11 @@ export default function CodeTabsComponent({
templateField
].value
}
rangeSpec={
node.data.node.template[
templateField
].rangeSpec
}
onChange={(target) => {
setData((old) => {
let newInputList =

View file

@ -7,11 +7,14 @@ export default function FloatComponent({
value,
onChange,
disabled,
rangeSpec,
editNode = false,
}: FloatComponentType): JSX.Element {
const step = 0.1;
const min = -2;
const max = 2;
const step = rangeSpec?.step ?? 0.1;
const min = rangeSpec?.min ?? -2;
const max = rangeSpec?.max ?? 2;
console.log("FloatComponent", value, disabled, rangeSpec, editNode);
console.log("FloatComponent", step, min, max);
// Clear component state
useEffect(() => {
@ -40,7 +43,9 @@ export default function FloatComponent({
disabled={disabled}
className={editNode ? "input-edit-node" : ""}
placeholder={
editNode ? "Number -2 to 2" : "Type a number from minus two to two"
editNode
? `Enter a value between ${min} and ${max}`
: `Enter a value between ${min} and ${max}`
}
onChange={(event) => {
onChange(event.target.value);

View file

@ -1,5 +1,5 @@
import { useEffect } from "react";
import { FloatComponentType } from "../../types/components";
import { IntComponentType } from "../../types/components";
import { handleKeyDown } from "../../utils/reactflowUtils";
import { Input } from "../ui/input";
@ -9,7 +9,7 @@ export default function IntComponent({
disabled,
editNode = false,
id = "",
}: FloatComponentType): JSX.Element {
}: IntComponentType): JSX.Element {
const min = 0;
// Clear component state

View file

@ -360,6 +360,10 @@ const EditNodeModal = forwardRef(
<FloatComponent
disabled={disabled}
editNode={true}
rangeSpec={
myData.node!.template[templateParam]
.rangeSpec
}
value={
myData.node.template[templateParam]
.value ?? ""

View file

@ -138,10 +138,26 @@ export type DisclosureComponentType = {
}[];
};
};
export type RangeSpecType = {
min: number;
max: number;
step: number;
};
export type IntComponentType = {
value: string;
disabled?: boolean;
onChange: (value: string) => void;
editNode?: boolean;
id?: string;
};
export type FloatComponentType = {
value: string;
disabled?: boolean;
onChange: (value: string) => void;
rangeSpec: RangeSpecType;
editNode?: boolean;
id?: string;
};

View file

@ -88,6 +88,7 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["max_tokens"] == {
@ -118,6 +119,7 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["frequency_penalty"] == {
@ -133,6 +135,7 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["presence_penalty"] == {
@ -148,6 +151,7 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["n"] == {
@ -237,6 +241,7 @@ def test_openai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["logit_bias"] == {
@ -359,6 +364,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["model_kwargs"] == {
@ -403,6 +409,7 @@ def test_chat_open_ai(client: TestClient, logged_in_headers):
"list": False,
"advanced": False,
"info": "",
"rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1},
"fileTypes": [],
}
assert template["max_retries"] == {