From 8ef4b090f411f0026c8aa178ffc41516d755ca9f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 20 Jun 2024 13:36:38 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20(constants.py):=20Replace=20Type?= =?UTF-8?q?Alias=20with=20TypeVar=20for=20LanguageModel=20in=20field=5Ftyp?= =?UTF-8?q?ing=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (model.py): Import LanguageModel from field_typing module to use in LCModelComponent class 🔧 (utils.py): Replace get_all_types_from_type function with format_type function in custom module 🔧 (custom.py): Remove unused import get_args in format_type function in helpers module --- src/backend/base/langflow/base/models/model.py | 4 ++-- src/backend/base/langflow/custom/utils.py | 7 +++---- src/backend/base/langflow/field_typing/constants.py | 5 ++--- src/backend/base/langflow/helpers/custom.py | 12 +----------- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 76f98290a..11256734d 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -2,12 +2,12 @@ import json import warnings from typing import Optional, Union -from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import LLM from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langflow.base.models.exceptions import get_message_from_openai_exception from langflow.custom import Component +from langflow.field_typing import LanguageModel from langflow.schema.message import Message @@ -91,7 +91,7 @@ class LCModelComponent(Component): return status_message def get_chat_result( - self, runnable: BaseChatModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None + self, runnable: LanguageModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None ): messages: list[Union[HumanMessage, SystemMessage]] = [] if not input_value and not system_message: diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index f2a2450a1..49bb201a8 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -3,7 +3,6 @@ import contextlib import re import traceback import warnings -from itertools import chain from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID @@ -22,7 +21,7 @@ from langflow.custom.directory_reader.utils import ( from langflow.custom.eval import eval_custom_component_code from langflow.custom.schema import MissingDefault from langflow.field_typing.range_spec import RangeSpec -from langflow.helpers.custom import get_all_types_from_type +from langflow.helpers.custom import format_type from langflow.schema import dotdict from langflow.template.field.base import Input from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode @@ -371,8 +370,8 @@ def build_custom_component_template_from_inputs( if output.types: continue return_types = custom_component.get_method_return_type(output.method) - all_types = [get_all_types_from_type(return_type) for return_type in return_types] - output.add_types(chain.from_iterable(all_types)) + return_types = [format_type(return_type) for return_type in return_types] + output.add_types(return_types) output.set_selected() # Validate that there is not name overlap between inputs and outputs frontend_node.validate() diff --git a/src/backend/base/langflow/field_typing/constants.py b/src/backend/base/langflow/field_typing/constants.py index feae7880d..8ba7e544d 100644 --- a/src/backend/base/langflow/field_typing/constants.py +++ b/src/backend/base/langflow/field_typing/constants.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Text, TypeAlias, Union +from typing import Callable, Dict, Text, TypeAlias, TypeVar, Union from langchain.agents.agent import AgentExecutor from langchain.chains.base import Chain @@ -16,9 +16,8 @@ from langchain_core.tools import Tool from langchain_core.vectorstores import VectorStore from langchain_text_splitters import TextSplitter - NestedDict: TypeAlias = Dict[str, Union[str, Dict]] -LanguageModel: TypeAlias = Union[BaseLanguageModel, BaseLLM, BaseChatModel] +LanguageModel = TypeVar("LanguageModel", BaseLanguageModel, BaseLLM, BaseChatModel) Retriever: TypeAlias = BaseRetriever diff --git a/src/backend/base/langflow/helpers/custom.py b/src/backend/base/langflow/helpers/custom.py index 61ab4f46b..bdbb128f4 100644 --- a/src/backend/base/langflow/helpers/custom.py +++ b/src/backend/base/langflow/helpers/custom.py @@ -1,4 +1,4 @@ -from typing import Any, get_args +from typing import Any def format_type(type_: Any) -> str: @@ -11,13 +11,3 @@ def format_type(type_: Any) -> str: else: type_ = str(type_) return type_ - - -def get_all_types_from_type(type_: Any) -> str: - args = get_args(type_) - if args: - formatted_types = [format_type(arg) for arg in args] - formatted_types.insert(0, format_type(type_)) - return formatted_types - else: - return [format_type(type_)]