🔧 (constants.py): Replace TypeAlias with TypeVar for LanguageModel in field_typing module

📝 (model.py): Import LanguageModel from field_typing module to use in LCModelComponent class
🔧 (utils.py): Replace get_all_types_from_type function with format_type function in custom module
🔧 (custom.py): Remove unused import get_args in format_type function in helpers module
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 13:36:38 -03:00
commit 8ef4b090f4
4 changed files with 8 additions and 20 deletions

View file

@ -2,12 +2,12 @@ import json
import warnings
from typing import Optional, Union
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langflow.base.models.exceptions import get_message_from_openai_exception
from langflow.custom import Component
from langflow.field_typing import LanguageModel
from langflow.schema.message import Message
@ -91,7 +91,7 @@ class LCModelComponent(Component):
return status_message
def get_chat_result(
self, runnable: BaseChatModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None
self, runnable: LanguageModel, stream: bool, input_value: str | Message, system_message: Optional[str] = None
):
messages: list[Union[HumanMessage, SystemMessage]] = []
if not input_value and not system_message:

View file

@ -3,7 +3,6 @@ import contextlib
import re
import traceback
import warnings
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
@ -22,7 +21,7 @@ from langflow.custom.directory_reader.utils import (
from langflow.custom.eval import eval_custom_component_code
from langflow.custom.schema import MissingDefault
from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import get_all_types_from_type
from langflow.helpers.custom import format_type
from langflow.schema import dotdict
from langflow.template.field.base import Input
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
@ -371,8 +370,8 @@ def build_custom_component_template_from_inputs(
if output.types:
continue
return_types = custom_component.get_method_return_type(output.method)
all_types = [get_all_types_from_type(return_type) for return_type in return_types]
output.add_types(chain.from_iterable(all_types))
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)
output.set_selected()
# Validate that there is not name overlap between inputs and outputs
frontend_node.validate()

View file

@ -1,4 +1,4 @@
from typing import Callable, Dict, Text, TypeAlias, Union
from typing import Callable, Dict, Text, TypeAlias, TypeVar, Union
from langchain.agents.agent import AgentExecutor
from langchain.chains.base import Chain
@ -16,9 +16,8 @@ from langchain_core.tools import Tool
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import TextSplitter
NestedDict: TypeAlias = Dict[str, Union[str, Dict]]
LanguageModel: TypeAlias = Union[BaseLanguageModel, BaseLLM, BaseChatModel]
LanguageModel = TypeVar("LanguageModel", BaseLanguageModel, BaseLLM, BaseChatModel)
Retriever: TypeAlias = BaseRetriever

View file

@ -1,4 +1,4 @@
from typing import Any, get_args
from typing import Any
def format_type(type_: Any) -> str:
@ -11,13 +11,3 @@ def format_type(type_: Any) -> str:
else:
type_ = str(type_)
return type_
def get_all_types_from_type(type_: Any) -> str:
args = get_args(type_)
if args:
formatted_types = [format_type(arg) for arg in args]
formatted_types.insert(0, format_type(type_))
return formatted_types
else:
return [format_type(type_)]