diff --git a/.gitignore b/.gitignore index 2fdc54c33..031ae8795 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ lightning_logs/ **lightning_logs/ **/__MACOSX +datasets/ docs/source/tutorials/package/* docs/source/tutorials/platform/* docs/source/tutorials_source/platform/data diff --git a/examples/pytorch/ijepa.py b/examples/pytorch/ijepa.py new file mode 100644 index 000000000..eb4730e04 --- /dev/null +++ b/examples/pytorch/ijepa.py @@ -0,0 +1,117 @@ +import copy + +import torch +import torchvision +from torch import nn +from torch.nn import functional as F +from tqdm import tqdm + +from lightly.data.collate import IJEPAMaskCollator +from lightly.models import utils +from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor +from lightly.transforms.ijepa_transform import IJEPATransform + + +class IJEPA(nn.Module): + def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): + super().__init__() + self.encoder = IJEPABackbone.from_vit(vit_encoder) + self.predictor = IJEPAPredictor.from_vit_encoder( + vit_predictor.encoder, + (vit_predictor.image_size // vit_predictor.patch_size) ** 2, + ) + self.target_encoder = copy.deepcopy(self.encoder) + self.momentum_scheduler = momentum_scheduler + + def forward_target(self, imgs, masks_enc, masks_pred): + with torch.no_grad(): + h = self.target_encoder(imgs) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim + B = len(h) + # -- create targets (masked regions of h) + h = utils.apply_masks(h, masks_pred) + h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc)) + return h + + def forward_context(self, imgs, masks_enc, masks_pred): + z = self.encoder(imgs, masks_enc) + z = self.predictor(z, masks_enc, masks_pred) + return z + + def forward(self, imgs, masks_enc, masks_pred): + z = self.forward_context(imgs, masks_enc, masks_pred) + h = self.forward_target(imgs, masks_enc, masks_pred) + return z, h + + def update_target_encoder( + self, + ): + with torch.no_grad(): + m = next(self.momentum_scheduler) + for param_q, param_k in zip( + self.encoder.parameters(), self.target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) + + +collator = IJEPAMaskCollator( + input_size=(224, 224), + patch_size=32, +) + +transform = IJEPATransform() + +# we ignore object detection annotations by setting target_transform to return 0 +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +data_loader = torch.utils.data.DataLoader( + dataset, collate_fn=collator, batch_size=10, persistent_workers=False +) + +ema = (0.996, 1.0) +ipe_scale = 1.0 +ipe = len(data_loader) +num_epochs = 10 +momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) +) + +vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) +vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) +model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler) + +criterion = nn.SmoothL1Loss() +optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + +print("Starting Training") +for epoch in range(num_epochs): + total_loss = 0 + for udata, masks_enc, masks_pred in tqdm(data_loader): + + def load_imgs(): + # -- unsupervised imgs + imgs = udata[0].to(device, non_blocking=True) + masks_1 = [u.to(device, non_blocking=True) for u in masks_enc] + masks_2 = [u.to(device, non_blocking=True) for u in masks_pred] + return (imgs, masks_1, masks_2) + + imgs, masks_enc, masks_pred = load_imgs() + z, h = model(imgs, masks_enc, masks_pred) + loss = criterion(z, h) + total_loss += loss.detach() + loss.backward() + optimizer.step() + optimizer.zero_grad() + model.update_target_encoder() + + avg_loss = total_loss / len(data_loader) + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/examples/pytorch_lightning/ijepa.py b/examples/pytorch_lightning/ijepa.py new file mode 100644 index 000000000..464090415 --- /dev/null +++ b/examples/pytorch_lightning/ijepa.py @@ -0,0 +1 @@ +# TODO diff --git a/examples/pytorch_lightning_distributed/ijepa.py b/examples/pytorch_lightning_distributed/ijepa.py new file mode 100644 index 000000000..464090415 --- /dev/null +++ b/examples/pytorch_lightning_distributed/ijepa.py @@ -0,0 +1 @@ +# TODO diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 25720171d..3a935951f 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -3,6 +3,8 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +import math +from multiprocessing import Value from typing import List, Optional, Tuple, Union from warnings import warn @@ -1345,6 +1347,176 @@ def forward( return (views_global, views_local, grids_global, grids_local), labels, fnames +class IJEPAMaskCollator: + """Collator for IJEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + """ + + def __init__( + self, + input_size=(224, 224), + patch_size=16, + enc_mask_scale=(0.2, 0.8), + pred_mask_scale=(0.2, 0.8), + aspect_ratio=(0.3, 3.0), + nenc=1, + npred=2, + min_keep=4, + allow_overlap=False, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.patch_size = patch_size + self.height, self.width = ( + input_size[0] // patch_size, + input_size[1] // patch_size, + ) + self.enc_mask_scale = enc_mask_scale + self.pred_mask_scale = pred_mask_scale + self.aspect_ratio = aspect_ratio + self.nenc = nenc + self.npred = npred + self.min_keep = min_keep # minimum number of patches to keep + self.allow_overlap = ( + allow_overlap # whether to allow overlap b/w enc and pred masks + ) + self._itr_counter = Value("i", -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size(self, generator, scale, aspect_ratio_scale): + _rand = torch.rand(1, generator=generator).item() + # -- Sample block scale + min_s, max_s = scale + mask_scale = min_s + _rand * (max_s - min_s) + max_keep = int(self.height * self.width * mask_scale) + # -- Sample block aspect-ratio + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + w = int(round(math.sqrt(max_keep / aspect_ratio))) + while h >= self.height: + h -= 1 + while w >= self.width: + w -= 1 + + return (h, w) + + def _sample_block_mask(self, b_size, acceptable_regions=None): + h, w = b_size + + def constrain_mask(mask, tries=0): + """Helper to restrict given mask to a set of acceptable regions""" + N = max(int(len(acceptable_regions) - tries), 0) + for k in range(N): + mask *= acceptable_regions[k] + + # -- + # -- Loop to sample masks until we find a valid one + tries = 0 + timeout = og_timeout = 20 + valid_mask = False + while not valid_mask: + # -- Sample block top-left corner + top = torch.randint(0, self.height - h, (1,)) + left = torch.randint(0, self.width - w, (1,)) + mask = torch.zeros((self.height, self.width), dtype=torch.int32) + mask[top : top + h, left : left + w] = 1 + # -- Constrain mask to a set of acceptable regions + if acceptable_regions is not None: + constrain_mask(mask, tries) + mask = torch.nonzero(mask.flatten()) + # -- If mask too small try again + valid_mask = len(mask) > self.min_keep + if not valid_mask: + timeout -= 1 + if timeout == 0: + tries += 1 + timeout = og_timeout + mask = mask.squeeze() + # -- + mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) + mask_complement[top : top + h, left : left + w] = 0 + # -- + return mask, mask_complement + + def __call__(self, batch): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample enc block (size + location) using seed + # 2. sample pred block (size) using seed + # 3. sample several enc block locations for each image (w/o seed) + # 4. sample several pred block locations for each image (w/o seed) + # 5. return enc mask and pred mask + """ + B = len(batch) + + collated_batch = torch.utils.data.default_collate(batch) + + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + scale=self.pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + e_size = self._sample_block_size( + generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0) + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_pred = self.height * self.width + min_keep_enc = self.height * self.width + for _ in range(B): + masks_p, masks_C = [], [] + for _ in range(self.npred): + mask, mask_C = self._sample_block_mask(p_size) + masks_p.append(mask) + masks_C.append(mask_C) + min_keep_pred = min(min_keep_pred, len(mask)) + collated_masks_pred.append(masks_p) + + acceptable_regions = masks_C + + if self.allow_overlap: + acceptable_regions = None + + masks_e = [] + for _ in range(self.nenc): + mask, _ = self._sample_block_mask( + e_size, acceptable_regions=acceptable_regions + ) + masks_e.append(mask) + min_keep_enc = min(min_keep_enc, len(mask)) + collated_masks_enc.append(masks_e) + + collated_masks_pred = [ + [cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred + ] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [ + [cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc + ] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_batch, collated_masks_enc, collated_masks_pred + + def _deprecation_warning_collate_functions() -> None: warn( "Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n" diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py new file mode 100644 index 000000000..3eb14a247 --- /dev/null +++ b/lightly/models/modules/ijepa.py @@ -0,0 +1,496 @@ +import math +from functools import partial +from typing import Callable, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torchvision.models import vision_transformer +from torchvision.models.vision_transformer import ConvStemConfig + +from lightly.models import utils + + +class IJEPAPredictor(vision_transformer.Encoder): + """Predictor for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Predict patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + predictor_embed_dim: + Dimension of inner predicted tokens + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + + """ + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + predictor_embed_dim: int, + num_patches: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + **kwargs + ): + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + self.predictor_proj = nn.Linear(predictor_embed_dim, mlp_dim, bias=True) + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + ) + predictor_pos_embed = _get_2d_sincos_pos_embed( + self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False + ) + self.predictor_pos_embed.data.copy_( + torch.from_numpy(predictor_pos_embed).float().unsqueeze(0) + ) + + @classmethod + def from_vit_encoder(cls, vit_encoder, num_patches): + """Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + predictor_embed_dim=768, + mlp_dim=768, + num_patches=num_patches, + dropout=0, + attention_dropout=0, + ) + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward(self, x, masks_x, masks): + assert (masks is not None) and ( + masks_x is not None + ), "Cannot run predictor without mask indices" + + if not isinstance(masks_x, list): + masks_x = [masks_x] + + if not isinstance(masks, list): + masks = [masks] + + B = len(x) // len(masks_x) + x = self.predictor_embed(x) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + + x += utils.apply_masks(x_pos_embed, masks_x) + _, N_ctxt, _ = x.shape + + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = utils.apply_masks(pos_embs, masks) + pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len(masks_x)) + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + + pred_tokens += pos_embs + x = x.repeat(len(masks), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + x = self.ln(self.layers(x)) + + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +class IJEPAEncoder(vision_transformer.Encoder): + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Encodes patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + + """ + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): + """Creates a IJEPA encoder from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + mlp_dim=1, + dropout=0, + attention_dropout=0, + ) + encoder.pos_embedding = vit_encoder.pos_embedding + encoder.dropout = vit_encoder.dropout + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward( + self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Encode input tokens. + + Args: + input: + Batch of token sequences. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + input = input + self.interpolate_pos_encoding(input) + if idx_keep is not None: + input = utils.apply_masks(input, idx_keep) + return self.ln(self.layers(self.dropout(input))) + + def interpolate_pos_encoding(self, input: torch.Tensor): + """Returns the interpolated positional embedding for the given input. + + This function interpolates self.pos_embedding for all tokens in the input, + ignoring the class token. This allows encoding variable sized images. + + Args: + input: + Input tensor with shape (batch_size, num_sequences). + + """ + # code copied from: + # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 + npatch = input.shape[1] - 1 + N = self.pos_embedding.shape[1] - 1 + if npatch == N: + return self.pos_embedding + class_emb = self.pos_embedding[:, 0] + pos_embedding = self.pos_embedding[:, 1:] + dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(npatch / N), + mode="bicubic", + ) + pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + + +class IJEPABackbone(vision_transformer.VisionTransformer): + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + Converts images into patches and encodes them. Code inspired by [1]. + Note that this implementation uses a learned positional embedding while [0] + uses a fixed positional embedding. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + image_size: + Input image size. + patch_size: + Width and height of the image patches. image_size must be a multiple + of patch_size. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + num_classes: + Number of classes for the classification head. Currently not used. + representation_size: + If specified, an additional linear layer is added before the + classification head to change the token dimension from hidden_dim + to representation_size. Currently not used. + norm_layer: + Callable that creates a normalization layer. + + """ + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0, + attention_dropout: float = 0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_classes=num_classes, + representation_size=representation_size, + norm_layer=norm_layer, + conv_stem_configs=conv_stem_configs, + ) + self.encoder = IJEPAEncoder( + seq_length=self.seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit(cls, vit: vision_transformer.VisionTransformer): + """Creates a IJEPABackbone from a torchvision ViT model.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + backbone = cls( + image_size=vit.image_size, + patch_size=vit.patch_size, + num_layers=1, + num_heads=1, + hidden_dim=vit.hidden_dim, + mlp_dim=vit.mlp_dim, + dropout=vit.dropout, + attention_dropout=vit.attention_dropout, + num_classes=vit.num_classes, + representation_size=vit.representation_size, + norm_layer=vit.norm_layer, + ) + backbone.conv_proj = vit.conv_proj + backbone.class_token = vit.class_token + backbone.seq_length = vit.seq_length + backbone.heads = vit.heads + backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder) + return backbone + + def forward( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, hidden_dim) containing the + encoded class token for every image. + + """ + if idx_keep is not None: + if not isinstance(idx_keep, list): + idx_keep = [idx_keep] + + out = self.encode(images, idx_keep) + return out + + def encode( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns encoded class and patch tokens from images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, sequence_length, hidden_dim) + containing the encoded class and patch tokens for every image. + + """ + out = self.images_to_tokens(images, prepend_class_token=True) + return self.encoder(out, idx_keep) + + def images_to_tokens( + self, images: torch.Tensor, prepend_class_token: bool + ) -> torch.Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + + Returns: + Tensor with shape (batch_size, sequence_length - 1, hidden_dim) + containing the patch tokens. + """ + x = self.conv_proj(images) + tokens = x.flatten(2).transpose(1, 2) + if prepend_class_token: + tokens = utils.prepend_class_token(tokens, self.class_token) + return tokens + + +def _get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def _get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid length + return: + pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = _get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/lightly/models/utils.py b/lightly/models/utils.py index d2ccce2fd..7215395a3 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -568,3 +568,31 @@ def get_weight_decay_parameters( else: params.append(param) return params, params_no_weight_decay + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + return _no_grad_trunc_normal(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat( + [ + torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) + for i in range(N) + ], + dim=0, + ) + return x diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py new file mode 100644 index 000000000..321dba66a --- /dev/null +++ b/lightly/transforms/ijepa_transform.py @@ -0,0 +1,58 @@ +from typing import Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor + +from lightly.transforms.utils import IMAGENET_NORMALIZE + + +class IJEPATransform: + """Implements the augmentations for I-JEPA [0, 1]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + input_size: + Size of the input image in pixels. + min_scale: + Minimum size of the randomized crop relative to the input_size. + normalize: + Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. + + """ + + def __init__( + self, + input_size: Union[int, Tuple[int, int]] = 224, + min_scale: float = 0.2, + normalize: dict = IMAGENET_NORMALIZE, + ): + transforms = [ + T.RandomResizedCrop( + input_size, scale=(min_scale, 1.0), interpolation=3 + ), # 3 is bicubic + T.RandomHorizontalFlip(), + T.ToTensor(), + ] + if normalize: + transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"])) + + self.transform = T.Compose(transforms) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + return self.transform(image)