Add optional typing for List in GatherRecordsComponent

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-28 13:41:04 -03:00
commit 9cada18943

View file

@ -1,6 +1,6 @@
from concurrent import futures
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from langflow import CustomComponent
from langflow.schema import Record
@ -12,21 +12,30 @@ class GatherRecordsComponent(CustomComponent):
def build_config(self) -> Dict[str, Any]:
return {
"path": {"display_name": "Path"},
"types": {
"display_name": "Types",
"info": "File types to load. Leave empty to load all types.",
},
"depth": {"display_name": "Depth", "info": "Depth to search for files."},
"max_concurrency": {"display_name": "Max Concurrency", "advanced": True},
"load_hidden": {
"display_name": "Load Hidden Files",
"value": False,
"display_name": "Load Hidden",
"advanced": True,
"info": "If true, hidden files will be loaded.",
},
"max_concurrency": {
"display_name": "Max Concurrency",
"value": 10,
"recursive": {
"display_name": "Recursive",
"advanced": True,
"info": "If true, the search will be recursive.",
},
"silent_errors": {
"display_name": "Silent Errors",
"advanced": True,
"info": "If true, errors will not raise an exception.",
},
"path": {"display_name": "Local Directory"},
"recursive": {"display_name": "Recursive", "value": True, "advanced": True},
"use_multithreading": {
"display_name": "Use Multithreading",
"value": True,
"advanced": True,
},
}
@ -61,7 +70,9 @@ class GatherRecordsComponent(CustomComponent):
glob = "**/*" if recursive else "*"
paths = walk_level(path_obj, depth) if depth else path_obj.glob(glob)
file_paths = [str(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)]
file_paths = [
str(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)
]
return file_paths
@ -91,13 +102,20 @@ class GatherRecordsComponent(CustomComponent):
use_multithreading: bool,
) -> List[Record]:
if use_multithreading:
records = self.parallel_load_records(file_paths, silent_errors, max_concurrency)
records = self.parallel_load_records(
file_paths, silent_errors, max_concurrency
)
else:
records = [self.parse_file_to_record(file_path, silent_errors) for file_path in file_paths]
records = [
self.parse_file_to_record(file_path, silent_errors)
for file_path in file_paths
]
records = list(filter(None, records))
return records
def parallel_load_records(self, file_paths: List[str], silent_errors: bool, max_concurrency: int) -> List[Record]:
def parallel_load_records(
self, file_paths: List[str], silent_errors: bool, max_concurrency: int
) -> List[Record]:
with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
loaded_files = executor.map(
lambda file_path: self.parse_file_to_record(file_path, silent_errors),
@ -108,7 +126,7 @@ class GatherRecordsComponent(CustomComponent):
def build(
self,
path: str,
types: List[str] = None,
types: Optional[List[str]] = None,
depth: int = 0,
max_concurrency: int = 2,
load_hidden: bool = False,
@ -116,14 +134,23 @@ class GatherRecordsComponent(CustomComponent):
silent_errors: bool = False,
use_multithreading: bool = True,
) -> List[Record]:
if types is None:
types = []
resolved_path = self.resolve_path(path)
file_paths = self.retrieve_file_paths(resolved_path, types, load_hidden, recursive, depth)
file_paths = self.retrieve_file_paths(
resolved_path, types, load_hidden, recursive, depth
)
loaded_records = []
if use_multithreading:
loaded_records = self.parallel_load_records(file_paths, silent_errors, max_concurrency)
loaded_records = self.parallel_load_records(
file_paths, silent_errors, max_concurrency
)
else:
loaded_records = [self.parse_file_to_record(file_path, silent_errors) for file_path in file_paths]
loaded_records = [
self.parse_file_to_record(file_path, silent_errors)
for file_path in file_paths
]
loaded_records = list(filter(None, loaded_records))
self.status = loaded_records
return loaded_records