diff --git a/configs/vision/pathology/offline/classification/camelyon16.yaml b/configs/vision/pathology/offline/classification/camelyon16.yaml index 0edbcfa7..6d898517 100644 --- a/configs/vision/pathology/offline/classification/camelyon16.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16.yaml @@ -108,6 +108,7 @@ data: height: 224 target_mpp: 0.25 split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv image_transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: diff --git a/configs/vision/pathology/offline/classification/camelyon16_small.yaml b/configs/vision/pathology/offline/classification/camelyon16_small.yaml index 99429325..133350b5 100644 --- a/configs/vision/pathology/offline/classification/camelyon16_small.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16_small.yaml @@ -108,6 +108,7 @@ data: height: 224 target_mpp: 0.25 split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv image_transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: diff --git a/configs/vision/pathology/offline/classification/panda.yaml b/configs/vision/pathology/offline/classification/panda.yaml index 2753f281..b88138c5 100644 --- a/configs/vision/pathology/offline/classification/panda.yaml +++ b/configs/vision/pathology/offline/classification/panda.yaml @@ -107,6 +107,7 @@ data: height: 224 target_mpp: 0.5 split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv image_transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: diff --git a/configs/vision/pathology/offline/classification/panda_small.yaml b/configs/vision/pathology/offline/classification/panda_small.yaml index bfd598f2..53735a7c 100644 --- a/configs/vision/pathology/offline/classification/panda_small.yaml +++ b/configs/vision/pathology/offline/classification/panda_small.yaml @@ -107,6 +107,7 @@ data: height: 224 target_mpp: 0.5 split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv image_transforms: class_path: eva.vision.data.transforms.common.ResizeAndCrop init_args: diff --git a/src/eva/core/callbacks/writers/embeddings/base.py b/src/eva/core/callbacks/writers/embeddings/base.py index f4930cd7..6cdde5ab 100644 --- a/src/eva/core/callbacks/writers/embeddings/base.py +++ b/src/eva/core/callbacks/writers/embeddings/base.py @@ -172,15 +172,14 @@ def _get_item_metadata( def _check_if_exists(self) -> None: """Checks if the output directory already exists and if it should be overwritten.""" - try: - os.makedirs(self._output_dir, exist_ok=self._overwrite) - except FileExistsError as e: + os.makedirs(self._output_dir, exist_ok=True) + if os.path.exists(os.path.join(self._output_dir, "manifest.csv")) and not self._overwrite: raise FileExistsError( f"The embeddings output directory already exists: {self._output_dir}. This " "either means that they have been computed before or that a wrong output " "directory is being used. Consider using `eva fit` instead, selecting a " "different output directory or setting overwrite=True." - ) from e + ) os.makedirs(self._output_dir, exist_ok=True) diff --git a/src/eva/vision/data/datasets/classification/camelyon16.py b/src/eva/vision/data/datasets/classification/camelyon16.py index a7ace11c..e8abb527 100644 --- a/src/eva/vision/data/datasets/classification/camelyon16.py +++ b/src/eva/vision/data/datasets/classification/camelyon16.py @@ -87,6 +87,7 @@ def __init__( target_mpp: float = 0.5, backend: str = "openslide", image_transforms: Callable | None = None, + coords_path: str | None = None, seed: int = 42, ) -> None: """Initializes the dataset. @@ -100,6 +101,7 @@ def __init__( target_mpp: Target microns per pixel (mpp) for the patches. backend: The backend to use for reading the whole-slide images. image_transforms: Transforms to apply to the extracted image patches. + coords_path: File path to save the patch coordinates as .csv. seed: Random seed for reproducibility. """ self._split = split @@ -119,6 +121,7 @@ def __init__( target_mpp=target_mpp, backend=backend, image_transforms=image_transforms, + coords_path=coords_path, ) @property @@ -207,7 +210,7 @@ def load_target(self, index: int) -> torch.Tensor: @override def load_metadata(self, index: int) -> Dict[str, Any]: - return {"wsi_id": self.filename(index).split(".")[0]} + return wsi.MultiWsiDataset.load_metadata(self, index) def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: """Loads the file paths of the corresponding dataset split.""" diff --git a/src/eva/vision/data/datasets/classification/panda.py b/src/eva/vision/data/datasets/classification/panda.py index ffa00ab3..fb089c47 100644 --- a/src/eva/vision/data/datasets/classification/panda.py +++ b/src/eva/vision/data/datasets/classification/panda.py @@ -49,6 +49,7 @@ def __init__( target_mpp: float = 0.5, backend: str = "openslide", image_transforms: Callable | None = None, + coords_path: str | None = None, seed: int = 42, ) -> None: """Initializes the dataset. @@ -62,6 +63,7 @@ def __init__( target_mpp: Target microns per pixel (mpp) for the patches. backend: The backend to use for reading the whole-slide images. image_transforms: Transforms to apply to the extracted image patches. + coords_path: File path to save the patch coordinates as .csv. seed: Random seed for reproducibility. """ self._split = split @@ -80,6 +82,7 @@ def __init__( target_mpp=target_mpp, backend=backend, image_transforms=image_transforms, + coords_path=coords_path, ) @property @@ -132,7 +135,7 @@ def load_target(self, index: int) -> torch.Tensor: @override def load_metadata(self, index: int) -> Dict[str, Any]: - return {"wsi_id": self.filename(index).split(".")[0]} + return wsi.MultiWsiDataset.load_metadata(self, index) def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: """Loads the file paths of the corresponding dataset split.""" diff --git a/src/eva/vision/data/datasets/classification/wsi.py b/src/eva/vision/data/datasets/classification/wsi.py index 9e1cae52..e0b4f83a 100644 --- a/src/eva/vision/data/datasets/classification/wsi.py +++ b/src/eva/vision/data/datasets/classification/wsi.py @@ -35,6 +35,7 @@ def __init__( split: Literal["train", "val", "test"] | None = None, image_transforms: Callable | None = None, column_mapping: Dict[str, str] = default_column_mapping, + coords_path: str | None = None, ): """Initializes the dataset. @@ -51,6 +52,7 @@ def __init__( split: The split of the dataset to load. image_transforms: Transforms to apply to the extracted image patches. column_mapping: Mapping of the columns in the manifest file. + coords_path: File path to save the patch coordinates as .csv. """ self._split = split self._column_mapping = self.default_column_mapping | column_mapping @@ -66,6 +68,7 @@ def __init__( target_mpp=target_mpp, backend=backend, image_transforms=image_transforms, + coords_path=coords_path, ) @override @@ -88,7 +91,7 @@ def load_target(self, index: int) -> np.ndarray: @override def load_metadata(self, index: int) -> Dict[str, Any]: - return {"wsi_id": self.filename(index).split(".")[0]} + return wsi.MultiWsiDataset.load_metadata(self, index) def _load_manifest(self, manifest_path: str) -> pd.DataFrame: df = pd.read_csv(manifest_path) diff --git a/src/eva/vision/data/datasets/wsi.py b/src/eva/vision/data/datasets/wsi.py index fe83ca63..803493ab 100644 --- a/src/eva/vision/data/datasets/wsi.py +++ b/src/eva/vision/data/datasets/wsi.py @@ -2,8 +2,9 @@ import bisect import os -from typing import Callable, List +from typing import Any, Callable, Dict, List +import pandas as pd from loguru import logger from torch.utils.data import dataset as torch_datasets from torchvision import tv_tensors @@ -85,6 +86,17 @@ def __getitem__(self, index: int) -> tv_tensors.Image: patch = self._apply_transforms(patch) return patch + def load_metadata(self, index: int) -> Dict[str, Any]: + """Loads the metadata for the patch at the specified index.""" + x, y = self._coords.x_y[index] + return { + "x": x, + "y": y, + "width": self._coords.width, + "height": self._coords.height, + "level_idx": self._coords.level_idx, + } + def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image: if self._image_transforms is not None: image = self._image_transforms(image) @@ -105,6 +117,7 @@ def __init__( overwrite_mpp: float | None = None, backend: str = "openslide", image_transforms: Callable | None = None, + coords_path: str | None = None, ): """Initializes a new dataset instance. @@ -118,6 +131,7 @@ def __init__( sampler: The sampler to use for sampling patch coordinates. backend: The backend to use for reading the whole-slide images. image_transforms: Transforms to apply to the extracted image patches. + coords_path: File path to save the patch coordinates as .csv. """ super().__init__() @@ -130,6 +144,7 @@ def __init__( self._sampler = sampler self._backend = backend self._image_transforms = image_transforms + self._coords_path = coords_path self._concat_dataset: torch_datasets.ConcatDataset @@ -146,6 +161,7 @@ def cumulative_sizes(self) -> List[int]: @override def configure(self) -> None: self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets()) + self._save_coords_to_file() @override def __len__(self) -> int: @@ -159,6 +175,12 @@ def __getitem__(self, index: int) -> tv_tensors.Image: def filename(self, index: int) -> str: return os.path.basename(self._file_paths[self._get_dataset_idx(index)]) + def load_metadata(self, index: int) -> Dict[str, Any]: + """Loads the metadata for the patch at the specified index.""" + dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index) + patch_metadata = self.datasets[dataset_index].load_metadata(sample_index) + return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata + def _load_datasets(self) -> list[WsiDataset]: logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...") wsi_datasets = [] @@ -185,3 +207,17 @@ def _load_datasets(self) -> list[WsiDataset]: def _get_dataset_idx(self, index: int) -> int: return bisect.bisect_right(self.cumulative_sizes, index) + + def _get_sample_idx(self, index: int) -> int: + dataset_idx = self._get_dataset_idx(index) + return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1] + + def _save_coords_to_file(self): + if self._coords_path is not None: + coords = [ + {"file": self._file_paths[i]} | dataset._coords.to_dict() + for i, dataset in enumerate(self.datasets) + ] + os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True) + pd.DataFrame(coords).to_csv(self._coords_path, index=False) + logger.info(f"Saved patch coordinates to: {self._coords_path}") diff --git a/src/eva/vision/data/wsi/patching/coordinates.py b/src/eva/vision/data/wsi/patching/coordinates.py index bab7e0be..0152115f 100644 --- a/src/eva/vision/data/wsi/patching/coordinates.py +++ b/src/eva/vision/data/wsi/patching/coordinates.py @@ -2,7 +2,7 @@ import dataclasses import functools -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from eva.vision.data.wsi import backends from eva.vision.data.wsi.patching import samplers @@ -75,6 +75,14 @@ def from_file( return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask")) + def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]: + """Convert the coordinates to a dictionary.""" + include_keys = include_keys or ["x_y", "width", "height", "level_idx"] + coord_dict = dataclasses.asdict(self) + if include_keys: + coord_dict = {key: coord_dict[key] for key in include_keys} + return coord_dict + @functools.lru_cache(LRU_CACHE_SIZE) def get_cached_coords( diff --git a/tests/eva/vision/data/datasets/classification/test_camelyon16.py b/tests/eva/vision/data/datasets/classification/test_camelyon16.py index c7c7277c..58594fb3 100644 --- a/tests/eva/vision/data/datasets/classification/test_camelyon16.py +++ b/tests/eva/vision/data/datasets/classification/test_camelyon16.py @@ -69,6 +69,11 @@ def _check_batch_shape(batch: Any): assert isinstance(target, torch.Tensor) assert isinstance(metadata, dict) assert "wsi_id" in metadata + assert "x" in metadata + assert "y" in metadata + assert "width" in metadata + assert "height" in metadata + assert "level_idx" in metadata @pytest.fixture diff --git a/tests/eva/vision/data/datasets/classification/test_panda.py b/tests/eva/vision/data/datasets/classification/test_panda.py index 783cc341..ce993a88 100644 --- a/tests/eva/vision/data/datasets/classification/test_panda.py +++ b/tests/eva/vision/data/datasets/classification/test_panda.py @@ -102,6 +102,11 @@ def _check_batch_shape(batch: Any): assert isinstance(target, torch.Tensor) assert isinstance(metadata, dict) assert "wsi_id" in metadata + assert "x" in metadata + assert "y" in metadata + assert "width" in metadata + assert "height" in metadata + assert "level_idx" in metadata @pytest.fixture diff --git a/tests/eva/vision/data/datasets/classification/test_wsi.py b/tests/eva/vision/data/datasets/classification/test_wsi.py index d14573d8..c2dc4bcc 100644 --- a/tests/eva/vision/data/datasets/classification/test_wsi.py +++ b/tests/eva/vision/data/datasets/classification/test_wsi.py @@ -79,6 +79,11 @@ def _check_batch_shape(batch: Any): assert isinstance(metadata, dict) assert "wsi_id" in metadata + assert "x" in metadata + assert "y" in metadata + assert "width" in metadata + assert "height" in metadata + assert "level_idx" in metadata @pytest.fixture diff --git a/tests/eva/vision/data/datasets/test_wsi.py b/tests/eva/vision/data/datasets/test_wsi.py index 87959a60..5c01a59d 100644 --- a/tests/eva/vision/data/datasets/test_wsi.py +++ b/tests/eva/vision/data/datasets/test_wsi.py @@ -1,8 +1,10 @@ """WsiDataset & MultiWsiDataset tests.""" import os +import pathlib from typing import Tuple +import pandas as pd import pytest from eva.vision.data import datasets @@ -69,14 +71,14 @@ def test_patch_shape(width: int, height: int, target_mpp: float, root: str, back assert dataset[0].shape == (3, scaled_width, scaled_height) -def test_multi_dataset(root: str): +def test_multi_dataset(root: str, tmp_path: pathlib.Path): """Test MultiWsiDataset with multiple whole-slide image paths.""" + coords_path = (tmp_path / "coords.csv").as_posix() file_paths = [ os.path.join(root, "0/a.tiff"), os.path.join(root, "0/b.tiff"), os.path.join(root, "1/a.tiff"), ] - width, height = 32, 32 dataset = datasets.MultiWsiDataset( root=root, @@ -86,6 +88,7 @@ def test_multi_dataset(root: str): target_mpp=0.25, sampler=samplers.GridSampler(max_samples=None), backend="openslide", + coords_path=coords_path, ) dataset.setup() @@ -94,6 +97,11 @@ def test_multi_dataset(root: str): assert len(dataset) == _expected_n_patches(layer_shape, width, height, (0, 0)) * len(file_paths) assert dataset.cumulative_sizes == [64, 128, 192] + assert os.path.exists(coords_path) + df_coords = pd.read_csv(coords_path) + assert "file" in df_coords.columns + assert "x_y" in df_coords.columns + def _expected_n_patches(layer_shape, width, height, overlap): """Calculate the expected number of patches."""