Refactor field typing in LLMChainComponent and
related modules
This commit is contained in:
parent
9b23dae530
commit
00634a2354
5 changed files with 26 additions and 7 deletions
|
|
@ -1,7 +1,12 @@
|
|||
from langflow import CustomComponent
|
||||
from langchain.chains import LLMChain
|
||||
from typing import Optional, Union, Callable
|
||||
from langflow.field_typing import PromptTemplate, BaseLanguageModel, BaseMemory, Chain
|
||||
from langflow.field_typing import (
|
||||
BasePromptTemplate,
|
||||
BaseLanguageModel,
|
||||
BaseMemory,
|
||||
Chain,
|
||||
)
|
||||
|
||||
|
||||
class LLMChainComponent(CustomComponent):
|
||||
|
|
@ -18,7 +23,7 @@ class LLMChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
prompt: PromptTemplate,
|
||||
prompt: BasePromptTemplate,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@
|
|||
from .constants import (
|
||||
Tool,
|
||||
PromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
BasePromptTemplate,
|
||||
Chain,
|
||||
BaseChatMemory,
|
||||
BaseLLM,
|
||||
|
|
@ -54,4 +56,6 @@ __all__ = [
|
|||
"Document",
|
||||
"AgentExecutor",
|
||||
"Callable",
|
||||
"BasePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langchain.chains.base import Chain
|
|||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.llms.base import BaseLLM, BaseLanguageModel
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts import PromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.memory import BaseMemory
|
||||
|
|
@ -27,6 +27,7 @@ LANGCHAIN_BASE_TYPES = {
|
|||
"BaseLLM": BaseLLM,
|
||||
"BaseLanguageModel": BaseLanguageModel,
|
||||
"PromptTemplate": PromptTemplate,
|
||||
"ChatPromptTemplate": ChatPromptTemplate,
|
||||
"BaseLoader": BaseLoader,
|
||||
"Document": Document,
|
||||
"TextSplitter": TextSplitter,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
|||
from langflow.interface.custom.component import Component
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.services.getters import get_db_service
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
from langflow.interface.custom.utils import extract_inner_type, extract_union_types
|
||||
|
||||
from langflow.utils import validate
|
||||
|
||||
|
|
@ -152,9 +152,7 @@ class CustomComponent(Component):
|
|||
return [return_type] if return_type in self.return_type_valid_list else []
|
||||
|
||||
# If the return type is a Union, then we need to parse it
|
||||
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
|
||||
return_type = return_type.split(",")
|
||||
return_type = [item.strip() for item in return_type]
|
||||
return_type = extract_union_types(return_type)
|
||||
return [item for item in return_type if item in self.return_type_valid_list]
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -8,3 +8,14 @@ def extract_inner_type(return_type: str) -> str:
|
|||
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
|
||||
return match[1]
|
||||
return return_type
|
||||
|
||||
|
||||
def extract_union_types(return_type: str) -> list[str]:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
"""
|
||||
# If the return type is a Union, then we need to parse it
|
||||
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
|
||||
return_type = return_type.split(",")
|
||||
return_type = [item.strip() for item in return_type]
|
||||
return return_type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue