feat: migrate text splitters to Component syntax (#2530)
* feat: migrate text splitters to Component syntax * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
cb8185237a
commit
86aaab0cec
8 changed files with 2923 additions and 2390 deletions
0
src/backend/base/langflow/base/textsplitters/__init__.py
Normal file
0
src/backend/base/langflow/base/textsplitters/__init__.py
Normal file
58
src/backend/base/langflow/base/textsplitters/model.py
Normal file
58
src/backend/base/langflow/base/textsplitters/model.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.io import Output
|
||||
from langflow.schema import Data
|
||||
from langflow.utils.util import build_loader_repr_from_data
|
||||
|
||||
|
||||
class LCTextSplitterComponent(Component):
|
||||
trace_type = "text_splitter"
|
||||
outputs = [
|
||||
Output(display_name="Data", name="data", method="split_data"),
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
required_output_methods = ["text_splitter"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
if method_name not in output_names:
|
||||
raise ValueError(f"Output with name '{method_name}' must be defined.")
|
||||
elif not hasattr(self, method_name):
|
||||
raise ValueError(f"Method '{method_name}' must be defined.")
|
||||
|
||||
def split_data(self) -> list[Data]:
|
||||
data_input = self.get_data_input()
|
||||
documents = []
|
||||
|
||||
if not isinstance(data_input, list):
|
||||
data_input: list[Any] = [data_input]
|
||||
|
||||
for _input in data_input:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
|
||||
splitter = self.build_text_splitter()
|
||||
docs = splitter.split_documents(documents)
|
||||
data = self.to_data(docs)
|
||||
self.repr_value = build_loader_repr_from_data(data)
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def get_data_input(self) -> Any:
|
||||
"""
|
||||
Get the data input.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
"""
|
||||
Build the text splitter.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -1,24 +1,58 @@
|
|||
from typing import List
|
||||
from typing import List, Any
|
||||
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from langchain_text_splitters import CharacterTextSplitter, TextSplitter
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.base.textsplitters.model import LCTextSplitterComponent
|
||||
from langflow.inputs import IntInput, DataInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.utils.util import unescape_string
|
||||
|
||||
|
||||
class CharacterTextSplitterComponent(CustomComponent):
|
||||
class CharacterTextSplitterComponent(LCTextSplitterComponent):
|
||||
display_name = "CharacterTextSplitter"
|
||||
description = "Splitting text that looks at characters."
|
||||
description = "Split text by number of characters."
|
||||
documentation = "https://docs.langflow.org/components/text-splitters#charactertextsplitter"
|
||||
name = "CharacterTextSplitter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"inputs": {"display_name": "Input", "input_types": ["Document", "Data"]},
|
||||
"chunk_overlap": {"display_name": "Chunk Overlap", "default": 200},
|
||||
"chunk_size": {"display_name": "Chunk Size", "default": 1000},
|
||||
"separator": {"display_name": "Separator", "default": "\n"},
|
||||
}
|
||||
inputs = [
|
||||
IntInput(
|
||||
name="chunk_size",
|
||||
display_name="Chunk Size",
|
||||
info="The maximum length of each chunk.",
|
||||
value=1000,
|
||||
),
|
||||
IntInput(
|
||||
name="chunk_overlap",
|
||||
display_name="Chunk Overlap",
|
||||
info="The amount of overlap between chunks.",
|
||||
value=200,
|
||||
),
|
||||
DataInput(
|
||||
name="data_input",
|
||||
display_name="Input",
|
||||
info="The texts to split.",
|
||||
input_types=["Document", "Data"],
|
||||
),
|
||||
MessageTextInput(
|
||||
name="separator",
|
||||
display_name="Separator",
|
||||
info='The characters to split on.\nIf left empty defaults to "\\n\\n".',
|
||||
),
|
||||
]
|
||||
|
||||
def get_data_input(self) -> Any:
|
||||
return self.data_input
|
||||
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
if self.separator:
|
||||
separator = unescape_string(self.separator)
|
||||
else:
|
||||
separator = "\n\n"
|
||||
return CharacterTextSplitter(
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
chunk_size=self.chunk_size,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,85 +1,47 @@
|
|||
from typing import List, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
||||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.schema import Data
|
||||
from langflow.base.textsplitters.model import LCTextSplitterComponent
|
||||
from langflow.inputs import IntInput, DataInput, DropdownInput
|
||||
|
||||
|
||||
class LanguageRecursiveTextSplitterComponent(CustomComponent):
|
||||
class LanguageRecursiveTextSplitterComponent(LCTextSplitterComponent):
|
||||
display_name: str = "Language Recursive Text Splitter"
|
||||
description: str = "Split text into chunks of a specified length based on language."
|
||||
documentation: str = "https://docs.langflow.org/components/text-splitters#languagerecursivetextsplitter"
|
||||
name = "LanguageRecursiveTextSplitter"
|
||||
|
||||
def build_config(self):
|
||||
options = [x.value for x in Language]
|
||||
return {
|
||||
"inputs": {"display_name": "Input", "input_types": ["Document", "Data"]},
|
||||
"separator_type": {
|
||||
"display_name": "Separator Type",
|
||||
"info": "The type of separator to use.",
|
||||
"field_type": "str",
|
||||
"options": options,
|
||||
"value": "Python",
|
||||
},
|
||||
"separators": {
|
||||
"display_name": "Separators",
|
||||
"info": "The characters to split on.",
|
||||
"is_list": True,
|
||||
},
|
||||
"chunk_size": {
|
||||
"display_name": "Chunk Size",
|
||||
"info": "The maximum length of each chunk.",
|
||||
"field_type": "int",
|
||||
"value": 1000,
|
||||
},
|
||||
"chunk_overlap": {
|
||||
"display_name": "Chunk Overlap",
|
||||
"info": "The amount of overlap between chunks.",
|
||||
"field_type": "int",
|
||||
"value": 200,
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
inputs = [
|
||||
IntInput(
|
||||
name="chunk_size",
|
||||
display_name="Chunk Size",
|
||||
info="The maximum length of each chunk.",
|
||||
value=1000,
|
||||
),
|
||||
IntInput(
|
||||
name="chunk_overlap",
|
||||
display_name="Chunk Overlap",
|
||||
info="The amount of overlap between chunks.",
|
||||
value=200,
|
||||
),
|
||||
DataInput(
|
||||
name="data_input",
|
||||
display_name="Input",
|
||||
info="The texts to split.",
|
||||
input_types=["Document", "Data"],
|
||||
),
|
||||
DropdownInput(
|
||||
name="code_language", display_name="Code Language", options=[x.value for x in Language], value="python"
|
||||
),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
inputs: List[Data],
|
||||
chunk_size: Optional[int] = 1000,
|
||||
chunk_overlap: Optional[int] = 200,
|
||||
separator_type: str = "Python",
|
||||
) -> list[Data]:
|
||||
"""
|
||||
Split text into chunks of a specified length.
|
||||
def get_data_input(self) -> Any:
|
||||
return self.data_input
|
||||
|
||||
Args:
|
||||
separators (list[str]): The characters to split on.
|
||||
chunk_size (int): The maximum length of each chunk.
|
||||
chunk_overlap (int): The amount of overlap between chunks.
|
||||
length_function (function): The function to use to calculate the length of the text.
|
||||
|
||||
Returns:
|
||||
list[str]: The chunks of text.
|
||||
"""
|
||||
|
||||
# Make sure chunk_size and chunk_overlap are ints
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = int(chunk_size)
|
||||
if isinstance(chunk_overlap, str):
|
||||
chunk_overlap = int(chunk_overlap)
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
language=Language(separator_type),
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
return RecursiveCharacterTextSplitter.from_language(
|
||||
language=Language(self.code_language),
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
docs = splitter.split_documents(documents)
|
||||
data = self.to_data(docs)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from langflow.custom import Component
|
||||
from typing import Any
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langflow.base.textsplitters.model import LCTextSplitterComponent
|
||||
from langflow.inputs.inputs import DataInput, IntInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.template.field.base import Output
|
||||
from langflow.utils.util import build_loader_repr_from_data, unescape_string
|
||||
from langflow.utils.util import unescape_string
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitterComponent(Component):
|
||||
class RecursiveCharacterTextSplitterComponent(LCTextSplitterComponent):
|
||||
display_name: str = "Recursive Character Text Splitter"
|
||||
description: str = "Split text into chunks of a specified length."
|
||||
description: str = "Split text trying to keep all related text together."
|
||||
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
|
||||
name = "RecursiveCharacterTextSplitter"
|
||||
|
||||
|
|
@ -39,49 +37,20 @@ class RecursiveCharacterTextSplitterComponent(Component):
|
|||
is_list=True,
|
||||
),
|
||||
]
|
||||
outputs = [
|
||||
Output(display_name="Data", name="data", method="split_data"),
|
||||
]
|
||||
|
||||
def split_data(self) -> list[Data]:
|
||||
"""
|
||||
Split text into chunks of a specified length.
|
||||
def get_data_input(self) -> Any:
|
||||
return self.data_input
|
||||
|
||||
Args:
|
||||
separators (list[str] | None): The characters to split on.
|
||||
chunk_size (int): The maximum length of each chunk.
|
||||
chunk_overlap (int): The amount of overlap between chunks.
|
||||
|
||||
Returns:
|
||||
list[str]: The chunks of text.
|
||||
"""
|
||||
|
||||
if self.separators == "":
|
||||
self.separators: list[str] | None = None
|
||||
elif self.separators:
|
||||
def build_text_splitter(self) -> TextSplitter:
|
||||
if not self.separators:
|
||||
separators: list[str] | None = None
|
||||
else:
|
||||
# check if the separators list has escaped characters
|
||||
# if there are escaped characters, unescape them
|
||||
self.separators = [unescape_string(x) for x in self.separators]
|
||||
separators = [unescape_string(x) for x in self.separators]
|
||||
|
||||
# Make sure chunk_size and chunk_overlap are ints
|
||||
if self.chunk_size:
|
||||
self.chunk_size: int = int(self.chunk_size)
|
||||
if self.chunk_overlap:
|
||||
self.chunk_overlap: int = int(self.chunk_overlap)
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
separators=self.separators,
|
||||
return RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
documents = []
|
||||
if not isinstance(self.data_input, list):
|
||||
self.data_input: list[Data] = [self.data_input]
|
||||
for _input in self.data_input:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
documents.append(_input)
|
||||
docs = splitter.split_documents(documents)
|
||||
data = self.to_data(docs)
|
||||
self.repr_value = build_loader_repr_from_data(data)
|
||||
return data
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue