diff --git a/src/backend/base/langflow/components/models/OllamaModel.py b/src/backend/base/langflow/components/models/OllamaModel.py index 12db058c2..2c7d3bae7 100644 --- a/src/backend/base/langflow/components/models/OllamaModel.py +++ b/src/backend/base/langflow/components/models/OllamaModel.py @@ -70,14 +70,14 @@ class ChatOllamaComponent(LCModelComponent): inputs = [ Input( name="base_url", - type=Optional[str], + field_type=Optional[str], display_name="Base URL", info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", value="http://localhost:11434", ), Input( name="model", - type=str, + field_type=str, display_name="Model Name", options=[], # This should be dynamically loaded if possible info="Refer to https://ollama.ai/library for more models.", @@ -86,7 +86,7 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="mirostat", - type=str, + field_type=str, display_name="Mirostat", options=["Disabled", "Mirostat", "Mirostat 2.0"], info="Enable/disable Mirostat sampling for controlling perplexity.", @@ -97,7 +97,7 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="mirostat_eta", - type=Optional[float], + field_type=Optional[float], display_name="Mirostat Eta", info="Learning rate for Mirostat algorithm.", advanced=True, @@ -106,7 +106,7 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="mirostat_tau", - type=Optional[float], + field_type=Optional[float], display_name="Mirostat Tau", info="Controls the balance between coherence and diversity of the output.", advanced=True, @@ -115,7 +115,7 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="temperature", - type=float, + field_type=float, display_name="Temperature", info="Controls the creativity of model responses.", value=0.8, @@ -124,7 +124,7 @@ class ChatOllamaComponent(LCModelComponent): Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False), Input( name="system_message", - type=Optional[str], + field_type=Optional[str], display_name="System Message", info="System message to pass to the model.", advanced=True, @@ -132,14 +132,14 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="headers", - type=dict, + field_type=dict, display_name="Headers", info="Additional headers to send with the request.", advanced=True, ), Input( name="keep_alive_flag", - type=str, + field_type=str, display_params=["Keep", "Immediately", "Minute", "Hour", "sec"], display_name="Unload interval", info="Controls how the model unload interval is managed.", @@ -148,7 +148,7 @@ class ChatOllamaComponent(LCModelComponent): ), Input( name="keep_alive", - type=int, + field_type=int, display_name="Interval", info="How long the model will stay loaded into memory.", value=None, diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 4a5218b71..548ad1dd7 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -7,10 +7,6 @@ from cachetools import TTLCache from langchain_core.documents import Document from pydantic import BaseModel -from langflow.custom.code_parser.utils import ( - extract_inner_type_from_generic_alias, - extract_union_types_from_generic_alias, -) from langflow.custom.custom_component.base_component import BaseComponent from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.schema import Record @@ -20,6 +16,10 @@ from langflow.schema.message import Message from langflow.schema.schema import Log from langflow.services.deps import get_storage_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService +from langflow.type_extraction.type_extraction import ( + extract_inner_type_from_generic_alias, + extract_union_types_from_generic_alias, +) from langflow.utils import validate if TYPE_CHECKING: @@ -327,7 +327,6 @@ class CustomComponent(BaseComponent): return [] return_type = build_method["return_type"] - # If list or List is in the return type, then we remove it and return the inner type if hasattr(return_type, "__origin__") and return_type.__origin__ in [ list, List, diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index d6f80ce7e..19943d499 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -11,7 +11,6 @@ from loguru import logger from pydantic import BaseModel from langflow.custom import CustomComponent -from langflow.custom.code_parser.utils import extract_inner_type from langflow.custom.custom_component.component import Component from langflow.custom.directory_reader.utils import ( abuild_custom_component_list_from_path, @@ -26,6 +25,7 @@ from langflow.helpers.custom import format_type from langflow.schema import dotdict from langflow.template.field.base import Input from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode +from langflow.type_extraction.type_extraction import extract_inner_type from langflow.utils import validate from langflow.utils.util import get_base_classes diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index fe260c890..22f86b3c8 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -1,11 +1,13 @@ from enum import Enum -from types import GenericAlias -from typing import Any, Callable, Optional, Union, _GenericAlias, _UnionGenericAlias, get_args, get_origin +from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias + from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator from langflow.field_typing import Text from langflow.field_typing.range_spec import RangeSpec +from langflow.helpers.custom import format_type +from langflow.type_extraction.type_extraction import post_process_type class UndefinedType(Enum): @@ -118,21 +120,8 @@ class Input(BaseModel): # this should be done for all types # How to check if v is a type? if isinstance(v, (type, _GenericAlias, GenericAlias, _UnionGenericAlias)): - if isinstance(v, type): - v = v.__name__ - else: - origin = get_origin(v) - args = get_args(v) - if origin and args: - v = f"{origin.__name__}[{', '.join(arg.__name__ if isinstance(arg, type) else str(arg) for arg in args)}]" - # if v is union with None (e.g Union[someType, NoneType]) we need to remove NoneType - # we can return Optional[someType] instead of Union[someType, NoneType] - if "NoneType" in v: - v = v.replace(", NoneType", "") - v = v.replace("Union[", "Optional[") - - else: - v = str(v) + v = post_process_type(v)[0] + v = format_type(v) elif not isinstance(v, str): raise ValueError(f"type must be a string or a type, not {type(v)}") return v @@ -196,15 +185,6 @@ class Output(BaseModel): if not self.selected: self.selected = self.types[0] - @field_validator("display_name", mode="before") - def validate_display_name(cls, v, info): - if not v: - if info.data.get("name"): - return info.data["name"] - else: - raise ValueError("If display_name is not set, name must be set") - return v - @model_serializer(mode="wrap") def serialize_model(self, handler): result = handler(self) @@ -217,4 +197,8 @@ class Output(BaseModel): def validate_model(self): if self.value == UNDEFINED.value: self.value = UNDEFINED + if self.name is None: + raise ValueError("name must be set") + if self.display_name is None: + self.display_name = self.name return self diff --git a/src/backend/base/langflow/template/field/inputs.py b/src/backend/base/langflow/template/field/inputs.py new file mode 100644 index 000000000..c71061f6e --- /dev/null +++ b/src/backend/base/langflow/template/field/inputs.py @@ -0,0 +1,46 @@ +from pydantic import SecretStr + +from langflow.field_typing.constants import NestedDict +from langflow.template.field.base import Input + + +class StrInput(Input): + field_type: str | type | None = str + + +class SecretStrInput(Input): + field_type: str | type | None = SecretStr + password = True + + +class IntInput(Input): + field_type: str | type | None = int + + +class FloatInput(Input): + field_type: str | type | None = float + + +class BoolInput(Input): + field_type: str | type | None = bool + + +class NestedDictInput(Input): + field_type: str | type | None = NestedDict + + +class DictInput(Input): + field_type: str | type | None = dict + + +class ListInput(Input): + is_list = True + + +class DropdownInput(Input): + field_type: str | type | None = str + options = [] + + +class FileInput(Input): + field_type: str | type | None = str diff --git a/src/backend/base/langflow/type_extraction/__init__.py b/src/backend/base/langflow/type_extraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/custom/code_parser/utils.py b/src/backend/base/langflow/type_extraction/type_extraction.py similarity index 51% rename from src/backend/base/langflow/custom/code_parser/utils.py rename to src/backend/base/langflow/type_extraction/type_extraction.py index 0f97b4c7b..f22a8eebe 100644 --- a/src/backend/base/langflow/custom/code_parser/utils.py +++ b/src/backend/base/langflow/type_extraction/type_extraction.py @@ -1,15 +1,6 @@ import re from types import GenericAlias -from typing import Any - - -def extract_inner_type(return_type: str) -> str: - """ - Extracts the inner type from a type hint that is a list. - """ - if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE): - return match[1] - return return_type +from typing import Any, List, Union def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any: @@ -21,6 +12,15 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any: return return_type +def extract_inner_type(return_type: str) -> str: + """ + Extracts the inner type from a type hint that is a list. + """ + if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE): + return match[1] + return return_type + + def extract_union_types(return_type: str) -> list[str]: """ Extracts the inner type from a type hint that is a list. @@ -31,6 +31,47 @@ def extract_union_types(return_type: str) -> list[str]: return [item.strip() for item in return_types] +def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list: + """ + Extracts the inner type from a type hint that is a Union. + """ + if isinstance(return_type, list): + return [ + _inner_arg + for _type in return_type + for _inner_arg in _type.__args__ + if _inner_arg not in set((Any, type(None), type(Any))) + ] + + return list(return_type.__args__) + + +def post_process_type(_type): + """ + Process the return type of a function. + + Args: + _type (Any): The return type of the function. + + Returns: + Union[List[Any], Any]: The processed return type. + + """ + if hasattr(_type, "__origin__") and _type.__origin__ in [ + list, + List, + ]: + _type = extract_inner_type_from_generic_alias(_type) + + # If the return type is not a Union, then we just return it as a list + inner_type = _type[0] if isinstance(_type, list) else _type + if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union: + return _type if isinstance(_type, list) else [_type] + # If the return type is a Union, then we need to parse it + _type = extract_union_types_from_generic_alias(_type) + return _type + + def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list: """ Extracts the inner type from a type hint that is a Union. diff --git a/test-results/.last-run.json b/test-results/.last-run.json deleted file mode 100644 index 8460acf95..000000000 --- a/test-results/.last-run.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "status": "failed", - "failedTests": [] -} diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 000000000..ef8879cab --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,124 @@ +from typing import Union + +import pytest +from langflow.template import Input, Output +from langflow.template.field.base import UNDEFINED +from langflow.type_extraction.type_extraction import post_process_type +from pydantic import ValidationError + + +@pytest.fixture(name="client", autouse=True) +def client_fixture(): + pass + + +class TestInput: + def test_field_type_str(self): + input_obj = Input(field_type="str") + assert input_obj.field_type == "str" + + def test_field_type_type(self): + input_obj = Input(field_type=int) + assert input_obj.field_type == "int" + + def test_invalid_field_type(self): + with pytest.raises(ValidationError): + Input(field_type=123) + + def test_serialize_field_type(self): + input_obj = Input(field_type="str") + assert input_obj.serialize_field_type("str", None) == "str" + + def test_validate_type_string(self): + input_obj = Input(field_type="str") + assert input_obj.field_type == "str" + + def test_validate_type_class(self): + input_obj = Input(field_type=int) + assert input_obj.field_type == "int" + + def test_post_process_type_function(self): + assert post_process_type(int) == [int] + assert post_process_type(list[int]) == [int] + assert post_process_type(Union[int, str]) == [int, str] + + def test_input_to_dict(self): + input_obj = Input(field_type="str") + assert input_obj.to_dict() == { + "type": "str", + "required": False, + "placeholder": "", + "list": False, + "show": True, + "multiline": False, + "fileTypes": [], + "file_path": "", + "password": False, + "advanced": False, + "title_case": False, + } + + +class TestOutput: + def test_output_default(self): + output_obj = Output(name="test_output") + assert output_obj.name == "test_output" + assert output_obj.value == UNDEFINED + assert output_obj.cache is True + + def test_output_add_types(self): + output_obj = Output(name="test_output") + output_obj.add_types(["str", "int"]) + assert output_obj.types == ["str", "int"] + + def test_output_set_selected(self): + output_obj = Output(name="test_output", types=["str", "int"]) + output_obj.set_selected() + assert output_obj.selected == "str" + + def test_output_to_dict(self): + output_obj = Output(name="test_output") + assert output_obj.to_dict() == { + "types": [], + "name": "test_output", + "display_name": "test_output", + "cache": True, + "value": "__UNDEFINED__", + } + + def test_output_validate_display_name(self): + output_obj = Output(name="test_output") + assert output_obj.display_name == "test_output" + + def test_output_validate_model(self): + output_obj = Output(name="test_output", value="__UNDEFINED__") + assert output_obj.validate_model() == output_obj + + +class TestPostProcessType: + def test_int_type(self): + assert post_process_type(int) == [int] + + def test_list_int_type(self): + assert post_process_type(list[int]) == [int] + + def test_union_type(self): + assert post_process_type(Union[int, str]) == [int, str] + + def test_custom_type(self): + class CustomType: + pass + + assert post_process_type(CustomType) == [CustomType] + + def test_list_custom_type(self): + class CustomType: + pass + + assert post_process_type(list[CustomType]) == [CustomType] + + def test_union_custom_type(self): + class CustomType: + pass + + assert post_process_type(Union[CustomType, int]) == [CustomType, int]