refactor: Update RecursiveCharacterTextSplitterComponent to use new input classes

This commit updates the RecursiveCharacterTextSplitterComponent class in the RecursiveCharacterTextSplitter.py file to use the new input classes from langflow.inputs.inputs module. The StrInput class has been replaced with TextInput, and the Document class has been replaced with DataInput. This change improves code organization and ensures compatibility with the latest input classes.
This commit is contained in:
ogabrielluiz 2024-06-19 01:04:05 -03:00
commit 7a089e7d71

View file

@ -1,52 +1,48 @@
from typing import Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langflow.custom import CustomComponent
from langflow.custom import Component
from langflow.inputs.inputs import DataInput, IntInput, TextInput
from langflow.schema import Data
from langflow.template.field.base import Output
from langflow.utils.util import build_loader_repr_from_data, unescape_string
class RecursiveCharacterTextSplitterComponent(CustomComponent):
class RecursiveCharacterTextSplitterComponent(Component):
display_name: str = "Recursive Character Text Splitter"
description: str = "Split text into chunks of a specified length."
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
def build_config(self):
return {
"inputs": {
"display_name": "Input",
"info": "The texts to split.",
"input_types": ["Document", "Data"],
},
"separators": {
"display_name": "Separators",
"info": 'The characters to split on.\nIf left empty defaults to ["\\n\\n", "\\n", " ", ""].',
"is_list": True,
},
"chunk_size": {
"display_name": "Chunk Size",
"info": "The maximum length of each chunk.",
"field_type": "int",
"value": 1000,
},
"chunk_overlap": {
"display_name": "Chunk Overlap",
"info": "The amount of overlap between chunks.",
"field_type": "int",
"value": 200,
},
"code": {"show": False},
}
inputs = [
IntInput(
name="chunk_size",
display_name="Chunk Size",
info="The maximum length of each chunk.",
value=1000,
),
IntInput(
name="chunk_overlap",
display_name="Chunk Overlap",
info="The amount of overlap between chunks.",
value=200,
),
DataInput(
name="data_input",
display_name="Input",
info="The texts to split.",
input_types=["Document", "Data"],
),
TextInput(
name="separators",
display_name="Separators",
info='The characters to split on.\nIf left empty defaults to ["\\n\\n", "\\n", " ", ""].',
is_list=True,
),
]
outputs = [
Output(display_name="Data", name="data", method="build"),
]
def build(
self,
inputs: list[Document],
separators: Optional[list[str]] = None,
chunk_size: Optional[int] = 1000,
chunk_overlap: Optional[int] = 200,
) -> list[Data]:
def build(self) -> list[Data]:
"""
Split text into chunks of a specified length.
@ -54,31 +50,30 @@ class RecursiveCharacterTextSplitterComponent(CustomComponent):
separators (list[str]): The characters to split on.
chunk_size (int): The maximum length of each chunk.
chunk_overlap (int): The amount of overlap between chunks.
length_function (function): The function to use to calculate the length of the text.
Returns:
list[str]: The chunks of text.
"""
if separators == "":
separators = None
elif separators:
if self.separators == "":
self.separators = None
elif self.separators:
# check if the separators list has escaped characters
# if there are escaped characters, unescape them
separators = [unescape_string(x) for x in separators]
self.separators = [unescape_string(x) for x in self.separators]
# Make sure chunk_size and chunk_overlap are ints
if isinstance(chunk_size, str):
chunk_size = int(chunk_size)
if isinstance(chunk_overlap, str):
chunk_overlap = int(chunk_overlap)
if isinstance(self.chunk_size, str):
self.chunk_size = int(self.chunk_size)
if isinstance(self.chunk_overlap, str):
self.chunk_overlap = int(self.chunk_overlap)
splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=self.separators,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
documents = []
for _input in inputs:
for _input in self.data_input:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else: