From 00634a235408ef274c8957fb3c79da1980c58cfc Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 3 Nov 2023 16:38:56 -0300 Subject: [PATCH] Refactor field typing in LLMChainComponent and related modules --- src/backend/langflow/components/chains/LLMChain.py | 9 +++++++-- src/backend/langflow/field_typing/__init__.py | 4 ++++ src/backend/langflow/field_typing/constants.py | 3 ++- .../langflow/interface/custom/custom_component.py | 6 ++---- src/backend/langflow/interface/custom/utils.py | 11 +++++++++++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/backend/langflow/components/chains/LLMChain.py b/src/backend/langflow/components/chains/LLMChain.py index 12b38a27a..b14eef302 100644 --- a/src/backend/langflow/components/chains/LLMChain.py +++ b/src/backend/langflow/components/chains/LLMChain.py @@ -1,7 +1,12 @@ from langflow import CustomComponent from langchain.chains import LLMChain from typing import Optional, Union, Callable -from langflow.field_typing import PromptTemplate, BaseLanguageModel, BaseMemory, Chain +from langflow.field_typing import ( + BasePromptTemplate, + BaseLanguageModel, + BaseMemory, + Chain, +) class LLMChainComponent(CustomComponent): @@ -18,7 +23,7 @@ class LLMChainComponent(CustomComponent): def build( self, - prompt: PromptTemplate, + prompt: BasePromptTemplate, llm: BaseLanguageModel, memory: Optional[BaseMemory] = None, ) -> Union[Chain, Callable]: diff --git a/src/backend/langflow/field_typing/__init__.py b/src/backend/langflow/field_typing/__init__.py index 0a135818b..ceba9fded 100644 --- a/src/backend/langflow/field_typing/__init__.py +++ b/src/backend/langflow/field_typing/__init__.py @@ -17,6 +17,8 @@ from .constants import ( Tool, PromptTemplate, + ChatPromptTemplate, + BasePromptTemplate, Chain, BaseChatMemory, BaseLLM, @@ -54,4 +56,6 @@ __all__ = [ "Document", "AgentExecutor", "Callable", + "BasePromptTemplate", + "ChatPromptTemplate", ] diff --git a/src/backend/langflow/field_typing/constants.py b/src/backend/langflow/field_typing/constants.py index 401a0d0b7..2cbc09bb2 100644 --- a/src/backend/langflow/field_typing/constants.py +++ b/src/backend/langflow/field_typing/constants.py @@ -3,7 +3,7 @@ from langchain.chains.base import Chain from langchain.document_loaders.base import BaseLoader from langchain.llms.base import BaseLLM, BaseLanguageModel from langchain.memory.chat_memory import BaseChatMemory -from langchain.prompts import PromptTemplate +from langchain.prompts import PromptTemplate, ChatPromptTemplate from langchain.schema import BaseOutputParser, BaseRetriever, Document from langchain.schema.embeddings import Embeddings from langchain.schema.memory import BaseMemory @@ -27,6 +27,7 @@ LANGCHAIN_BASE_TYPES = { "BaseLLM": BaseLLM, "BaseLanguageModel": BaseLanguageModel, "PromptTemplate": PromptTemplate, + "ChatPromptTemplate": ChatPromptTemplate, "BaseLoader": BaseLoader, "Document": Document, "TextSplitter": TextSplitter, diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 440c93bb0..b1066eab0 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -5,7 +5,7 @@ from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader from langflow.services.getters import get_db_service -from langflow.interface.custom.utils import extract_inner_type +from langflow.interface.custom.utils import extract_inner_type, extract_union_types from langflow.utils import validate @@ -152,9 +152,7 @@ class CustomComponent(Component): return [return_type] if return_type in self.return_type_valid_list else [] # If the return type is a Union, then we need to parse it - return_type = return_type.replace("Union", "").replace("[", "").replace("]", "") - return_type = return_type.split(",") - return_type = [item.strip() for item in return_type] + return_type = extract_union_types(return_type) return [item for item in return_type if item in self.return_type_valid_list] @property diff --git a/src/backend/langflow/interface/custom/utils.py b/src/backend/langflow/interface/custom/utils.py index 99b0d4bc6..9560b0f01 100644 --- a/src/backend/langflow/interface/custom/utils.py +++ b/src/backend/langflow/interface/custom/utils.py @@ -8,3 +8,14 @@ def extract_inner_type(return_type: str) -> str: if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE): return match[1] return return_type + + +def extract_union_types(return_type: str) -> list[str]: + """ + Extracts the inner type from a type hint that is a list. + """ + # If the return type is a Union, then we need to parse it + return_type = return_type.replace("Union", "").replace("[", "").replace("]", "") + return_type = return_type.split(",") + return_type = [item.strip() for item in return_type] + return return_type