diff --git a/tests/conftest.py b/tests/conftest.py index 328a168ad..79704e3b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -250,7 +250,7 @@ class CSVLoaderComponent(CustomComponent): "column_name": {"field_type": "str", "required": True}, } - def build(self, filename: str, column_name: str) -> List[Document]: + def build(self, filename: str, column_name: str) -> Document: # Load the CSV file df = pd.read_csv(filename) @@ -265,3 +265,23 @@ class CSVLoaderComponent(CustomComponent): documents.append(Document(page_content=str(content), metadata=metadata)) return documents""" + + +@pytest.fixture +def filter_docs(): + return """from langchain.schema import Document +from langflow.interface.custom.base import CustomComponent +from typing import List + +class DocumentFilterByLengthComponent(CustomComponent): + display_name: str = "Document Filter By Length" + field_config = { + "documents": {"field_type": "Document", "required": True}, + "max_length": {"field_type": "int", "required": True}, + } + + def build(self, documents: List[Document], max_length: int) -> List[Document]: + # Filter the documents by length + filtered_documents = [doc for doc in documents if len(doc.page_content) <= max_length] + + return filtered_documents"""