From 8adf26e8dea0f5c898289e763850aca43d83a865 Mon Sep 17 00:00:00 2001 From: Ibis Prevedello Date: Fri, 7 Apr 2023 16:10:44 -0300 Subject: [PATCH] feat: add text splitter --- src/backend/langflow/config.yaml | 3 ++ src/backend/langflow/graph/base.py | 2 - src/backend/langflow/graph/graph.py | 3 ++ src/backend/langflow/graph/nodes.py | 5 +++ .../langflow/interface/custom_lists.py | 13 ++++-- .../interface/documentLoaders/base.py | 1 - .../interface/documentLoaders/custom.py | 20 ---------- .../langflow/interface/importing/utils.py | 6 +++ src/backend/langflow/interface/listing.py | 8 ++++ src/backend/langflow/interface/loading.py | 6 +++ .../interface/textSplitters/__init__.py | 3 ++ .../langflow/interface/textSplitters/base.py | 40 +++++++++++++++++++ src/backend/langflow/interface/types.py | 2 + .../langflow/interface/vectorStore/base.py | 4 +- src/backend/langflow/settings.py | 2 + src/frontend/src/utils.ts | 3 ++ 16 files changed, 93 insertions(+), 28 deletions(-) create mode 100644 src/backend/langflow/interface/textSplitters/__init__.py create mode 100644 src/backend/langflow/interface/textSplitters/base.py diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index be7e82099..8bd929036 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -63,4 +63,7 @@ documentloaders: - TextLoader - WebBaseLoader +textsplitters: + - CharacterTextSplitter + dev: false diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 57bf797eb..2a74a8233 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -153,8 +153,6 @@ class Node: result = result.run # type: ignore elif hasattr(result, "get_function"): result = result.get_function() # type: ignore - elif value.base_type == "documentloaders": - result = result.load() self.params[key] = result elif isinstance(value, list) and all( diff --git a/src/backend/langflow/graph/graph.py b/src/backend/langflow/graph/graph.py index a7f908ef2..4045ee329 100644 --- a/src/backend/langflow/graph/graph.py +++ b/src/backend/langflow/graph/graph.py @@ -10,6 +10,7 @@ from langflow.graph.nodes import ( LLMNode, MemoryNode, PromptNode, + TextSplitterNode, ToolkitNode, ToolNode, VectorStoreNode, @@ -22,6 +23,7 @@ from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator from langflow.interface.memories.base import memory_creator from langflow.interface.prompts.base import prompt_creator +from langflow.interface.textSplitters.base import textsplitter_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.tools.constants import FILE_TOOLS @@ -126,6 +128,7 @@ class Graph: **{t: EmbeddingNode for t in embedding_creator.to_list()}, **{t: VectorStoreNode for t in vectorstore_creator.to_list()}, **{t: DocumentLoaderNode for t in documentloader_creator.to_list()}, + **{t: TextSplitterNode for t in textsplitter_creator.to_list()}, } if node_type in FILE_TOOLS: diff --git a/src/backend/langflow/graph/nodes.py b/src/backend/langflow/graph/nodes.py index 0d90fe333..cda327e6d 100644 --- a/src/backend/langflow/graph/nodes.py +++ b/src/backend/langflow/graph/nodes.py @@ -147,3 +147,8 @@ class VectorStoreNode(Node): class MemoryNode(Node): def __init__(self, data: Dict): super().__init__(data, base_type="memory") + + +class TextSplitterNode(Node): + def __init__(self, data: Dict): + super().__init__(data, base_type="textsplitters") diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 746c58325..2ffc18b14 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -1,6 +1,6 @@ +import inspect from typing import Any -## LLM from langchain import ( chains, document_loaders, @@ -8,6 +8,7 @@ from langchain import ( llms, memory, requests, + text_splitter, vectorstores, ) from langchain.agents import agent_toolkits @@ -15,16 +16,17 @@ from langchain.chat_models import ChatOpenAI from langflow.interface.importing.utils import import_class -## LLM +## LLMs llm_type_to_cls_dict = llms.type_to_cls_dict llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore -## Chain +## Chains chain_type_to_cls_dict: dict[str, Any] = { chain_name: import_class(f"langchain.chains.{chain_name}") for chain_name in chains.__all__ } +## Toolkits toolkit_type_to_loader_dict: dict[str, Any] = { toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}") # if toolkit_name is lower case it is a loader @@ -69,3 +71,8 @@ documentloaders_type_to_cls_dict: dict[str, Any] = { ) for documentloader_name in document_loaders.__all__ } + +## Text Splitters +textsplitter_type_to_cls_dict: dict[str, Any] = dict( + inspect.getmembers(text_splitter, inspect.isclass) +) diff --git a/src/backend/langflow/interface/documentLoaders/base.py b/src/backend/langflow/interface/documentLoaders/base.py index 5aa729b62..9510a1d9e 100644 --- a/src/backend/langflow/interface/documentLoaders/base.py +++ b/src/backend/langflow/interface/documentLoaders/base.py @@ -16,7 +16,6 @@ class DocumentLoaderCreator(LangChainTypeCreator): # Drop some types that are reimplemented with the same name types.pop("TextLoader") - types.pop("WebBaseLoader") for name, documentloader in CUSTOM_DOCUMENTLOADERS.items(): types[name] = documentloader diff --git a/src/backend/langflow/interface/documentLoaders/custom.py b/src/backend/langflow/interface/documentLoaders/custom.py index f142314fa..22baac582 100644 --- a/src/backend/langflow/interface/documentLoaders/custom.py +++ b/src/backend/langflow/interface/documentLoaders/custom.py @@ -3,8 +3,6 @@ from typing import List from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from langchain.document_loaders.web_base import WebBaseLoader as LCWebBaseLoader -from langchain.text_splitter import CharacterTextSplitter class TextLoader(BaseLoader): @@ -18,25 +16,7 @@ class TextLoader(BaseLoader): """Load from file path.""" documents = [Document(page_content=self.file, metadata={"source": "loaded"})] - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - - return text_splitter.split_documents(documents) - - -class WebBaseLoader(LCWebBaseLoader): - def load(self) -> List[Document]: - """Load data into document objects.""" - soup = self.scrape() - text = soup.get_text() - metadata = {"source": self.web_path} - documents = [Document(page_content=text, metadata=metadata)] - - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - - return text_splitter.split_documents(documents) - CUSTOM_DOCUMENTLOADERS = { "TextLoader": TextLoader, - "WebBaseLoader": WebBaseLoader, } diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 62f81a90a..f72ded60b 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -44,6 +44,7 @@ def import_by_type(_type: str, name: str) -> Any: "embeddings": import_embedding, "vectorstores": import_vectorstore, "documentloaders": import_documentloader, + "textsplitters": import_textsplitter, } if _type == "llms": key = "chat" if "chat" in name.lower() else "llm" @@ -135,3 +136,8 @@ def import_documentloader(documentloader: str) -> Any: return CUSTOM_DOCUMENTLOADERS[documentloader] return import_class(f"langchain.document_loaders.{documentloader}") + + +def import_textsplitter(textsplitter: str) -> Any: + """Import textsplitter from textsplitter name""" + return import_class(f"langchain.text_splitter.{textsplitter}") diff --git a/src/backend/langflow/interface/listing.py b/src/backend/langflow/interface/listing.py index b11b3cef9..bf9558e92 100644 --- a/src/backend/langflow/interface/listing.py +++ b/src/backend/langflow/interface/listing.py @@ -1,10 +1,14 @@ from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator +from langflow.interface.documentLoaders.base import documentloader_creator +from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator from langflow.interface.memories.base import memory_creator from langflow.interface.prompts.base import prompt_creator +from langflow.interface.textSplitters.base import textsplitter_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator +from langflow.interface.vectorStore.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator @@ -18,6 +22,10 @@ def get_type_dict(): "memory": memory_creator.to_list(), "toolkits": toolkits_creator.to_list(), "wrappers": wrapper_creator.to_list(), + "documentLoaders": documentloader_creator.to_list(), + "vectorStore": vectorstore_creator.to_list(), + "embeddings": embedding_creator.to_list(), + "textSplitters": textsplitter_creator.to_list(), } diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 4426644b9..292f3d944 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -62,6 +62,12 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: return class_object(**params) elif base_type == "vectorstores": return class_object.from_documents(**params) + elif base_type == "documentloaders": + return class_object(**params).load() + elif base_type == "textsplitters": + documents = params.pop("documents") + text_splitter = class_object(**params) + return text_splitter.split_documents(documents) else: return class_object(**params) diff --git a/src/backend/langflow/interface/textSplitters/__init__.py b/src/backend/langflow/interface/textSplitters/__init__.py new file mode 100644 index 000000000..9da97b697 --- /dev/null +++ b/src/backend/langflow/interface/textSplitters/__init__.py @@ -0,0 +1,3 @@ +from langflow.interface.textSplitters.base import TextSplitterCreator + +__all__ = ["TextSplitterCreator"] diff --git a/src/backend/langflow/interface/textSplitters/base.py b/src/backend/langflow/interface/textSplitters/base.py new file mode 100644 index 000000000..bbc8ab863 --- /dev/null +++ b/src/backend/langflow/interface/textSplitters/base.py @@ -0,0 +1,40 @@ +from typing import Dict, List, Optional + +from langflow.interface.base import LangChainTypeCreator +from langflow.interface.custom_lists import textsplitter_type_to_cls_dict +from langflow.settings import settings +from langflow.utils.util import build_template_from_class + + +class TextSplitterCreator(LangChainTypeCreator): + type_name: str = "textsplitters" + + @property + def type_to_loader_dict(self) -> Dict: + return textsplitter_type_to_cls_dict + + def get_signature(self, name: str) -> Optional[Dict]: + """Get the signature of a text splitter.""" + try: + signature = build_template_from_class(name, textsplitter_type_to_cls_dict) + + signature["template"]["documents"] = { + "type": "BaseLoader", + "required": True, + "show": True, + "name": "documents", + } + + return signature + except ValueError as exc: + raise ValueError(f"Text Splitter {name} not found") from exc + + def to_list(self) -> List[str]: + return [ + textsplitter.__name__ + for textsplitter in self.type_to_loader_dict.values() + if textsplitter.__name__ in settings.textsplitters or settings.dev + ] + + +textsplitter_creator = TextSplitterCreator() diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index bf3cea372..61bfdfa7b 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -5,6 +5,7 @@ from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator from langflow.interface.memories.base import memory_creator from langflow.interface.prompts.base import prompt_creator +from langflow.interface.textSplitters.base import textsplitter_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.vectorStore.base import vectorstore_creator @@ -40,6 +41,7 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union embedding_creator, vectorstore_creator, documentloader_creator, + textsplitter_creator, ] all_types = {} diff --git a/src/backend/langflow/interface/vectorStore/base.py b/src/backend/langflow/interface/vectorStore/base.py index 10e3a8768..15dfd2886 100644 --- a/src/backend/langflow/interface/vectorStore/base.py +++ b/src/backend/langflow/interface/vectorStore/base.py @@ -20,11 +20,11 @@ class VectorstoreCreator(LangChainTypeCreator): signature["template"] = { "documents": { - "type": "BaseLoader", + "type": "TextSplitter", "required": True, "show": True, "name": "documents", - "display_name": "Document Loader", + "display_name": "Text Splitter", }, "embedding": { "type": "Embeddings", diff --git a/src/backend/langflow/settings.py b/src/backend/langflow/settings.py index c0fd66c58..c5377c85a 100644 --- a/src/backend/langflow/settings.py +++ b/src/backend/langflow/settings.py @@ -17,6 +17,7 @@ class Settings(BaseSettings): documentloaders: List[str] = [] wrappers: List[str] = [] toolkits: List[str] = [] + textsplitters: List[str] = [] dev: bool = False class Config: @@ -40,6 +41,7 @@ class Settings(BaseSettings): self.memories = new_settings.memories or [] self.wrappers = new_settings.wrappers or [] self.toolkits = new_settings.toolkits or [] + self.textsplitters = new_settings.textsplitters or [] self.dev = new_settings.dev or False diff --git a/src/frontend/src/utils.ts b/src/frontend/src/utils.ts index ab95e5cad..090827a7a 100644 --- a/src/frontend/src/utils.ts +++ b/src/frontend/src/utils.ts @@ -79,6 +79,7 @@ export const nodeColors: {[char: string]: string} = { embeddings:"#FF9135", documentloaders:"#FF9135", vectorstores: "#FF9135", + textsplitters: "#FF9135", toolkits:"#DB2C2C", wrappers:"#E6277A", unknown:"#9CA3AF" @@ -98,6 +99,7 @@ export const nodeNames:{[char: string]: string} = { vectorstores: "Vector Stores", toolkits:"Toolkits", wrappers:"Wrappers", + textsplitters: "Text Splitters", unknown:"Unknown" }; @@ -114,6 +116,7 @@ export const nodeIcons:{[char: string]: React.ForwardRefExoticComponent