refactor: Update RecursiveCharacterTextSplitterComponent to use new input classes
This commit updates the RecursiveCharacterTextSplitterComponent class in the RecursiveCharacterTextSplitter.py file to use the new input classes from langflow.inputs.inputs module. The StrInput class has been replaced with TextInput, and the Document class has been replaced with DataInput. This change improves code organization and ensures compatibility with the latest input classes.
This commit is contained in:
parent
2c2cc968e8
commit
7a089e7d71
1 changed files with 46 additions and 51 deletions
|
|
@ -1,52 +1,48 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs.inputs import DataInput, IntInput, TextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.template.field.base import Output
|
||||
from langflow.utils.util import build_loader_repr_from_data, unescape_string
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitterComponent(CustomComponent):
|
||||
class RecursiveCharacterTextSplitterComponent(Component):
|
||||
display_name: str = "Recursive Character Text Splitter"
|
||||
description: str = "Split text into chunks of a specified length."
|
||||
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"inputs": {
|
||||
"display_name": "Input",
|
||||
"info": "The texts to split.",
|
||||
"input_types": ["Document", "Data"],
|
||||
},
|
||||
"separators": {
|
||||
"display_name": "Separators",
|
||||
"info": 'The characters to split on.\nIf left empty defaults to ["\\n\\n", "\\n", " ", ""].',
|
||||
"is_list": True,
|
||||
},
|
||||
"chunk_size": {
|
||||
"display_name": "Chunk Size",
|
||||
"info": "The maximum length of each chunk.",
|
||||
"field_type": "int",
|
||||
"value": 1000,
|
||||
},
|
||||
"chunk_overlap": {
|
||||
"display_name": "Chunk Overlap",
|
||||
"info": "The amount of overlap between chunks.",
|
||||
"field_type": "int",
|
||||
"value": 200,
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
inputs = [
|
||||
IntInput(
|
||||
name="chunk_size",
|
||||
display_name="Chunk Size",
|
||||
info="The maximum length of each chunk.",
|
||||
value=1000,
|
||||
),
|
||||
IntInput(
|
||||
name="chunk_overlap",
|
||||
display_name="Chunk Overlap",
|
||||
info="The amount of overlap between chunks.",
|
||||
value=200,
|
||||
),
|
||||
DataInput(
|
||||
name="data_input",
|
||||
display_name="Input",
|
||||
info="The texts to split.",
|
||||
input_types=["Document", "Data"],
|
||||
),
|
||||
TextInput(
|
||||
name="separators",
|
||||
display_name="Separators",
|
||||
info='The characters to split on.\nIf left empty defaults to ["\\n\\n", "\\n", " ", ""].',
|
||||
is_list=True,
|
||||
),
|
||||
]
|
||||
outputs = [
|
||||
Output(display_name="Data", name="data", method="build"),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
inputs: list[Document],
|
||||
separators: Optional[list[str]] = None,
|
||||
chunk_size: Optional[int] = 1000,
|
||||
chunk_overlap: Optional[int] = 200,
|
||||
) -> list[Data]:
|
||||
def build(self) -> list[Data]:
|
||||
"""
|
||||
Split text into chunks of a specified length.
|
||||
|
||||
|
|
@ -54,31 +50,30 @@ class RecursiveCharacterTextSplitterComponent(CustomComponent):
|
|||
separators (list[str]): The characters to split on.
|
||||
chunk_size (int): The maximum length of each chunk.
|
||||
chunk_overlap (int): The amount of overlap between chunks.
|
||||
length_function (function): The function to use to calculate the length of the text.
|
||||
|
||||
Returns:
|
||||
list[str]: The chunks of text.
|
||||
"""
|
||||
|
||||
if separators == "":
|
||||
separators = None
|
||||
elif separators:
|
||||
if self.separators == "":
|
||||
self.separators = None
|
||||
elif self.separators:
|
||||
# check if the separators list has escaped characters
|
||||
# if there are escaped characters, unescape them
|
||||
separators = [unescape_string(x) for x in separators]
|
||||
self.separators = [unescape_string(x) for x in self.separators]
|
||||
|
||||
# Make sure chunk_size and chunk_overlap are ints
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = int(chunk_size)
|
||||
if isinstance(chunk_overlap, str):
|
||||
chunk_overlap = int(chunk_overlap)
|
||||
if isinstance(self.chunk_size, str):
|
||||
self.chunk_size = int(self.chunk_size)
|
||||
if isinstance(self.chunk_overlap, str):
|
||||
self.chunk_overlap = int(self.chunk_overlap)
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=self.separators,
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
documents = []
|
||||
for _input in inputs:
|
||||
for _input in self.data_input:
|
||||
if isinstance(_input, Data):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue