Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Documentation for lightly/models/modules #1704

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions lightly/models/modules/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class IJEPAPredictor(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.

"""

def __init__(
Expand All @@ -56,6 +55,8 @@ def __init__(
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
**kwargs,
):
"""Initializes the IJEPAPredictor with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -81,7 +82,16 @@ def __init__(

@classmethod
def from_vit_encoder(cls, vit_encoder, num_patches):
"""Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder."""
"""Creates an I-JEPA predictor backbone (multi-head attention and layernorm) from a torchvision ViT encoder.

Args:
vit_encoder: The Vision Transformer encoder from torchvision.
num_patches: The number of patches (tokens).

Returns:
IJEPAPredictor: An I-JEPA predictor backbone initialized from the ViT encoder.
"""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand All @@ -95,11 +105,27 @@ def from_vit_encoder(cls, vit_encoder, num_patches):
dropout=0,
attention_dropout=0,
)

# Copy attributes from the ViT encoder
encoder.layers = vit_encoder.layers
encoder.ln = vit_encoder.ln

return encoder

def forward(self, x, masks_x, masks):
"""Forward pass of the IJEPAPredictor.

Args:
x:
Input tensor.
masks_x:
Mask indices for the input tensor.
masks:
Mask indices for the predicted tokens.
MalteEbner marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The predicted output tensor.
"""
assert (masks is not None) and (
masks_x is not None
), "Cannot run predictor without mask indices"
Expand Down Expand Up @@ -160,7 +186,6 @@ class IJEPAEncoder(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.

"""

def __init__(
Expand All @@ -174,6 +199,8 @@ def __init__(
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the IJEPAEncoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -188,6 +215,7 @@ def __init__(
@classmethod
def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder):
"""Creates a IJEPA encoder from a torchvision ViT encoder."""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand Down Expand Up @@ -221,6 +249,7 @@ def forward(
Returns:
Batch of encoded output tokens.
"""

input = input + self.interpolate_pos_encoding(input)
if idx_keep is not None:
input = utils.apply_masks(input, idx_keep)
Expand All @@ -236,7 +265,11 @@ def interpolate_pos_encoding(self, input: torch.Tensor):
input:
Input tensor with shape (batch_size, num_sequences).

Returns:
Interpolated positional embedding.

"""

# code copied from:
# https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291
npatch = input.shape[1] - 1
Expand Down Expand Up @@ -264,6 +297,7 @@ class IJEPABackbone(vision_transformer.VisionTransformer):
in the future.

Converts images into patches and encodes them. Code inspired by [1].

Note that this implementation uses a learned positional embedding while [0]
uses a fixed positional embedding.

Expand Down Expand Up @@ -342,6 +376,7 @@ def __init__(
@classmethod
def from_vit(cls, vit: vision_transformer.VisionTransformer):
"""Creates a IJEPABackbone from a torchvision ViT model."""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
backbone = cls(
Expand All @@ -357,18 +392,20 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer):
representation_size=vit.representation_size,
norm_layer=vit.norm_layer,
)

# Copy attributes from the ViT model
backbone.conv_proj = vit.conv_proj
backbone.class_token = vit.class_token
backbone.seq_length = vit.seq_length
backbone.heads = vit.heads
backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder)

return backbone

def forward(
self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Returns encoded class tokens from a batch of images.
"""Returns encoded class tokens from a batch of images.

Args:
images:
Expand All @@ -382,8 +419,8 @@ def forward(
Returns:
Tensor with shape (batch_size, hidden_dim) containing the
encoded class token for every image.

"""

if idx_keep is not None:
if not isinstance(idx_keep, list):
idx_keep = [idx_keep]
Expand Down Expand Up @@ -421,6 +458,8 @@ def images_to_tokens(
Args:
images:
Tensor with shape (batch_size, channels, image_size, image_size).
prepend_class_token:
Whether to prepend the class token to the patch tokens.

Returns:
Tensor with shape (batch_size, sequence_length - 1, hidden_dim)
Expand Down
31 changes: 25 additions & 6 deletions lightly/models/modules/ijepa_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class IJEPAPredictorTIMM(nn.Module):
Percentage of elements set to zero after the attention head.
norm_layer:
Normalization layer.

"""

def __init__(
Expand All @@ -59,6 +58,8 @@ def __init__(
attn_drop_rate: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the IJEPAPredictorTIMM with the specified dimensions."""

super().__init__()

self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True)
Expand Down Expand Up @@ -98,6 +99,20 @@ def forward(
masks_x: Union[List[torch.Tensor], torch.Tensor],
masks: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Forward pass of the IJEPAPredictorTIMM.

Args:
x:
Input tensor.
masks_x:
Mask indices for the input tensor.
masks:
Mask indices for the predicted tokens.
MalteEbner marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The predicted output tensor.
"""

assert (masks is not None) and (
masks_x is not None
), "Cannot run predictor without mask indices"
Expand Down Expand Up @@ -147,16 +162,20 @@ def repeat_interleave_batch(
def apply_masks(
self, x: torch.Tensor, masks: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
"""Apply masks to the input tensor.

From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py
Apply masks to the input tensor.

Args:
x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
masks: tensor or list of tensors containing indices of patches in [N] to keep
x:
tensor of shape [B (batch-size), N (num-patches), D (feature-dim)].
masks:
tensor or list of tensors containing indices of patches in [N] to keep.

Returns:
tensor of shape [B, N', D] where N' is the number of patches to keep
Tensor of shape [B, N', D] where N' is the number of patches to keep.
"""

if not isinstance(masks, list):
masks = [masks]

Expand Down
31 changes: 28 additions & 3 deletions lightly/models/modules/masked_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class MAEEncoder(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.

"""

def __init__(
Expand All @@ -52,6 +51,8 @@ def __init__(
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the MAEEncoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand Down Expand Up @@ -79,8 +80,8 @@ def from_vit_encoder(

Returns:
A MAEEncoder with the same architecture as vit_encoder.

"""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand All @@ -92,10 +93,13 @@ def from_vit_encoder(
dropout=0,
attention_dropout=0,
)

# Copy attributes from the ViT encoder
encoder.pos_embedding = vit_encoder.pos_embedding
encoder.dropout = vit_encoder.dropout
encoder.layers = vit_encoder.layers
encoder.ln = vit_encoder.ln

if initialize_weights:
encoder._initialize_weights()
return encoder
Expand Down Expand Up @@ -131,16 +135,22 @@ def interpolate_pos_encoding(self, input: torch.Tensor):
input:
Input tensor with shape (batch_size, num_sequences).

Returns:
Interpolated positional embedding.

"""
# code copied from:
# https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291
npatch = input.shape[1] - 1
N = self.pos_embedding.shape[1] - 1
if npatch == N:
return self.pos_embedding

# Separate the class embedding from the positional embeddings
class_emb = self.pos_embedding[:, 0]
pos_embedding = self.pos_embedding[:, 1:]
dim = input.shape[-1]

pos_embedding = nn.functional.interpolate(
pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
0, 3, 1, 2
Expand Down Expand Up @@ -215,6 +225,8 @@ def __init__(
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
"""Initializes the MAEBackbone with the specified dimensions."""

super().__init__(
image_size=image_size,
patch_size=patch_size,
Expand Down Expand Up @@ -272,6 +284,8 @@ def from_vit(
representation_size=vit.representation_size,
norm_layer=vit.norm_layer,
)

# Copy attributes from the ViT model
backbone.conv_proj = vit.conv_proj
backbone.class_token = vit.class_token
backbone.seq_length = vit.seq_length
Expand Down Expand Up @@ -334,18 +348,23 @@ def images_to_tokens(
Args:
images:
Tensor with shape (batch_size, channels, image_size, image_size).
prepend_class_token:
Whether to prepend the class token to the patch tokens.

Returns:
Tensor with shape (batch_size, sequence_length - 1, hidden_dim)
containing the patch tokens.
"""

x = self.conv_proj(images)
tokens = x.flatten(2).transpose(1, 2)
if prepend_class_token:
tokens = utils.prepend_class_token(tokens, self.class_token)
return tokens

def _initialize_weights(self) -> None:
"""Initializes weights for the backbone components."""

# Initialize the patch embedding layer like a linear layer instead of conv
# layer.
w = self.conv_proj.weight.data
Expand Down Expand Up @@ -404,6 +423,8 @@ def __init__(
attention_dropout: float = 0.0,
norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the MAEDecoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -427,8 +448,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

Returns:
Tensor with shape (batch_size, seq_length, out_dim).

"""

out = self.embed(input)
out = self.decode(out)
return self.predict(out)
Expand Down Expand Up @@ -487,6 +508,8 @@ def _initialize_weights(self) -> None:


def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None:
"""Initializes a 2D sine-cosine positional embedding."""

_, seq_length, hidden_dim = pos_embedding.shape
grid_size = int((seq_length - 1) ** 0.5)
sine_cosine_embedding = utils.get_2d_sine_cosine_positional_embedding(
Expand All @@ -502,6 +525,8 @@ def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) ->


def _initialize_linear_layers(module: Module) -> None:
"""Initializes linear layers in the given module."""

def init(mod: Module) -> None:
if isinstance(mod, Linear):
nn.init.xavier_uniform_(mod.weight)
Expand Down
Loading