diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index 88cc23bab..d891d1f33 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -39,7 +39,6 @@ class IJEPAPredictor(vision_transformer.Encoder): Percentage of elements set to zero after the MLP in the transformer. attention_dropout: Percentage of elements set to zero after the attention head. - """ def __init__( @@ -56,6 +55,8 @@ def __init__( norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), **kwargs, ): + """Initializes the IJEPAPredictor with the specified dimensions.""" + super().__init__( seq_length=seq_length, num_layers=num_layers, @@ -81,7 +82,16 @@ def __init__( @classmethod def from_vit_encoder(cls, vit_encoder, num_patches): - """Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder.""" + """Creates an I-JEPA predictor backbone (multi-head attention and layernorm) from a torchvision ViT encoder. + + Args: + vit_encoder: The Vision Transformer encoder from torchvision. + num_patches: The number of patches (tokens). + + Returns: + IJEPAPredictor: An I-JEPA predictor backbone initialized from the ViT encoder. + """ + # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes encoder = cls( @@ -95,11 +105,27 @@ def from_vit_encoder(cls, vit_encoder, num_patches): dropout=0, attention_dropout=0, ) + + # Copy attributes from the ViT encoder encoder.layers = vit_encoder.layers encoder.ln = vit_encoder.ln + return encoder def forward(self, x, masks_x, masks): + """Forward pass of the IJEPAPredictor. + + Args: + x: + Input tensor. + masks_x: + Mask indices for the input tensor. + masks: + Mask indices for the predicted tokens. + + Returns: + The predicted output tensor. + """ assert (masks is not None) and ( masks_x is not None ), "Cannot run predictor without mask indices" @@ -160,7 +186,6 @@ class IJEPAEncoder(vision_transformer.Encoder): Percentage of elements set to zero after the MLP in the transformer. attention_dropout: Percentage of elements set to zero after the attention head. - """ def __init__( @@ -174,6 +199,8 @@ def __init__( attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): + """Initializes the IJEPAEncoder with the specified dimensions.""" + super().__init__( seq_length=seq_length, num_layers=num_layers, @@ -188,6 +215,7 @@ def __init__( @classmethod def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): """Creates a IJEPA encoder from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes encoder = cls( @@ -221,6 +249,7 @@ def forward( Returns: Batch of encoded output tokens. """ + input = input + self.interpolate_pos_encoding(input) if idx_keep is not None: input = utils.apply_masks(input, idx_keep) @@ -236,7 +265,11 @@ def interpolate_pos_encoding(self, input: torch.Tensor): input: Input tensor with shape (batch_size, num_sequences). + Returns: + Interpolated positional embedding. + """ + # code copied from: # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 npatch = input.shape[1] - 1 @@ -264,6 +297,7 @@ class IJEPABackbone(vision_transformer.VisionTransformer): in the future. Converts images into patches and encodes them. Code inspired by [1]. + Note that this implementation uses a learned positional embedding while [0] uses a fixed positional embedding. @@ -342,6 +376,7 @@ def __init__( @classmethod def from_vit(cls, vit: vision_transformer.VisionTransformer): """Creates a IJEPABackbone from a torchvision ViT model.""" + # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes backbone = cls( @@ -357,18 +392,20 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer): representation_size=vit.representation_size, norm_layer=vit.norm_layer, ) + + # Copy attributes from the ViT model backbone.conv_proj = vit.conv_proj backbone.class_token = vit.class_token backbone.seq_length = vit.seq_length backbone.heads = vit.heads backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder) + return backbone def forward( self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None ) -> torch.Tensor: - """ - Returns encoded class tokens from a batch of images. + """Returns encoded class tokens from a batch of images. Args: images: @@ -382,8 +419,8 @@ def forward( Returns: Tensor with shape (batch_size, hidden_dim) containing the encoded class token for every image. - """ + if idx_keep is not None: if not isinstance(idx_keep, list): idx_keep = [idx_keep] @@ -421,6 +458,8 @@ def images_to_tokens( Args: images: Tensor with shape (batch_size, channels, image_size, image_size). + prepend_class_token: + Whether to prepend the class token to the patch tokens. Returns: Tensor with shape (batch_size, sequence_length - 1, hidden_dim) diff --git a/lightly/models/modules/ijepa_timm.py b/lightly/models/modules/ijepa_timm.py index b59e149a1..89e3dca76 100644 --- a/lightly/models/modules/ijepa_timm.py +++ b/lightly/models/modules/ijepa_timm.py @@ -42,7 +42,6 @@ class IJEPAPredictorTIMM(nn.Module): Percentage of elements set to zero after the attention head. norm_layer: Normalization layer. - """ def __init__( @@ -59,6 +58,8 @@ def __init__( attn_drop_rate: float = 0.0, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): + """Initializes the IJEPAPredictorTIMM with the specified dimensions.""" + super().__init__() self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) @@ -98,6 +99,20 @@ def forward( masks_x: Union[List[torch.Tensor], torch.Tensor], masks: Union[List[torch.Tensor], torch.Tensor], ) -> torch.Tensor: + """Forward pass of the IJEPAPredictorTIMM. + + Args: + x: + Input tensor. + masks_x: + Mask indices for the input tensor. + masks: + Mask indices for the predicted tokens. + + Returns: + The predicted output tensor. + """ + assert (masks is not None) and ( masks_x is not None ), "Cannot run predictor without mask indices" @@ -147,16 +162,20 @@ def repeat_interleave_batch( def apply_masks( self, x: torch.Tensor, masks: Union[torch.Tensor, List[torch.Tensor]] ) -> torch.Tensor: - """ + """Apply masks to the input tensor. + From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py - Apply masks to the input tensor. Args: - x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] - masks: tensor or list of tensors containing indices of patches in [N] to keep + x: + tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]. + masks: + tensor or list of tensors containing indices of patches in [N] to keep. + Returns: - tensor of shape [B, N', D] where N' is the number of patches to keep + Tensor of shape [B, N', D] where N' is the number of patches to keep. """ + if not isinstance(masks, list): masks = [masks] diff --git a/lightly/models/modules/masked_autoencoder.py b/lightly/models/modules/masked_autoencoder.py index f7a5a2902..2e4f49719 100644 --- a/lightly/models/modules/masked_autoencoder.py +++ b/lightly/models/modules/masked_autoencoder.py @@ -38,7 +38,6 @@ class MAEEncoder(vision_transformer.Encoder): Percentage of elements set to zero after the MLP in the transformer. attention_dropout: Percentage of elements set to zero after the attention head. - """ def __init__( @@ -52,6 +51,8 @@ def __init__( attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): + """Initializes the MAEEncoder with the specified dimensions.""" + super().__init__( seq_length=seq_length, num_layers=num_layers, @@ -79,8 +80,8 @@ def from_vit_encoder( Returns: A MAEEncoder with the same architecture as vit_encoder. - """ + # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes encoder = cls( @@ -92,10 +93,13 @@ def from_vit_encoder( dropout=0, attention_dropout=0, ) + + # Copy attributes from the ViT encoder encoder.pos_embedding = vit_encoder.pos_embedding encoder.dropout = vit_encoder.dropout encoder.layers = vit_encoder.layers encoder.ln = vit_encoder.ln + if initialize_weights: encoder._initialize_weights() return encoder @@ -131,6 +135,9 @@ def interpolate_pos_encoding(self, input: torch.Tensor): input: Input tensor with shape (batch_size, num_sequences). + Returns: + Interpolated positional embedding. + """ # code copied from: # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 @@ -138,9 +145,12 @@ def interpolate_pos_encoding(self, input: torch.Tensor): N = self.pos_embedding.shape[1] - 1 if npatch == N: return self.pos_embedding + + # Separate the class embedding from the positional embeddings class_emb = self.pos_embedding[:, 0] pos_embedding = self.pos_embedding[:, 1:] dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( 0, 3, 1, 2 @@ -215,6 +225,8 @@ def __init__( norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): + """Initializes the MAEBackbone with the specified dimensions.""" + super().__init__( image_size=image_size, patch_size=patch_size, @@ -272,6 +284,8 @@ def from_vit( representation_size=vit.representation_size, norm_layer=vit.norm_layer, ) + + # Copy attributes from the ViT model backbone.conv_proj = vit.conv_proj backbone.class_token = vit.class_token backbone.seq_length = vit.seq_length @@ -334,11 +348,14 @@ def images_to_tokens( Args: images: Tensor with shape (batch_size, channels, image_size, image_size). + prepend_class_token: + Whether to prepend the class token to the patch tokens. Returns: Tensor with shape (batch_size, sequence_length - 1, hidden_dim) containing the patch tokens. """ + x = self.conv_proj(images) tokens = x.flatten(2).transpose(1, 2) if prepend_class_token: @@ -346,6 +363,8 @@ def images_to_tokens( return tokens def _initialize_weights(self) -> None: + """Initializes weights for the backbone components.""" + # Initialize the patch embedding layer like a linear layer instead of conv # layer. w = self.conv_proj.weight.data @@ -404,6 +423,8 @@ def __init__( attention_dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): + """Initializes the MAEDecoder with the specified dimensions.""" + super().__init__( seq_length=seq_length, num_layers=num_layers, @@ -427,8 +448,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: Returns: Tensor with shape (batch_size, seq_length, out_dim). - """ + out = self.embed(input) out = self.decode(out) return self.predict(out) @@ -487,6 +508,8 @@ def _initialize_weights(self) -> None: def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None: + """Initializes a 2D sine-cosine positional embedding.""" + _, seq_length, hidden_dim = pos_embedding.shape grid_size = int((seq_length - 1) ** 0.5) sine_cosine_embedding = utils.get_2d_sine_cosine_positional_embedding( @@ -502,6 +525,8 @@ def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> def _initialize_linear_layers(module: Module) -> None: + """Initializes linear layers in the given module.""" + def init(mod: Module) -> None: if isinstance(mod, Linear): nn.init.xavier_uniform_(mod.weight) diff --git a/lightly/models/modules/masked_autoencoder_timm.py b/lightly/models/modules/masked_autoencoder_timm.py index ad034b468..5a4286ac5 100644 --- a/lightly/models/modules/masked_autoencoder_timm.py +++ b/lightly/models/modules/masked_autoencoder_timm.py @@ -66,6 +66,8 @@ def __init__( initialize_weights: bool = True, mask_token: Optional[Parameter] = None, ): + """Initializes the MAEDecoderTIMM with the specified parameters.""" + super().__init__() self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) @@ -75,7 +77,7 @@ def __init__( else mask_token ) - # positional encoding of the decoder + # Positional encoding of the decoder self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False ) # fixed sin-cos embedding @@ -112,8 +114,8 @@ def forward(self, input: Tensor) -> Tensor: Returns: Tensor with shape (batch_size, seq_length, out_dim). - """ + out = self.embed(input) out = self.decode(out) return self.predict(out) @@ -156,7 +158,7 @@ def decode(self, input: Tensor) -> Tensor: return output def predict(self, input: Tensor) -> Tensor: - """Predics pixel values from decoded tokens. + """Predicts pixel values from decoded tokens. Args: input: @@ -172,6 +174,8 @@ def predict(self, input: Tensor) -> Tensor: return out def _initialize_weights(self) -> None: + """Initializes weights for the decoder components.""" + torch.nn.init.normal_(self.mask_token, std=0.02) utils.initialize_2d_sine_cosine_positional_embedding( pos_embedding=self.decoder_pos_embed, has_class_token=True diff --git a/lightly/models/modules/masked_causal_vision_transformer.py b/lightly/models/modules/masked_causal_vision_transformer.py index 61a354503..e07e8b02c 100644 --- a/lightly/models/modules/masked_causal_vision_transformer.py +++ b/lightly/models/modules/masked_causal_vision_transformer.py @@ -68,6 +68,18 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: def _get_attention_mask( self, x: Tensor, mask: Optional[Tensor] ) -> Optional[Tensor]: + """Generates an attention mask for causal attention. + + Args: + x: + Input tensor of shape (batch_size, sequence_length, channels). + mask: + Mask tensor of shape (batch_size, sequence_length) indicating which tokens + should be masked. + + Returns: + Attention mask of shape (batch_size, 1, sequence_length, sequence_length). + """ B, N = x.shape[:2] # Only apply causal attention if mask is not None. This is a bit hacky, but it @@ -108,6 +120,34 @@ def __init__( norm_layer: Type[Module] = LayerNorm, mlp_layer: Type[Module] = Mlp, ) -> None: + """Initializes the MaskedCausalBlock with the specified parameters. + + Args: + dim: + Dimension of the input tokens. + num_heads: + Number of attention heads. + mlp_ratio: + Ratio of MLP hidden dim to embedding dim. + qkv_bias: + If True, add bias to the query, key, and value tensors. + qk_norm: + If True, apply layer normalization to queries and keys. + proj_drop: + Percentage of elements set to zero after the projection layer. + attn_drop: + Percentage of elements set to zero after the attention head. + init_values: + Initial values for the layer. + drop_path: + Drop path rate for the block. + act_layer: + Activation layer to use. + norm_layer: + Normalization layer to use. + mlp_layer: + MLP layer to use. + """ super().__init__( dim=dim, num_heads=num_heads, @@ -144,6 +184,9 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: causal attention, while unmasked tokens are used for bidirectional attention. If the mask is None, all tokens are used for bidirectional attention. + + Returns: + Output tensor after applying the attention block. """ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), mask=mask))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) @@ -191,6 +234,72 @@ def __init__( block_fn: Type[Module] = MaskedCausalBlock, mlp_layer: Type[Module] = Mlp, ) -> None: + """Initializes the MaskedCausalVisionTransformer with the specified parameters. + + Args: + img_size: + Input image size. + patch_size: + Width and height of the image patches. + in_chans: + Number of image input channels. + num_classes: + Number of classes for the classification head. + global_pool: + Global pooling type. + embed_dim: + Embedding dimension. + depth: + Depth of the transformer. + num_heads: + Number of attention heads. + mlp_ratio: + Ratio of MLP hidden dim to embedding dim. + qkv_bias: + If True, add bias to the query, key, and value tensors. + qk_norm: + If True, apply layer normalization to queries and keys. + init_values: + Initial values for the layer. + class_token: + If True, add class token to the embeddings. + no_embed_class: + If True, do not embed class token. + reg_tokens: + Number of regularization tokens. + pre_norm : + If True, apply layer normalization before the transformer. + fc_norm: + If True, apply layer normalization to the final fully connected layer. + dynamic_img_size: + If True, dynamically adjust the image size. + dynamic_img_pad: + If True, dynamically pad the image. + drop_rate: + Percentage of elements set to zero after the dropout layer. + pos_drop_rate: + Percentage of elements set to zero after the positional dropout layer. + patch_drop_rate: + Percentage of elements set to zero after the patch dropout layer. + proj_drop_rate: + Percentage of elements set to zero after the projection dropout layer. + attn_drop_rate: + Percentage of elements set to zero after the attention head dropout. + drop_path_rate: + Drop path rate for the block. + weight_init: + Weight initialization method. + embed_layer: + Callable that creates the embedding layer. + norm_layer: + Normalization layer to use. + act_layer: + Activation layer to use. + block_fn: + Block function to use. + mlp_layer: + MLP layer to use. + """ super().__init__( img_size=img_size, patch_size=patch_size, @@ -237,6 +346,9 @@ def forward_features(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: causal attention, while unmasked tokens are used for bidirectional attention. If the mask is None, all tokens are used for bidirectional attention. + + Returns: + Output tensor after applying the transformer blocks. """ x = self.patch_embed(x) x = self._pos_embed(x)