Refactor field typing in LLMChainComponent and

related modules
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-03 16:38:56 -03:00
commit 00634a2354
5 changed files with 26 additions and 7 deletions

View file

@ -1,7 +1,12 @@
from langflow import CustomComponent
from langchain.chains import LLMChain
from typing import Optional, Union, Callable
from langflow.field_typing import PromptTemplate, BaseLanguageModel, BaseMemory, Chain
from langflow.field_typing import (
BasePromptTemplate,
BaseLanguageModel,
BaseMemory,
Chain,
)
class LLMChainComponent(CustomComponent):
@ -18,7 +23,7 @@ class LLMChainComponent(CustomComponent):
def build(
self,
prompt: PromptTemplate,
prompt: BasePromptTemplate,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Union[Chain, Callable]:

View file

@ -17,6 +17,8 @@
from .constants import (
Tool,
PromptTemplate,
ChatPromptTemplate,
BasePromptTemplate,
Chain,
BaseChatMemory,
BaseLLM,
@ -54,4 +56,6 @@ __all__ = [
"Document",
"AgentExecutor",
"Callable",
"BasePromptTemplate",
"ChatPromptTemplate",
]

View file

@ -3,7 +3,7 @@ from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain.llms.base import BaseLLM, BaseLanguageModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain.prompts import PromptTemplate
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain.schema import BaseOutputParser, BaseRetriever, Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.memory import BaseMemory
@ -27,6 +27,7 @@ LANGCHAIN_BASE_TYPES = {
"BaseLLM": BaseLLM,
"BaseLanguageModel": BaseLanguageModel,
"PromptTemplate": PromptTemplate,
"ChatPromptTemplate": ChatPromptTemplate,
"BaseLoader": BaseLoader,
"Document": Document,
"TextSplitter": TextSplitter,

View file

@ -5,7 +5,7 @@ from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.services.getters import get_db_service
from langflow.interface.custom.utils import extract_inner_type
from langflow.interface.custom.utils import extract_inner_type, extract_union_types
from langflow.utils import validate
@ -152,9 +152,7 @@ class CustomComponent(Component):
return [return_type] if return_type in self.return_type_valid_list else []
# If the return type is a Union, then we need to parse it
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
return_type = return_type.split(",")
return_type = [item.strip() for item in return_type]
return_type = extract_union_types(return_type)
return [item for item in return_type if item in self.return_type_valid_list]
@property

View file

@ -8,3 +8,14 @@ def extract_inner_type(return_type: str) -> str:
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
return match[1]
return return_type
def extract_union_types(return_type: str) -> list[str]:
"""
Extracts the inner type from a type hint that is a list.
"""
# If the return type is a Union, then we need to parse it
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
return_type = return_type.split(",")
return_type = [item.strip() for item in return_type]
return return_type