Merge branch 'two_edges' of personal:langflow-ai/langflow into two_edges

This commit is contained in:
anovazzi1 2024-06-12 11:01:39 -03:00
commit 3224227f58
9 changed files with 246 additions and 56 deletions

View file

@ -70,14 +70,14 @@ class ChatOllamaComponent(LCModelComponent):
inputs = [
Input(
name="base_url",
type=Optional[str],
field_type=Optional[str],
display_name="Base URL",
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
value="http://localhost:11434",
),
Input(
name="model",
type=str,
field_type=str,
display_name="Model Name",
options=[], # This should be dynamically loaded if possible
info="Refer to https://ollama.ai/library for more models.",
@ -86,7 +86,7 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="mirostat",
type=str,
field_type=str,
display_name="Mirostat",
options=["Disabled", "Mirostat", "Mirostat 2.0"],
info="Enable/disable Mirostat sampling for controlling perplexity.",
@ -97,7 +97,7 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="mirostat_eta",
type=Optional[float],
field_type=Optional[float],
display_name="Mirostat Eta",
info="Learning rate for Mirostat algorithm.",
advanced=True,
@ -106,7 +106,7 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="mirostat_tau",
type=Optional[float],
field_type=Optional[float],
display_name="Mirostat Tau",
info="Controls the balance between coherence and diversity of the output.",
advanced=True,
@ -115,7 +115,7 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="temperature",
type=float,
field_type=float,
display_name="Temperature",
info="Controls the creativity of model responses.",
value=0.8,
@ -124,7 +124,7 @@ class ChatOllamaComponent(LCModelComponent):
Input(name="stream", type=bool, display_name="Stream", info=STREAM_INFO_TEXT, value=False),
Input(
name="system_message",
type=Optional[str],
field_type=Optional[str],
display_name="System Message",
info="System message to pass to the model.",
advanced=True,
@ -132,14 +132,14 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="headers",
type=dict,
field_type=dict,
display_name="Headers",
info="Additional headers to send with the request.",
advanced=True,
),
Input(
name="keep_alive_flag",
type=str,
field_type=str,
display_params=["Keep", "Immediately", "Minute", "Hour", "sec"],
display_name="Unload interval",
info="Controls how the model unload interval is managed.",
@ -148,7 +148,7 @@ class ChatOllamaComponent(LCModelComponent):
),
Input(
name="keep_alive",
type=int,
field_type=int,
display_name="Interval",
info="How long the model will stay loaded into memory.",
value=None,

View file

@ -7,10 +7,6 @@ from cachetools import TTLCache
from langchain_core.documents import Document
from pydantic import BaseModel
from langflow.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.custom.custom_component.base_component import BaseComponent
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.schema import Record
@ -20,6 +16,10 @@ from langflow.schema.message import Message
from langflow.schema.schema import Log
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.type_extraction.type_extraction import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.utils import validate
if TYPE_CHECKING:
@ -327,7 +327,6 @@ class CustomComponent(BaseComponent):
return []
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,

View file

@ -11,7 +11,6 @@ from loguru import logger
from pydantic import BaseModel
from langflow.custom import CustomComponent
from langflow.custom.code_parser.utils import extract_inner_type
from langflow.custom.custom_component.component import Component
from langflow.custom.directory_reader.utils import (
abuild_custom_component_list_from_path,
@ -26,6 +25,7 @@ from langflow.helpers.custom import format_type
from langflow.schema import dotdict
from langflow.template.field.base import Input
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
from langflow.type_extraction.type_extraction import extract_inner_type
from langflow.utils import validate
from langflow.utils.util import get_base_classes

View file

@ -1,11 +1,13 @@
from enum import Enum
from types import GenericAlias
from typing import Any, Callable, Optional, Union, _GenericAlias, _UnionGenericAlias, get_args, get_origin
from typing import Any, Callable, GenericAlias, Optional, Union, _GenericAlias, _UnionGenericAlias
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_serializer, model_validator
from langflow.field_typing import Text
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
from langflow.type_extraction.type_extraction import post_process_type
class UndefinedType(Enum):
@ -118,21 +120,8 @@ class Input(BaseModel):
# this should be done for all types
# How to check if v is a type?
if isinstance(v, (type, _GenericAlias, GenericAlias, _UnionGenericAlias)):
if isinstance(v, type):
v = v.__name__
else:
origin = get_origin(v)
args = get_args(v)
if origin and args:
v = f"{origin.__name__}[{', '.join(arg.__name__ if isinstance(arg, type) else str(arg) for arg in args)}]"
# if v is union with None (e.g Union[someType, NoneType]) we need to remove NoneType
# we can return Optional[someType] instead of Union[someType, NoneType]
if "NoneType" in v:
v = v.replace(", NoneType", "")
v = v.replace("Union[", "Optional[")
else:
v = str(v)
v = post_process_type(v)[0]
v = format_type(v)
elif not isinstance(v, str):
raise ValueError(f"type must be a string or a type, not {type(v)}")
return v
@ -196,15 +185,6 @@ class Output(BaseModel):
if not self.selected:
self.selected = self.types[0]
@field_validator("display_name", mode="before")
def validate_display_name(cls, v, info):
if not v:
if info.data.get("name"):
return info.data["name"]
else:
raise ValueError("If display_name is not set, name must be set")
return v
@model_serializer(mode="wrap")
def serialize_model(self, handler):
result = handler(self)
@ -217,4 +197,8 @@ class Output(BaseModel):
def validate_model(self):
if self.value == UNDEFINED.value:
self.value = UNDEFINED
if self.name is None:
raise ValueError("name must be set")
if self.display_name is None:
self.display_name = self.name
return self

View file

@ -0,0 +1,46 @@
from pydantic import SecretStr
from langflow.field_typing.constants import NestedDict
from langflow.template.field.base import Input
class StrInput(Input):
field_type: str | type | None = str
class SecretStrInput(Input):
field_type: str | type | None = SecretStr
password = True
class IntInput(Input):
field_type: str | type | None = int
class FloatInput(Input):
field_type: str | type | None = float
class BoolInput(Input):
field_type: str | type | None = bool
class NestedDictInput(Input):
field_type: str | type | None = NestedDict
class DictInput(Input):
field_type: str | type | None = dict
class ListInput(Input):
is_list = True
class DropdownInput(Input):
field_type: str | type | None = str
options = []
class FileInput(Input):
field_type: str | type | None = str

View file

@ -1,15 +1,6 @@
import re
from types import GenericAlias
from typing import Any
def extract_inner_type(return_type: str) -> str:
"""
Extracts the inner type from a type hint that is a list.
"""
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
return match[1]
return return_type
from typing import Any, List, Union
def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
@ -21,6 +12,15 @@ def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
return return_type
def extract_inner_type(return_type: str) -> str:
"""
Extracts the inner type from a type hint that is a list.
"""
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
return match[1]
return return_type
def extract_union_types(return_type: str) -> list[str]:
"""
Extracts the inner type from a type hint that is a list.
@ -31,6 +31,47 @@ def extract_union_types(return_type: str) -> list[str]:
return [item.strip() for item in return_types]
def extract_uniont_types_from_generic_alias(return_type: GenericAlias) -> list:
"""
Extracts the inner type from a type hint that is a Union.
"""
if isinstance(return_type, list):
return [
_inner_arg
for _type in return_type
for _inner_arg in _type.__args__
if _inner_arg not in set((Any, type(None), type(Any)))
]
return list(return_type.__args__)
def post_process_type(_type):
"""
Process the return type of a function.
Args:
_type (Any): The return type of the function.
Returns:
Union[List[Any], Any]: The processed return type.
"""
if hasattr(_type, "__origin__") and _type.__origin__ in [
list,
List,
]:
_type = extract_inner_type_from_generic_alias(_type)
# If the return type is not a Union, then we just return it as a list
inner_type = _type[0] if isinstance(_type, list) else _type
if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union:
return _type if isinstance(_type, list) else [_type]
# If the return type is a Union, then we need to parse it
_type = extract_union_types_from_generic_alias(_type)
return _type
def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
"""
Extracts the inner type from a type hint that is a Union.

View file

@ -1,4 +0,0 @@
{
"status": "failed",
"failedTests": []
}

124
tests/test_schema.py Normal file
View file

@ -0,0 +1,124 @@
from typing import Union
import pytest
from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from pydantic import ValidationError
@pytest.fixture(name="client", autouse=True)
def client_fixture():
pass
class TestInput:
def test_field_type_str(self):
input_obj = Input(field_type="str")
assert input_obj.field_type == "str"
def test_field_type_type(self):
input_obj = Input(field_type=int)
assert input_obj.field_type == "int"
def test_invalid_field_type(self):
with pytest.raises(ValidationError):
Input(field_type=123)
def test_serialize_field_type(self):
input_obj = Input(field_type="str")
assert input_obj.serialize_field_type("str", None) == "str"
def test_validate_type_string(self):
input_obj = Input(field_type="str")
assert input_obj.field_type == "str"
def test_validate_type_class(self):
input_obj = Input(field_type=int)
assert input_obj.field_type == "int"
def test_post_process_type_function(self):
assert post_process_type(int) == [int]
assert post_process_type(list[int]) == [int]
assert post_process_type(Union[int, str]) == [int, str]
def test_input_to_dict(self):
input_obj = Input(field_type="str")
assert input_obj.to_dict() == {
"type": "str",
"required": False,
"placeholder": "",
"list": False,
"show": True,
"multiline": False,
"fileTypes": [],
"file_path": "",
"password": False,
"advanced": False,
"title_case": False,
}
class TestOutput:
def test_output_default(self):
output_obj = Output(name="test_output")
assert output_obj.name == "test_output"
assert output_obj.value == UNDEFINED
assert output_obj.cache is True
def test_output_add_types(self):
output_obj = Output(name="test_output")
output_obj.add_types(["str", "int"])
assert output_obj.types == ["str", "int"]
def test_output_set_selected(self):
output_obj = Output(name="test_output", types=["str", "int"])
output_obj.set_selected()
assert output_obj.selected == "str"
def test_output_to_dict(self):
output_obj = Output(name="test_output")
assert output_obj.to_dict() == {
"types": [],
"name": "test_output",
"display_name": "test_output",
"cache": True,
"value": "__UNDEFINED__",
}
def test_output_validate_display_name(self):
output_obj = Output(name="test_output")
assert output_obj.display_name == "test_output"
def test_output_validate_model(self):
output_obj = Output(name="test_output", value="__UNDEFINED__")
assert output_obj.validate_model() == output_obj
class TestPostProcessType:
def test_int_type(self):
assert post_process_type(int) == [int]
def test_list_int_type(self):
assert post_process_type(list[int]) == [int]
def test_union_type(self):
assert post_process_type(Union[int, str]) == [int, str]
def test_custom_type(self):
class CustomType:
pass
assert post_process_type(CustomType) == [CustomType]
def test_list_custom_type(self):
class CustomType:
pass
assert post_process_type(list[CustomType]) == [CustomType]
def test_union_custom_type(self):
class CustomType:
pass
assert post_process_type(Union[CustomType, int]) == [CustomType, int]