From 0022cd832896390cbac74b907fdafa2f339ecc7c Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Wed, 20 Mar 2024 11:22:26 +0100 Subject: [PATCH] added check if dataset exists & raise error if not --- src/eva/vision/data/datasets/_validators.py | 10 ++++++++++ src/eva/vision/data/datasets/classification/bach.py | 9 +++++---- src/eva/vision/data/datasets/classification/crc.py | 7 ++++--- src/eva/vision/data/datasets/classification/mhist.py | 4 ++++ .../data/datasets/classification/patch_camelyon.py | 1 + .../data/datasets/classification/total_segmentator.py | 1 + 6 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/eva/vision/data/datasets/_validators.py b/src/eva/vision/data/datasets/_validators.py index ce0755e6..0a727411 100644 --- a/src/eva/vision/data/datasets/_validators.py +++ b/src/eva/vision/data/datasets/_validators.py @@ -1,5 +1,7 @@ """Dataset validation related functions.""" +import os + from typing_extensions import List, Tuple from eva.vision.data.datasets import vision @@ -42,3 +44,11 @@ def check_dataset_integrity( f"({(dataset_classes[0], dataset_classes[-1])}) does not match the expected " f"ones ({first_and_last_labels}). {_SUFFIX_ERROR_MESSAGE}" ) + + +def check_dataset_exists(dataset_dir: str, download_available: bool) -> None: + if not os.path.isdir(dataset_dir): + error_message = "Dataset not found at '{dataset_dir}'." + if download_available: + error_message += " You can set `download=True` to download the dataset automatically." + raise FileNotFoundError(error_message) diff --git a/src/eva/vision/data/datasets/classification/bach.py b/src/eva/vision/data/datasets/classification/bach.py index c7900005..935ab609 100644 --- a/src/eva/vision/data/datasets/classification/bach.py +++ b/src/eva/vision/data/datasets/classification/bach.py @@ -96,24 +96,25 @@ def class_to_idx(self) -> Dict[str, int]: return {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3} @property - def dataset_path(self) -> str: + def _dataset_path(self) -> str: """Returns the path of the image data of the dataset.""" return os.path.join(self._root, "ICIAR2018_BACH_Challenge", "Photos") @override def filename(self, index: int) -> str: image_path, _ = self._samples[self._indices[index]] - return os.path.relpath(image_path, self.dataset_path) + return os.path.relpath(image_path, self._dataset_path) @override def prepare_data(self) -> None: if self._download: self._download_dataset() + _validators.check_dataset_exists(self._root, True) @override def configure(self) -> None: self._samples = folder.make_dataset( - directory=self.dataset_path, + directory=self._dataset_path, class_to_idx=self.class_to_idx, extensions=(".tif"), ) @@ -145,7 +146,7 @@ def __len__(self) -> int: def _download_dataset(self) -> None: """Downloads the dataset.""" for resource in self._resources: - if os.path.isdir(self.dataset_path): + if os.path.isdir(self._dataset_path): continue self._print_license() diff --git a/src/eva/vision/data/datasets/classification/crc.py b/src/eva/vision/data/datasets/classification/crc.py index 4717e0be..5c661d45 100644 --- a/src/eva/vision/data/datasets/classification/crc.py +++ b/src/eva/vision/data/datasets/classification/crc.py @@ -95,12 +95,13 @@ def class_to_idx(self) -> Dict[str, int]: @override def filename(self, index: int) -> str: image_path, *_ = self._samples[index] - return os.path.relpath(image_path, self._dataset_dir) + return os.path.relpath(image_path, self._dataset_path) @override def prepare_data(self) -> None: if self._download: self._download_dataset() + _validators.check_dataset_exists(self._root, True) @override def configure(self) -> None: @@ -135,7 +136,7 @@ def __len__(self) -> int: return len(self._samples) @property - def _dataset_dir(self) -> str: + def _dataset_path(self) -> str: """Returns the full path of dataset directory.""" dataset_dirs = { "train": os.path.join(self._root, "NCT-CRC-HE-100K"), @@ -150,7 +151,7 @@ def _dataset_dir(self) -> str: def _make_dataset(self) -> List[Tuple[str, int]]: """Builds the dataset for the specified split.""" dataset = folder.make_dataset( - directory=self._dataset_dir, + directory=self._dataset_path, class_to_idx=self.class_to_idx, extensions=(".tif"), ) diff --git a/src/eva/vision/data/datasets/classification/mhist.py b/src/eva/vision/data/datasets/classification/mhist.py index 9f4bda6c..75297183 100644 --- a/src/eva/vision/data/datasets/classification/mhist.py +++ b/src/eva/vision/data/datasets/classification/mhist.py @@ -56,6 +56,10 @@ def filename(self, index: int) -> str: image_filename, _ = self._samples[index] return image_filename + @override + def prepare_data(self) -> None: + _validators.check_dataset_exists(self._root, False) + @override def configure(self) -> None: self._samples = self._make_dataset() diff --git a/src/eva/vision/data/datasets/classification/patch_camelyon.py b/src/eva/vision/data/datasets/classification/patch_camelyon.py index e18ee36e..e9eaa5f5 100644 --- a/src/eva/vision/data/datasets/classification/patch_camelyon.py +++ b/src/eva/vision/data/datasets/classification/patch_camelyon.py @@ -114,6 +114,7 @@ def filename(self, index: int) -> str: def prepare_data(self) -> None: if self._download: self._download_dataset() + _validators.check_dataset_exists(self._root, True) @override def validate(self) -> None: diff --git a/src/eva/vision/data/datasets/classification/total_segmentator.py b/src/eva/vision/data/datasets/classification/total_segmentator.py index f9ddac26..c7c0c88d 100644 --- a/src/eva/vision/data/datasets/classification/total_segmentator.py +++ b/src/eva/vision/data/datasets/classification/total_segmentator.py @@ -108,6 +108,7 @@ def filename(self, index: int) -> str: def prepare_data(self) -> None: if self._download: self._download_dataset() + _validators.check_dataset_exists(self._root, True) @override def configure(self) -> None: