From 8b0163843079f0c5cea41031349c9c97dff9635a Mon Sep 17 00:00:00 2001 From: guarin <43336610+guarin@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:45:04 +0200 Subject: [PATCH] Add option for boolean masking in masked vision transformer (#1600) * Add option for boolean masking in MaskedVisionTransformer * Replace add_prefix_tokens with prepend_prefix_tokens * Move preprocess method to MaskedVisionTransformer * Move common docstrings to MaskedVisionTransformer --- .../modules/masked_vision_transformer.py | 197 +++++++++++++- .../modules/masked_vision_transformer_timm.py | 140 +--------- .../masked_vision_transformer_torchvision.py | 96 +------ lightly/models/utils.py | 24 +- .../modules/masked_vision_transformer_test.py | 242 ++++++++++++++---- ...t_masked_vision_transformer_torchvision.py | 8 +- tests/models/test_ModelUtils.py | 45 ++++ 7 files changed, 484 insertions(+), 268 deletions(-) diff --git a/lightly/models/modules/masked_vision_transformer.py b/lightly/models/modules/masked_vision_transformer.py index 9df55a58c..66d2079c2 100644 --- a/lightly/models/modules/masked_vision_transformer.py +++ b/lightly/models/modules/masked_vision_transformer.py @@ -2,7 +2,9 @@ from typing import List, Optional, Tuple from torch import Tensor -from torch.nn import Module +from torch.nn import Module, Parameter + +from lightly.models import utils class MaskedVisionTransformer(ABC, Module): @@ -14,6 +16,10 @@ class MaskedVisionTransformer(ABC, Module): tokenization of images, and various operations needed for the transformer. """ + # This is not defined as a property for backwards compatibility. + # New models should define this as a property. + mask_token: Parameter + @property @abstractmethod def sequence_length(self) -> int: @@ -25,7 +31,36 @@ def forward( images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: + """Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + Indices must be in the range [0, sequence_length). + If set, the indexed tokens are masked with self.mask_token. + Cannot be used in combination with mask argument. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + Indices must be in the range [0, sequence_length). + If set, only the indexed tokens will be forwarded. + Is applied after any masking operation. + mask: + Boolean tensor with shape (batch_size, sequence_length) indicating + which tokens should be masked. Tokens where the mask is True will be + replaced with the mask token. + Cannot be used in combination with idx_mask argument. + + Returns: + Tensor with shape (batch_size, embed_dim) containing the encoded class token + for every image. + + """ ... @abstractmethod @@ -35,7 +70,40 @@ def forward_intermediates( idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, norm: bool = False, + mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor]]: + """Encode input images and return features from the intermediate layers. + + Args: + images: + Tensor with shape (batch_size, channels, image_height, image_width). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + Indices must be in the range [0, sequence_length). + If specified, the indexed tokens are masked with self.mask_token. + Cannot be used in combination with mask argument. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + Indices must be in the range [0, sequence_length). + If set, only the indexed tokens will be forwarded. + Is applied after any masking operation. + norm: + Apply norm layer to all intermediates. + mask: + Boolean tensor with shape (batch_size, sequence_length) indicating + which tokens should be masked. Tokens where the mask is True will be + replaced with the mask token. + Cannot be used in combination with idx_mask argument. + + Returns: + Tuple of batch of encoded output tokens and a list of intermediate features. + The encoded output tokens have shape (batch_size, embed_dim) and each + intermediate feature has shape (batch_size, sequence_length, embed_dim). + If idx_keep is set, only num_tokens_to_keep tokens per sequence are + returned. + """ ... @abstractmethod @@ -44,17 +112,142 @@ def encode( images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: + """Encode input images. + + Args: + images: + Tensor with shape (batch_size, channels, image_height, image_width). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + Indices must be in the range [0, sequence_length). + If specified, the indexed tokens are masked with self.mask_token. + Cannot be used in combination with mask argument. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + Indices must be in the range [0, sequence_length). + If set, only the indexed tokens will be encoded. + Is applied after any masking operation. + mask: + Boolean tensor with shape (batch_size, sequence_length) indicating + which tokens should be masked. Tokens where the mask is True will be + replaced with the mask token. + Cannot be used in combination with idx_mask argument. + + Returns: + Tensor with shape (batch_size, sequence_length, embed_dim) containing the + encoded output tokens. If idx_keep is set, only num_tokens_to_keep tokens + per sequence are returned. + """ ... + def preprocess( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: + """Convert images to tokens, add positional embeddings, and apply masking. + + Args: + images: + Tensor with shape (batch_size, channels, image_height, image_width). + idx_mask: + Tensor with shape (batch_size, num_tokens_to_mask) where each + entry is an index of the token to mask in the respective batch. + Indices must be in the range [0, sequence_length). + If specified, the indexed tokens are masked with self.mask_token. + Cannot be used in combination with mask argument. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + Indices must be in the range [0, sequence_length). + If set, only the indexed tokens will be returned. + Is applied after any masking operation. + mask: + Tensor with shape (batch_size, sequence_length) indicating which tokens + should be masked. Tokens where the mask is True will be masked with + self.mask_token. + + Returns: + Tensor with shape (batch_size, sequence_length, embed_dim) containing the + preprocessed tokens. If idx_keep is set, only num_tokens_to_keep tokens + per sequence are returned. Any class or prefix tokens are prepended to the + sequence. + """ + if idx_mask is not None and mask is not None: + raise ValueError("idx_mask and mask cannot both be set at the same time.") + + # convert images to tokens + tokens = self.images_to_tokens(images) + # add prefix tokens if needed + tokens = self.prepend_prefix_tokens(tokens) + + if idx_mask is not None: + tokens = utils.mask_at_index( + tokens=tokens, index=idx_mask, mask_token=self.mask_token + ) + elif mask is not None: + tokens = utils.mask_bool( + tokens=tokens, mask=mask, mask_token=self.mask_token + ) + + # add positional encoding + tokens = self.add_pos_embed(tokens) + + if idx_keep is not None: + tokens = utils.get_at_index(tokens, idx_keep) + + return tokens + @abstractmethod def images_to_tokens(self, images: Tensor) -> Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_height, image_width). + + Returns: + Tensor with shape (batch_size, num_patches, embed_dim) containing the + patch tokens (excluding prefix tokens). + """ ... - @abstractmethod + # Keep for backwards compatibility. def add_prefix_tokens(self, x: Tensor) -> Tensor: + return self.prepend_prefix_tokens(x) + + @abstractmethod + def prepend_prefix_tokens(self, x: Tensor) -> Tensor: + """Prepends prefix tokens to the input patch tokens. + + Args: + x: + Tensor with shape (batch_size, num_patches, embed_dim) containing patch + tokens. + + Returns: + Tensor with shape (batch_size, sequence_length, embed_dim) containing + the prefix and patch tokens. The prefix tokens are prepended to the + sequence. + """ ... @abstractmethod def add_pos_embed(self, x: Tensor) -> Tensor: + """Adds positional embeddings to the input tokens. + + Args: + x: + Tensor with shape (batch_size, sequence_length, embed_dim) containing + the input tokens. Must include prefix tokens. + + Returns: + Tensor after adding positional embeddings, with the same shape as the input. + """ ... diff --git a/lightly/models/modules/masked_vision_transformer_timm.py b/lightly/models/modules/masked_vision_transformer_timm.py index 4079023d4..6467c6a1e 100644 --- a/lightly/models/modules/masked_vision_transformer_timm.py +++ b/lightly/models/modules/masked_vision_transformer_timm.py @@ -64,28 +64,9 @@ def forward( images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: - """Returns encoded class tokens from a batch of images. - - Args: - images: - Tensor with shape (batch_size, channels, image_size, image_size). - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be passed to the - encoder. - - Returns: - Tensor with shape (batch_size, vit.embed_dim) containing the - encoded class token for every image. - - """ - x = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep) + x = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep, mask=mask) if self.vit.attn_pool is not None: x = self.vit.attn_pool(x) elif self.vit.global_pool == "avg": @@ -100,29 +81,12 @@ def forward_intermediates( idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, norm: bool = False, + mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor]]: - """Encode input images and return features from the intermediate layers. - - Args: - images: - Batch of input images. - norm: - Apply norm layer to all intermediates. - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be encoded. - - Returns: - Tuple of batch of encoded output tokens and a list of intermediate features - from each layer with shape (batch_size, self.sequence_length, vit.embed_dim). - """ # preprocess images, convert to tokens and add positional embeddings - tokens = self.preprocess(images=images, idx_mask=idx_mask, idx_keep=idx_keep) + tokens = self.preprocess( + images=images, idx_mask=idx_mask, idx_keep=idx_keep, mask=mask + ) # normalization layer tokens = self.vit.norm_pre(tokens) @@ -136,70 +100,16 @@ def forward_intermediates( return out, intermediates - def preprocess( - self, - images: Tensor, - idx_mask: Optional[Tensor] = None, - idx_keep: Optional[Tensor] = None, - ) -> Tensor: - """Convert images to tokens, add positional embeddings, and apply masking. - - Args: - images: - Batch of input images. - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be encoded. - - Returns: - Batch of preprocessed tokens. - """ - # convert images to tokens - tokens = self.images_to_tokens(images) - # add prefix tokens if needed - tokens = self.add_prefix_tokens(tokens) - - if idx_mask is not None: - tokens = utils.mask_at_index(tokens, idx_mask, self.mask_token) - # add positional encoding - tokens = self.add_pos_embed(tokens) - - if idx_keep is not None: - tokens = utils.get_at_index(tokens, idx_keep) - - return tokens - def encode( self, images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: - """Encode input images. - - Args: - images: - Batch of input images. - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be encoded. - - Returns: - Batch of encoded output tokens. - """ # preprocess images, convert to tokens and add positional embeddings tokens: Tensor = self.preprocess( - images=images, idx_mask=idx_mask, idx_keep=idx_keep + images=images, idx_mask=idx_mask, idx_keep=idx_keep, mask=mask ) # normalization layer tokens = self.vit.norm_pre(tokens) @@ -210,34 +120,13 @@ def encode( return tokens def images_to_tokens(self, images: Tensor) -> Tensor: - """Converts images into patch tokens. - - Args: - images: - Tensor with shape (batch_size, channels, image_size, image_size). - - Returns: - Tensor with shape (batch_size, vit.patch_embed.num_patches, vit.embed_dim) - containing the patch tokens (excluding prefix tokens). - """ tokens: Tensor = self.vit.patch_embed(images) if self.vit.dynamic_img_size: tokens = tokens.permute(0, 3, 1, 2) # NHWC -> NCHW tokens = tokens.flatten(2).transpose(1, 2) # NCHW -> NLC return tokens - def add_prefix_tokens(self, x: Tensor) -> Tensor: - """Adds prefix tokens to image patch tokens. - - Args: - x: - Tensor with shape (batch_size, vit.patch_embed.num_patches, vit.embed_dim) - containing the image patch tokens - - Returns: - Tensor with shape (batch_size, self.sequence_length, vit.embed_dim) containing - the image patch tokens and prefix tokens. - """ + def prepend_prefix_tokens(self, x: Tensor) -> Tensor: prefix_tokens = [] if self.vit.cls_token is not None: prefix_tokens.append(self.vit.cls_token.expand(x.shape[0], -1, -1)) @@ -248,17 +137,6 @@ def add_prefix_tokens(self, x: Tensor) -> Tensor: return x def add_pos_embed(self, x: Tensor) -> Tensor: - """Adds positional embeddings to the input tensor based on the Vision Transformer - (ViT) architecture in vit. - - Args: - x: - Input tensor with shape (batch_size, self.sequence_length, vit.embed_dim). - - Returns: - Tensor after adding positional embeddings, with the same shape as the input. - """ - x_prefix = x[:, : self.vit.num_prefix_tokens, :] x = x[:, self.vit.num_prefix_tokens :, :] if self.vit.dynamic_img_size: diff --git a/lightly/models/modules/masked_vision_transformer_torchvision.py b/lightly/models/modules/masked_vision_transformer_torchvision.py index 7a24c196c..149be8ef4 100644 --- a/lightly/models/modules/masked_vision_transformer_torchvision.py +++ b/lightly/models/modules/masked_vision_transformer_torchvision.py @@ -93,28 +93,9 @@ def forward( images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: - """Returns encoded class tokens from a batch of images. - - Args: - images: - Tensor with shape (batch_size, channels, image_size, image_size). - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be passed to the - encoder. - - Returns: - Tensor with shape (batch_size, vit.hidden_dim) containing the - encoded class token for every image. - - """ - out = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep) + out = self.encode(images, idx_mask=idx_mask, idx_keep=idx_keep, mask=mask) class_token = out[:, 0] return class_token @@ -124,6 +105,7 @@ def forward_intermediates( idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, norm: bool = False, + mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor]]: raise NotImplementedError( "forward_intermediates is not implemented for this model." @@ -134,85 +116,29 @@ def encode( images: Tensor, idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, + mask: Optional[Tensor] = None, ) -> Tensor: - """Encode input images. - - Args: - input: - Batch of input images. - idx_mask: - Tensor with shape (batch_size, num_tokens_to_mask) where each - entry is an index of the token to mask in the respective batch. - If specified, the indexed tokens are masked with self.mask_token. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be encoded. - - Returns: - Batch of encoded output tokens. - """ - # convert images to tokens - input = self.images_to_tokens(images) - # add prefix tokens if needed - input = self.add_prefix_tokens(input) - - if idx_mask is not None: - input = utils.mask_at_index(input, idx_mask, self.mask_token) - # add positional encoding - input = self.add_pos_embed(input) - - if idx_keep is not None: - input = utils.get_at_index(input, idx_keep) + tokens = self.preprocess( + images=images, idx_mask=idx_mask, idx_keep=idx_keep, mask=mask + ) out: Tensor = self.vit.encoder.ln( - self.vit.encoder.layers(self.vit.encoder.dropout(input)) + self.vit.encoder.layers(self.vit.encoder.dropout(tokens)) ) return out def images_to_tokens(self, images: Tensor) -> Tensor: - """Converts images into patch tokens. - - Args: - images: - Tensor with shape (batch_size, channels, image_size, image_size). - - Returns: - Tensor with shape (batch_size, vit.seq_length-1, vit.hidden_dim) containing - the image patch tokens. - """ x = self.vit.conv_proj(images) tokens: Tensor = x.flatten(2).transpose(1, 2) return tokens - def add_prefix_tokens(self, x: Tensor, prepend_class_token: bool = True) -> Tensor: - """Adds class token to image patch tokens. - - Args: - x: - Tensor with shape (batch_size, vit.seq_length-1, vit.hidden_dim) - containing the image patch tokens - prepend_class_token: - Boolean flag that determines if a class token should be prepended. - - Returns: - Tensor with shape (batch_size, vit.seq_length, vit.hidden_dim) containing - the image patch tokens and class tokens. - """ + def prepend_prefix_tokens( + self, x: Tensor, prepend_class_token: bool = True + ) -> Tensor: if prepend_class_token: x = utils.prepend_class_token(x, self.vit.class_token) return x def add_pos_embed(self, x: Tensor) -> Tensor: - """Adds positional embeddings to the input tensor based on the Vision Transformer - (ViT) architecture in vit. - - Args: - x: - Input tensor with shape (batch_size, self.sequence_length, vit.hidden_dim). - - Returns: - Tensor after adding positional embeddings, with the same shape as the input. - """ # TODO(Ersi:1/24) This adds positional encoding to the prefix tokens as well. # Give the option of not doing so, as is the case for TIMM. x = x + self.interpolate_pos_encoding(x) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 83c164247..816edac66 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -371,7 +371,8 @@ def set_at_index( def mask_at_index( tokens: torch.Tensor, index: torch.Tensor, mask_token: torch.Tensor ) -> torch.Tensor: - """Copies mask token into the input tensor at the given indices. + """Returns a tensor where the tokens at the given indices are replaced by the + mask token. Args: tokens: @@ -391,6 +392,27 @@ def mask_at_index( return (1 - mask) * tokens + mask * mask_token +def mask_bool(tokens: Tensor, mask: Tensor, mask_token: Tensor) -> Tensor: + """Returns a tensor with tokens replaced by the mask tokens in all positions where + the mask is True. + + Args: + tokens: + Tokens tensor with shape (batch_size, sequence_length, dim). + mask: + Boolean mask tensor with shape (batch_size, sequence_length). + mask_token: + Mask token with shape (1, 1, dim). + + Returns: + Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j] + is replaced by the mask token if mask[i, j] is True. + """ + # Convert to int for multiplication. + mask = mask.unsqueeze(-1).to(torch.bool).to(torch.int) + return (1 - mask) * tokens + mask * mask_token + + def prepend_class_token( tokens: torch.Tensor, class_token: torch.Tensor ) -> torch.Tensor: diff --git a/tests/models/modules/masked_vision_transformer_test.py b/tests/models/modules/masked_vision_transformer_test.py index ad87352ef..6d995a4c4 100644 --- a/tests/models/modules/masked_vision_transformer_test.py +++ b/tests/models/modules/masked_vision_transformer_test.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Tuple import pytest import torch from pytest_mock import MockerFixture +from torch import Tensor from torch.nn import Parameter from lightly.models import utils @@ -63,9 +64,20 @@ def test_sequence_length(self) -> None: expected_sequence_length = 49 + 1 + 0 assert model.sequence_length == expected_sequence_length + @pytest.mark.slow @pytest.mark.parametrize("device", ["cpu", "cuda"]) - @pytest.mark.parametrize("mask_ratio", [None, 0.6]) - def test_forward(self, device: str, mask_ratio: Optional[float]) -> None: + @pytest.mark.parametrize( + "idx_mask_ratio, bool_mask_ratio", + [(None, None), (0.6, None), (None, 0.6)], + ) + @pytest.mark.parametrize("idx_keep_none", [False, True]) + def test_forward( + self, + device: str, + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + idx_keep_none: bool, + ) -> None: if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available.") @@ -77,27 +89,37 @@ def test_forward(self, device: str, mask_ratio: Optional[float]) -> None: ).to(device) images = torch.rand(batch_size, 3, 224, 224).to(device) - idx_keep = None - if mask_ratio is not None: - idx_keep, _ = utils.random_token_mask( - size=(batch_size, model.sequence_length), - device=device, - mask_ratio=mask_ratio, - ) + idx_keep, idx_mask, mask = self.get_masks( + batch_size=batch_size, + idx_mask_ratio=idx_mask_ratio, + bool_mask_ratio=bool_mask_ratio, + sequence_length=model.sequence_length, + device=device, + idx_keep_none=idx_keep_none, + ) - class_tokens = model(images=images, idx_keep=idx_keep) + class_tokens = model( + images=images, idx_keep=idx_keep, idx_mask=idx_mask, mask=mask + ) - # output shape must be correct + # Output shape must be correct. assert class_tokens.shape == (batch_size, embed_dim) - # output must have reasonable numbers + # Output must have reasonable numbers. assert torch.all(torch.isfinite(class_tokens)) + @pytest.mark.slow @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize( - "mask_ratio, expected_sequence_length", [(None, 50), (0.6, 20)] + "idx_mask_ratio, bool_mask_ratio", + [(None, None), (0.6, None), (None, 0.6)], ) + @pytest.mark.parametrize("idx_keep_none", [False, True]) def test_forward_intermediates( - self, device: str, mask_ratio: Optional[float], expected_sequence_length: int + self, + device: str, + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + idx_keep_none: bool, ) -> None: if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available.") @@ -110,41 +132,104 @@ def test_forward_intermediates( ).to(device) images = torch.rand(batch_size, 3, 224, 224).to(device) - idx_keep = None - if mask_ratio is not None: - idx_keep, _ = utils.random_token_mask( - size=(batch_size, model.sequence_length), - device=device, - mask_ratio=mask_ratio, - ) + idx_keep, idx_mask, mask = self.get_masks( + batch_size=batch_size, + idx_mask_ratio=idx_mask_ratio, + bool_mask_ratio=bool_mask_ratio, + sequence_length=model.sequence_length, + device=device, + idx_keep_none=idx_keep_none, + ) output, intermediates = model.forward_intermediates( - images=images, idx_keep=idx_keep + images=images, idx_keep=idx_keep, idx_mask=idx_mask, mask=mask + ) + expected_output = model.encode( + images=images, idx_keep=idx_keep, idx_mask=idx_mask, mask=mask ) - expected_output = model.encode(images=images, idx_keep=idx_keep) - # output shape must be correct - assert output.shape == (batch_size, expected_sequence_length, embed_dim) - # output should be same as from encode + # Output shape must be correct. + assert output.shape == expected_output.shape + # Output should be same as from encode. assert torch.allclose(output, expected_output) - # intermediates must have reasonable numbers + # Intermediates must have reasonable numbers. for intermediate in intermediates: - assert intermediate.shape == ( - batch_size, - expected_sequence_length, - embed_dim, - ) + assert intermediate.shape == expected_output.shape assert torch.all(torch.isfinite(intermediate)) + @pytest.mark.slow @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize( - "mask_ratio,expected_sequence_length", [(None, 50), (0.6, 20)] + "idx_mask_ratio, bool_mask_ratio", + [(None, None), (0.6, None), (None, 0.6)], ) + @pytest.mark.parametrize("idx_keep_none", [False, True]) def test_encode( self, device: str, - mask_ratio: Optional[float], + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + idx_keep_none: bool, + ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + batch_size = 8 + embed_dim = 768 + model = self.get_masked_vit( + patch_size=32, + depth=2, + num_heads=2, + embed_dim=embed_dim, + ).to(device) + images = torch.rand(batch_size, 3, 224, 224).to(device) + + idx_keep, idx_mask, mask = self.get_masks( + batch_size=batch_size, + idx_mask_ratio=idx_mask_ratio, + bool_mask_ratio=bool_mask_ratio, + sequence_length=model.sequence_length, + device=device, + idx_keep_none=idx_keep_none, + ) + + tokens = model.encode( + images=images, idx_keep=idx_keep, idx_mask=idx_mask, mask=mask + ) + # Output shape must be correct.. + assert tokens.ndim == 3 + assert tokens.shape[0] == batch_size + # Sequence length depends on idx_keep. + assert 20 <= tokens.shape[1] <= 50 + assert tokens.shape[2] == embed_dim + # + assert torch.all(torch.isfinite(tokens)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + ( + "idx_mask_ratio, bool_mask_ratio, idx_keep_none, expected_sequence_length, " + "expected_num_masked" + ), + [ + (None, None, False, 50, 0), + (None, None, True, 50, 0), + # No masked because idx_keep != None which only returns unmasked tokens. + (0.6, None, False, 20, 0), + (0.6, None, True, 50, 30), + (None, 0.6, False, 50, 30), + (None, 0.6, True, 50, 30), + ], + ) + def test_preprocess( + self, + device: str, + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + idx_keep_none: bool, expected_sequence_length: int, + expected_num_masked: int, ) -> None: if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available.") @@ -152,22 +237,49 @@ def test_encode( torch.manual_seed(0) batch_size = 8 embed_dim = 768 + # Create mask token with large value to check if it is used correctly. + mask_token = Parameter(torch.zeros(1, 1, embed_dim) + 10_000) model = self.get_masked_vit( - patch_size=32, depth=2, num_heads=2, embed_dim=embed_dim + patch_size=32, + depth=2, + num_heads=2, + embed_dim=embed_dim, + mask_token=mask_token, ).to(device) images = torch.rand(batch_size, 3, 224, 224).to(device) - idx_keep = None - if mask_ratio is not None: - idx_keep, _ = utils.random_token_mask( - size=(batch_size, model.sequence_length), - device=device, - mask_ratio=mask_ratio, - ) + idx_keep, idx_mask, mask = self.get_masks( + batch_size=batch_size, + idx_mask_ratio=idx_mask_ratio, + bool_mask_ratio=bool_mask_ratio, + sequence_length=model.sequence_length, + device=device, + idx_keep_none=idx_keep_none, + ) - tokens = model.encode(images=images, idx_keep=idx_keep) + tokens = model.preprocess( + images=images, idx_keep=idx_keep, idx_mask=idx_mask, mask=mask + ) assert tokens.shape == (batch_size, expected_sequence_length, embed_dim) - assert torch.all(torch.isfinite(tokens)) + # Check if the mask token was used correctly. Note that the actual value + # can be smaller than 10_000 because the positional embedding is added. This is + # why we check for 1_000 instead. + assert (tokens > 1_000).sum() == batch_size * expected_num_masked * embed_dim + + def test_preprocess__fail_idx_mask_and_mask(self) -> None: + batch_size = 8 + model = self.get_masked_vit(patch_size=32, depth=2, num_heads=2, embed_dim=768) + _, idx_mask, mask = self.get_masks( + batch_size=batch_size, + idx_mask_ratio=0.6, + bool_mask_ratio=0.6, + sequence_length=model.sequence_length, + ) + images = images = torch.rand(batch_size, 3, 224, 224) + with pytest.raises( + ValueError, match="idx_mask and mask cannot both be set at the same time" + ): + model.preprocess(images=images, idx_mask=idx_mask, mask=mask) @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_images_to_tokens(self, device: str) -> None: @@ -193,7 +305,7 @@ def test_images_to_tokens(self, device: str) -> None: # (True, 2, 52), TODO(Guarin, 07/2024): Support reg_tokens > 0 ], ) - def test_add_prefix_tokens( + def test_prepend_prefix_tokens( self, device: str, class_token: bool, @@ -212,7 +324,11 @@ def test_add_prefix_tokens( reg_tokens=reg_tokens, ).to(device) x = torch.rand(2, 49, 768).to(device) - assert model.add_prefix_tokens(x).shape == (2, expected_sequence_length, 768) + assert model.prepend_prefix_tokens(x).shape == ( + 2, + expected_sequence_length, + 768, + ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize( @@ -244,3 +360,35 @@ def test_add_pos_embed( ).to(device) x = torch.rand(2, model.sequence_length, 768).to(device) assert model.add_pos_embed(x).shape == (2, expected_sequence_length, 768) + + def get_masks( + self, + batch_size: int, + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + sequence_length: int, + device: Optional[str] = None, + idx_keep_none: bool = False, + ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + idx_keep = None + idx_mask = None + if idx_mask_ratio is not None: + idx_keep, idx_mask = utils.random_token_mask( + size=(batch_size, sequence_length), + mask_ratio=idx_mask_ratio, + device=device, + ) + if idx_keep_none: + idx_keep = None + + mask = None + if bool_mask_ratio is not None: + # Create random boolean mask that has exactly bool_mask_ratio of values + # set to true. + n = int(batch_size * sequence_length) + n_masked = int(n * bool_mask_ratio) + mask = torch.randperm(n).reshape(batch_size, sequence_length) + mask = mask < n_masked + mask = mask.to(device).to(torch.bool) + assert mask.sum() == n_masked + return idx_keep, idx_mask, mask diff --git a/tests/models/modules/test_masked_vision_transformer_torchvision.py b/tests/models/modules/test_masked_vision_transformer_torchvision.py index 5b06b8cb7..32d212f33 100644 --- a/tests/models/modules/test_masked_vision_transformer_torchvision.py +++ b/tests/models/modules/test_masked_vision_transformer_torchvision.py @@ -84,12 +84,16 @@ def test__init__weight_initialization__skip(self, mocker: MockerFixture) -> None @pytest.mark.skip(reason="Torchvision ViT does not support forward intermediates") def test_forward_intermediates( - self, device: str, mask_ratio: Optional[float], expected_sequence_length: int + self, + device: str, + idx_mask_ratio: Optional[float], + bool_mask_ratio: Optional[float], + expected_sequence_length: int, ) -> None: ... @pytest.mark.skip(reason="Torchvision ViT does not support reg tokens") - def test_add_prefix_tokens( + def test_prepend_prefix_tokens( self, device: str, class_token: bool, diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index 26ed066ed..46563f79a 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -5,6 +5,7 @@ import pytest import torch import torch.nn as nn +from torch import Tensor from torch.nn import Identity from lightly.models import utils @@ -301,6 +302,50 @@ def test_random_token_mask_cuda(self): self._test_random_token_mask_parameters(device="cuda") +@pytest.mark.parametrize( + "mask, expected", + [ + ( + [ + [0, 0, 0], + [0, 0, 0], + ], + [ + [[2, 2], [2, 2], [2, 2]], + [[2, 2], [2, 2], [2, 2]], + ], + ), + ( + [ + [1, 1, 1], + [1, 1, 1], + ], + [ + [[3, 3], [3, 3], [3, 3]], + [[3, 3], [3, 3], [3, 3]], + ], + ), + ( + [ + [0, 1, 0], + [1, 0, 1], + ], + [ + [[2, 2], [3, 3], [2, 2]], + [[3, 3], [2, 2], [3, 3]], + ], + ), + ], +) +def test_mask_ones(mask: Tensor, expected: Tensor) -> None: + tokens = torch.zeros(2, 3, 2) + 2 + mask_token = torch.zeros(1, 1, 2) + 3 + result = utils.mask_bool( + tokens=tokens, mask=torch.tensor(mask, dtype=torch.bool), mask_token=mask_token + ) + assert torch.allclose(result, torch.tensor(expected, dtype=torch.float)) + + @pytest.mark.parametrize( "x, y, expected", [