From 26bc47c7006497e9f953a44ac332bc2d2f3170f8 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 11 Oct 2024 13:00:37 +0000 Subject: [PATCH] lint --- src/eva/core/data/dataloaders/dataloader.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/eva/core/data/dataloaders/dataloader.py b/src/eva/core/data/dataloaders/dataloader.py index 232a3f1e..70040c61 100644 --- a/src/eva/core/data/dataloaders/dataloader.py +++ b/src/eva/core/data/dataloaders/dataloader.py @@ -38,6 +38,12 @@ class DataLoader: Mutually exclusive with `batch_size`, `shuffle`, `sampler` and `drop_last`. """ + num_workers: int | None = None + """How many workers to use for loading the data. + + By default, it will use the number of CPUs available. + """ + collate_fn: Callable | None = None """The batching process.""" @@ -53,16 +59,6 @@ class DataLoader: prefetch_factor: int | None = 2 """Number of batches loaded in advance by each worker.""" - num_workers: int | None = dataclasses.field(default=None) - """How many workers to use for loading the data. - - By default, it will use the number of CPUs available. - """ - - def __post_init__(self): - if self.num_workers is None: - self.num_workers = multiprocessing.cpu_count() - def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader: """Returns the dataloader on the provided dataset. @@ -75,7 +71,7 @@ def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader: shuffle=self.shuffle, sampler=self.sampler, batch_sampler=self.batch_sampler, - num_workers=self.num_workers, + num_workers=self.num_workers or multiprocessing.cpu_count(), collate_fn=self.collate_fn, pin_memory=self.pin_memory, drop_last=self.drop_last,