Merge branch 'two_edges' of personal:langflow-ai/langflow into two_edges
This commit is contained in:
commit
3224227f58
9 changed files with 246 additions and 56 deletions
|
|
@ -70,14 +70,14 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
inputs = [
|
||||
Input(
|
||||
name="base_url",
|
||||
type=Optional[str],
|
||||
field_type=Optional[str],
|
||||
display_name="Base URL",
|
||||
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
|
||||
value="http://localhost:11434",
|
||||
),
|
||||
Input(
|
||||
name="model",
|
||||
type=str,
|
||||
field_type=str,
|
||||
display_name="Model Name",
|
||||
options=[], # This should be dynamically loaded if possible
|
||||
info="Refer to https://ollama.ai/library for more models.",
|
||||
|
|
@ -86,7 +86,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="mirostat",
|
||||
type=str,
|
||||
field_type=str,
|
||||
display_name="Mirostat",
|
||||
options=["Disabled", "Mirostat", "Mirostat 2.0"],
|
||||
info="Enable/disable Mirostat sampling for controlling perplexity.",
|
||||
|
|
@ -97,7 +97,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="mirostat_eta",
|
||||
type=Optional[float],
|
||||
field_type=Optional[float],
|
||||
display_name="Mirostat Eta",
|
||||
info="Learning rate for Mirostat algorithm.",
|
||||
advanced=True,
|
||||
|
|
@ -106,7 +106,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="mirostat_tau",
|
||||
type=Optional[float],
|
||||
field_type=Optional[float],
|
||||
display_name="Mirostat Tau",
|
||||
info="Controls the balance between coherence and diversity of the output.",
|
||||
advanced=True,
|
||||
|
|
@ -115,7 +115,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="temperature",
|
||||
type=float,
|
||||
field_type=float,
|
||||
display_name="Temperature",
|
||||
info="Controls the creativity of model responses.",
|
||||
value=0.8,
|
||||
|
|
@ -124,7 +124,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False),
|
||||
Input(
|
||||
name="system_message",
|
||||
type=Optional[str],
|
||||
field_type=Optional[str],
|
||||
display_name="System Message",
|
||||
info="System message to pass to the model.",
|
||||
advanced=True,
|
||||
|
|
@ -132,14 +132,14 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="headers",
|
||||
type=dict,
|
||||
field_type=dict,
|
||||
display_name="Headers",
|
||||
info="Additional headers to send with the request.",
|
||||
advanced=True,
|
||||
),
|
||||
Input(
|
||||
name="keep_alive_flag",
|
||||
type=str,
|
||||
field_type=str,
|
||||
display_params=["Keep", "Immediately", "Minute", "Hour", "sec"],
|
||||
display_name="Unload interval",
|
||||
info="Controls how the model unload interval is managed.",
|
||||
|
|
@ -148,7 +148,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
),
|
||||
Input(
|
||||
name="keep_alive",
|
||||
type=int,
|
||||
field_type=int,
|
||||
display_name="Interval",
|
||||
info="How long the model will stay loaded into memory.",
|
||||
value=None,
|
||||
|
|
|
|||
|
|
@ -7,10 +7,6 @@ from cachetools import TTLCache
|
|||
from langchain_core.documents import Document
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.custom.code_parser.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
from langflow.custom.custom_component.base_component import BaseComponent
|
||||
from langflow.helpers.flow import list_flows, load_flow, run_flow
|
||||
from langflow.schema import Record
|
||||
|
|
@ -20,6 +16,10 @@ from langflow.schema.message import Message
|
|||
from langflow.schema.schema import Log
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.type_extraction.type_extraction import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
from langflow.utils import validate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -327,7 +327,6 @@ class CustomComponent(BaseComponent):
|
|||
return []
|
||||
return_type = build_method["return_type"]
|
||||
|
||||
# If list or List is in the return type, then we remove it and return the inner type
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
|
||||
list,
|
||||
List,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom.code_parser.utils import extract_inner_type
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.directory_reader.utils import (
|
||||
abuild_custom_component_list_from_path,
|
||||
|
|
@ -26,6 +25,7 @@ from langflow.helpers.custom import format_type
|
|||
from langflow.schema import dotdict
|
||||
from langflow.template.field.base import Input
|
||||
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
|
||||
from langflow.type_extraction.type_extraction import extract_inner_type
|
||||
from langflow.utils import validate
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from enum import Enum
|
||||
from types import GenericAlias
|
||||
from typing import Any, Callable, Optional, Union, _GenericAlias, _UnionGenericAlias, get_args, get_origin
|
||||
from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
|
||||
|
||||
from langflow.field_typing import Text
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.helpers.custom import format_type
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
||||
|
||||
class UndefinedType(Enum):
|
||||
|
|
@ -118,21 +120,8 @@ class Input(BaseModel):
|
|||
# this should be done for all types
|
||||
# How to check if v is a type?
|
||||
if isinstance(v, (type, _GenericAlias, GenericAlias, _UnionGenericAlias)):
|
||||
if isinstance(v, type):
|
||||
v = v.__name__
|
||||
else:
|
||||
origin = get_origin(v)
|
||||
args = get_args(v)
|
||||
if origin and args:
|
||||
v = f"{origin.__name__}[{', '.join(arg.__name__ if isinstance(arg, type) else str(arg) for arg in args)}]"
|
||||
# if v is union with None (e.g Union[someType, NoneType]) we need to remove NoneType
|
||||
# we can return Optional[someType] instead of Union[someType, NoneType]
|
||||
if "NoneType" in v:
|
||||
v = v.replace(", NoneType", "")
|
||||
v = v.replace("Union[", "Optional[")
|
||||
|
||||
else:
|
||||
v = str(v)
|
||||
v = post_process_type(v)[0]
|
||||
v = format_type(v)
|
||||
elif not isinstance(v, str):
|
||||
raise ValueError(f"type must be a string or a type, not {type(v)}")
|
||||
return v
|
||||
|
|
@ -196,15 +185,6 @@ class Output(BaseModel):
|
|||
if not self.selected:
|
||||
self.selected = self.types[0]
|
||||
|
||||
@field_validator("display_name", mode="before")
|
||||
def validate_display_name(cls, v, info):
|
||||
if not v:
|
||||
if info.data.get("name"):
|
||||
return info.data["name"]
|
||||
else:
|
||||
raise ValueError("If display_name is not set, name must be set")
|
||||
return v
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
|
|
@ -217,4 +197,8 @@ class Output(BaseModel):
|
|||
def validate_model(self):
|
||||
if self.value == UNDEFINED.value:
|
||||
self.value = UNDEFINED
|
||||
if self.name is None:
|
||||
raise ValueError("name must be set")
|
||||
if self.display_name is None:
|
||||
self.display_name = self.name
|
||||
return self
|
||||
|
|
|
|||
46
src/backend/base/langflow/template/field/inputs.py
Normal file
46
src/backend/base/langflow/template/field/inputs.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from pydantic import SecretStr
|
||||
|
||||
from langflow.field_typing.constants import NestedDict
|
||||
from langflow.template.field.base import Input
|
||||
|
||||
|
||||
class StrInput(Input):
|
||||
field_type: str | type | None = str
|
||||
|
||||
|
||||
class SecretStrInput(Input):
|
||||
field_type: str | type | None = SecretStr
|
||||
password = True
|
||||
|
||||
|
||||
class IntInput(Input):
|
||||
field_type: str | type | None = int
|
||||
|
||||
|
||||
class FloatInput(Input):
|
||||
field_type: str | type | None = float
|
||||
|
||||
|
||||
class BoolInput(Input):
|
||||
field_type: str | type | None = bool
|
||||
|
||||
|
||||
class NestedDictInput(Input):
|
||||
field_type: str | type | None = NestedDict
|
||||
|
||||
|
||||
class DictInput(Input):
|
||||
field_type: str | type | None = dict
|
||||
|
||||
|
||||
class ListInput(Input):
|
||||
is_list = True
|
||||
|
||||
|
||||
class DropdownInput(Input):
|
||||
field_type: str | type | None = str
|
||||
options = []
|
||||
|
||||
|
||||
class FileInput(Input):
|
||||
field_type: str | type | None = str
|
||||
0
src/backend/base/langflow/type_extraction/__init__.py
Normal file
0
src/backend/base/langflow/type_extraction/__init__.py
Normal file
|
|
@ -1,15 +1,6 @@
|
|||
import re
|
||||
from types import GenericAlias
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_inner_type(return_type: str) -> str:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
"""
|
||||
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
|
||||
return match[1]
|
||||
return return_type
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
|
||||
|
|
@ -21,6 +12,15 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
|
|||
return return_type
|
||||
|
||||
|
||||
def extract_inner_type(return_type: str) -> str:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
"""
|
||||
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
|
||||
return match[1]
|
||||
return return_type
|
||||
|
||||
|
||||
def extract_union_types(return_type: str) -> list[str]:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
|
|
@ -31,6 +31,47 @@ def extract_union_types(return_type: str) -> list[str]:
|
|||
return [item.strip() for item in return_types]
|
||||
|
||||
|
||||
def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a Union.
|
||||
"""
|
||||
if isinstance(return_type, list):
|
||||
return [
|
||||
_inner_arg
|
||||
for _type in return_type
|
||||
for _inner_arg in _type.__args__
|
||||
if _inner_arg not in set((Any, type(None), type(Any)))
|
||||
]
|
||||
|
||||
return list(return_type.__args__)
|
||||
|
||||
|
||||
def post_process_type(_type):
|
||||
"""
|
||||
Process the return type of a function.
|
||||
|
||||
Args:
|
||||
_type (Any): The return type of the function.
|
||||
|
||||
Returns:
|
||||
Union[List[Any], Any]: The processed return type.
|
||||
|
||||
"""
|
||||
if hasattr(_type, "__origin__") and _type.__origin__ in [
|
||||
list,
|
||||
List,
|
||||
]:
|
||||
_type = extract_inner_type_from_generic_alias(_type)
|
||||
|
||||
# If the return type is not a Union, then we just return it as a list
|
||||
inner_type = _type[0] if isinstance(_type, list) else _type
|
||||
if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union:
|
||||
return _type if isinstance(_type, list) else [_type]
|
||||
# If the return type is a Union, then we need to parse it
|
||||
_type = extract_union_types_from_generic_alias(_type)
|
||||
return _type
|
||||
|
||||
|
||||
def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a Union.
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
{
|
||||
"status": "failed",
|
||||
"failedTests": []
|
||||
}
|
||||
124
tests/test_schema.py
Normal file
124
tests/test_schema.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@pytest.fixture(name="client", autouse=True)
|
||||
def client_fixture():
|
||||
pass
|
||||
|
||||
|
||||
class TestInput:
|
||||
def test_field_type_str(self):
|
||||
input_obj = Input(field_type="str")
|
||||
assert input_obj.field_type == "str"
|
||||
|
||||
def test_field_type_type(self):
|
||||
input_obj = Input(field_type=int)
|
||||
assert input_obj.field_type == "int"
|
||||
|
||||
def test_invalid_field_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
Input(field_type=123)
|
||||
|
||||
def test_serialize_field_type(self):
|
||||
input_obj = Input(field_type="str")
|
||||
assert input_obj.serialize_field_type("str", None) == "str"
|
||||
|
||||
def test_validate_type_string(self):
|
||||
input_obj = Input(field_type="str")
|
||||
assert input_obj.field_type == "str"
|
||||
|
||||
def test_validate_type_class(self):
|
||||
input_obj = Input(field_type=int)
|
||||
assert input_obj.field_type == "int"
|
||||
|
||||
def test_post_process_type_function(self):
|
||||
assert post_process_type(int) == [int]
|
||||
assert post_process_type(list[int]) == [int]
|
||||
assert post_process_type(Union[int, str]) == [int, str]
|
||||
|
||||
def test_input_to_dict(self):
|
||||
input_obj = Input(field_type="str")
|
||||
assert input_obj.to_dict() == {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"password": False,
|
||||
"advanced": False,
|
||||
"title_case": False,
|
||||
}
|
||||
|
||||
|
||||
class TestOutput:
|
||||
def test_output_default(self):
|
||||
output_obj = Output(name="test_output")
|
||||
assert output_obj.name == "test_output"
|
||||
assert output_obj.value == UNDEFINED
|
||||
assert output_obj.cache is True
|
||||
|
||||
def test_output_add_types(self):
|
||||
output_obj = Output(name="test_output")
|
||||
output_obj.add_types(["str", "int"])
|
||||
assert output_obj.types == ["str", "int"]
|
||||
|
||||
def test_output_set_selected(self):
|
||||
output_obj = Output(name="test_output", types=["str", "int"])
|
||||
output_obj.set_selected()
|
||||
assert output_obj.selected == "str"
|
||||
|
||||
def test_output_to_dict(self):
|
||||
output_obj = Output(name="test_output")
|
||||
assert output_obj.to_dict() == {
|
||||
"types": [],
|
||||
"name": "test_output",
|
||||
"display_name": "test_output",
|
||||
"cache": True,
|
||||
"value": "__UNDEFINED__",
|
||||
}
|
||||
|
||||
def test_output_validate_display_name(self):
|
||||
output_obj = Output(name="test_output")
|
||||
assert output_obj.display_name == "test_output"
|
||||
|
||||
def test_output_validate_model(self):
|
||||
output_obj = Output(name="test_output", value="__UNDEFINED__")
|
||||
assert output_obj.validate_model() == output_obj
|
||||
|
||||
|
||||
class TestPostProcessType:
|
||||
def test_int_type(self):
|
||||
assert post_process_type(int) == [int]
|
||||
|
||||
def test_list_int_type(self):
|
||||
assert post_process_type(list[int]) == [int]
|
||||
|
||||
def test_union_type(self):
|
||||
assert post_process_type(Union[int, str]) == [int, str]
|
||||
|
||||
def test_custom_type(self):
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
assert post_process_type(CustomType) == [CustomType]
|
||||
|
||||
def test_list_custom_type(self):
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
assert post_process_type(list[CustomType]) == [CustomType]
|
||||
|
||||
def test_union_custom_type(self):
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
assert post_process_type(Union[CustomType, int]) == [CustomType, int]
|
||||
Loading…
Add table
Add a link
Reference in a new issue