From 773c526445f557aefda20c5ffc0094f896c76106 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Tue, 22 Oct 2024 12:29:24 +0200 Subject: [PATCH] add virchow2 model to registry --- .../advanced/replicate_evaluations.md | 17 ++++++- .../models/transforms/extract_cls_features.py | 26 ++++++---- .../transforms/extract_patch_features.py | 34 +++++++++---- .../models/networks/backbones/_utils.py | 12 +++++ .../networks/backbones/pathology/__init__.py | 2 + .../networks/backbones/pathology/histai.py | 4 +- .../networks/backbones/pathology/mahmood.py | 11 +--- .../networks/backbones/pathology/paige.py | 51 +++++++++++++++++++ .../core/models/wrappers/test_huggingface.py | 5 ++ 9 files changed, 130 insertions(+), 32 deletions(-) create mode 100644 src/eva/vision/models/networks/backbones/pathology/paige.py diff --git a/docs/user-guide/advanced/replicate_evaluations.md b/docs/user-guide/advanced/replicate_evaluations.md index 6baa4f09..ccee2bbb 100644 --- a/docs/user-guide/advanced/replicate_evaluations.md +++ b/docs/user-guide/advanced/replicate_evaluations.md @@ -204,6 +204,19 @@ IN_FEATURES=1024 \ eva predict_fit --config configs/vision/pathology/offline/.yaml ``` +### Virchow2 (paige.ai) - DINOv2 ViT-H14 (3.1M Slides) [[8]](#references) +To evaluate [paige.ai's](https://www.paige.ai/) FM with DINOv2 ViT-H14 backbone, pretrained on +a proprietary dataset of 3.1M million slides, available for download on +[HuggingFace](https://huggingface.co/paige-ai/Virchow2), run: + +``` +MODEL_NAME=paige/virchow2 \ +NORMALIZE_MEAN="[0.485,0.456,0.406]" \ +NORMALIZE_STD="[0.229,0.224,0.225]" \ +IN_FEATURES=1280 \ +eva predict_fit --config configs/vision/pathology/offline/.yaml +``` + ## References @@ -219,4 +232,6 @@ eva predict_fit --config configs/vision/pathology/offline/.yaml [6]: Xu, Hanwen, et al. "A whole-slide foundation model for digital pathology from real-world data." Nature (2024): 1-8. - [7]: Nechaev, Dmitry, Alexey Pchelnikov, and Ekaterina Ivanova. "Hibou: A Family of Foundational Vision Transformers for Pathology." arXiv preprint arXiv:2406.05074 (2024). \ No newline at end of file + [7]: Nechaev, Dmitry, Alexey Pchelnikov, and Ekaterina Ivanova. "Hibou: A Family of Foundational Vision Transformers for Pathology." arXiv preprint arXiv:2406.05074 (2024). + + [8]: Zimmermann, Eric, et al. "Virchow 2: Scaling Self-Supervised Mixed Magnification Models in Pathology." arXiv preprint arXiv:2408.00738 (2024). \ No newline at end of file diff --git a/src/eva/core/models/transforms/extract_cls_features.py b/src/eva/core/models/transforms/extract_cls_features.py index 4fe693e8..2a3612a4 100644 --- a/src/eva/core/models/transforms/extract_cls_features.py +++ b/src/eva/core/models/transforms/extract_cls_features.py @@ -7,13 +7,20 @@ class ExtractCLSFeatures: """Extracts the CLS token from a ViT model output.""" - def __init__(self, cls_index: int = 0) -> None: + def __init__( + self, cls_index: int = 0, num_reg_tokens: int = 0, include_patch_tokens: bool = False + ) -> None: """Initializes the transformation. Args: cls_index: The index of the CLS token in the output tensor. + num_reg_tokens: The number of register tokens in the model output. + include_patch_tokens: Whether to concat the mean aggregated patch tokens with + the cls token. """ self._cls_index = cls_index + self._num_reg_tokens = num_reg_tokens + self._include_patch_tokens = include_patch_tokens def __call__( self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling @@ -23,11 +30,12 @@ def __call__( Args: tensor: The tensor representing the model output. """ - if isinstance(tensor, torch.Tensor): - transformed_tensor = tensor[:, self._cls_index, :] - elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): - transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :] - else: - raise ValueError(f"Unsupported type {type(tensor)}") - - return transformed_tensor + if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): + tensor = tensor.last_hidden_state + + cls_token = tensor[:, self._cls_index, :] + if self._include_patch_tokens: + patch_tokens = tensor[:, 1 + self._num_reg_tokens :, :] + return torch.cat([cls_token, patch_tokens.mean(1)], dim=-1) + + return cls_token diff --git a/src/eva/core/models/transforms/extract_patch_features.py b/src/eva/core/models/transforms/extract_patch_features.py index 0a87d27d..6ac9ac98 100644 --- a/src/eva/core/models/transforms/extract_patch_features.py +++ b/src/eva/core/models/transforms/extract_patch_features.py @@ -10,13 +10,23 @@ class ExtractPatchFeatures: """Extracts the patch features from a ViT model output.""" - def __init__(self, ignore_remaining_dims: bool = False) -> None: + def __init__( + self, + has_cls_token: bool = True, + num_reg_tokens: int = 0, + ignore_remaining_dims: bool = False, + ) -> None: """Initializes the transformation. Args: + has_cls_token: If set to `True`, the model output is expected to have + a classification token. + num_reg_tokens: The number of register tokens in the model output. ignore_remaining_dims: If set to `True`, ignore the remaining dimensions of the patch grid if it is not a square number. """ + self._has_cls_token = has_cls_token + self._num_reg_tokens = num_reg_tokens self._ignore_remaining_dims = ignore_remaining_dims def __call__( @@ -31,17 +41,19 @@ def __call__( A tensor (batch_size, hidden_size, n_patches_height, n_patches_width) representing the model output. """ + num_skip = int(self._has_cls_token) + self._num_reg_tokens if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): - features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1) - batch_size, hidden_size, patch_grid = features.shape - height = width = int(math.sqrt(patch_grid)) - if height * width != patch_grid: - if self._ignore_remaining_dims: - features = features[:, :, -height * width :] - else: - raise ValueError(f"Patch grid size must be a square number {patch_grid}.") - patch_embeddings = features.view(batch_size, hidden_size, height, width) + features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1) else: - raise ValueError(f"Unsupported type {type(tensor)}") + features = tensor[:, num_skip:, :].permute(0, 2, 1) + + batch_size, hidden_size, patch_grid = features.shape + height = width = int(math.sqrt(patch_grid)) + if height * width != patch_grid: + if self._ignore_remaining_dims: + features = features[:, :, -height * width :] + else: + raise ValueError(f"Patch grid size must be a square number {patch_grid}.") + patch_embeddings = features.view(batch_size, hidden_size, height, width) return [patch_embeddings] diff --git a/src/eva/vision/models/networks/backbones/_utils.py b/src/eva/vision/models/networks/backbones/_utils.py index f68fc98f..4e2d844b 100644 --- a/src/eva/vision/models/networks/backbones/_utils.py +++ b/src/eva/vision/models/networks/backbones/_utils.py @@ -1,7 +1,9 @@ """Utilis for backbone networks.""" +import os from typing import Any, Dict, Tuple +import huggingface_hub from torch import nn from eva import models @@ -37,3 +39,13 @@ def load_hugingface_model( tensor_transforms=tensor_transforms, model_kwargs=model_kwargs, ) + + +def huggingface_login(hf_token: str | None = None): + token = hf_token or os.environ.get("HF_TOKEN") + if not token: + raise ValueError( + "Please provide a HuggingFace token to download the model. " + "You can either pass it as an argument or set the env variable HF_TOKEN." + ) + huggingface_hub.login(token=token) diff --git a/src/eva/vision/models/networks/backbones/pathology/__init__.py b/src/eva/vision/models/networks/backbones/pathology/__init__.py index a4b83dc8..48222ab8 100644 --- a/src/eva/vision/models/networks/backbones/pathology/__init__.py +++ b/src/eva/vision/models/networks/backbones/pathology/__init__.py @@ -13,6 +13,7 @@ from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16 from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon +from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2 __all__ = [ "kaiko_vitb16", @@ -28,4 +29,5 @@ "prov_gigapath", "histai_hibou_b", "histai_hibou_l", + "paige_virchow2", ] diff --git a/src/eva/vision/models/networks/backbones/pathology/histai.py b/src/eva/vision/models/networks/backbones/pathology/histai.py index e74fba4a..20dc890e 100644 --- a/src/eva/vision/models/networks/backbones/pathology/histai.py +++ b/src/eva/vision/models/networks/backbones/pathology/histai.py @@ -23,7 +23,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul model_name="histai/hibou-B", out_indices=out_indices, model_kwargs={"trust_remote_code": True}, - transform_args={"ignore_remaining_dims": True} if out_indices is not None else None, + transform_args={"num_reg_tokens": 4} if out_indices is not None else None, ) @@ -42,5 +42,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul model_name="histai/hibou-L", out_indices=out_indices, model_kwargs={"trust_remote_code": True}, - transform_args={"ignore_remaining_dims": True} if out_indices is not None else None, + transform_args={"num_reg_tokens": 4} if out_indices is not None else None, ) diff --git a/src/eva/vision/models/networks/backbones/pathology/mahmood.py b/src/eva/vision/models/networks/backbones/pathology/mahmood.py index 80219f29..9bb9cac1 100644 --- a/src/eva/vision/models/networks/backbones/pathology/mahmood.py +++ b/src/eva/vision/models/networks/backbones/pathology/mahmood.py @@ -9,6 +9,7 @@ from torch import nn from eva.vision.models import wrappers +from eva.vision.models.networks.backbones import _utils from eva.vision.models.networks.backbones.registry import register_model @@ -31,19 +32,11 @@ def mahmood_uni( Returns: The model instance. """ - token = hf_token or os.environ.get("HF_TOKEN") - if not token: - raise ValueError( - "Please provide a HuggingFace token to download the model. " - "You can either pass it as an argument or set the env variable HF_TOKEN." - ) - checkpoint_path = os.path.join(download_dir, "pytorch_model.bin") - if not os.path.exists(checkpoint_path): logger.info(f"Downloading the model checkpoint to {download_dir} ...") os.makedirs(download_dir, exist_ok=True) - huggingface_hub.login(token=token) + _utils.huggingface_login(hf_token) huggingface_hub.hf_hub_download( "MahmoodLab/UNI", filename="pytorch_model.bin", diff --git a/src/eva/vision/models/networks/backbones/pathology/paige.py b/src/eva/vision/models/networks/backbones/pathology/paige.py new file mode 100644 index 00000000..dfb41aa6 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/pathology/paige.py @@ -0,0 +1,51 @@ +"""Pathology FMs from paige.ai. + +Source: https://huggingface.co/paige-ai/ +""" + +from typing import Tuple + +import timm +import torch.nn as nn + +from eva.core.models import transforms +from eva.vision.models import wrappers +from eva.vision.models.networks.backbones import _utils +from eva.vision.models.networks.backbones.registry import register_model + + +@register_model("paige/virchow2") +def paige_virchow2( + dynamic_img_size: bool = True, + out_indices: int | Tuple[int, ...] | None = None, + hf_token: str | None = None, + include_patch_tokens: bool = False, +) -> nn.Module: + """Initializes the Virchow2 pathology FM by paige.ai. + + Args: + dynamic_img_size: Support different input image sizes by allowing to change + the grid size (interpolate abs and/or ROPE pos) in the forward pass. + out_indices: Whether and which multi-level patch embeddings to return. + include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token. + hf_token: HuggingFace token to download the model. + + Returns: + The model instance. + """ + _utils.huggingface_login(hf_token) + return wrappers.TimmModel( + model_name="hf-hub:paige-ai/Virchow2", + out_indices=out_indices, + pretrained=True, + model_kwargs={ + "dynamic_img_size": dynamic_img_size, + "mlp_layer": timm.layers.SwiGLUPacked, + "act_layer": nn.SiLU, + }, + tensor_transforms=( + transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens) + if out_indices is None + else None + ), + ) diff --git a/tests/eva/core/models/wrappers/test_huggingface.py b/tests/eva/core/models/wrappers/test_huggingface.py index 468eac8c..9b136170 100644 --- a/tests/eva/core/models/wrappers/test_huggingface.py +++ b/tests/eva/core/models/wrappers/test_huggingface.py @@ -14,6 +14,11 @@ [ ("hf-internal-testing/tiny-random-ViTModel", None, (16, 226, 32)), ("hf-internal-testing/tiny-random-ViTModel", transforms.ExtractCLSFeatures(), (16, 32)), + ( + "hf-internal-testing/tiny-random-ViTModel", + transforms.ExtractCLSFeatures(include_patch_tokens=True), + (16, 64), + ), ], ) def test_huggingface_model(