Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ersi lig 3912 refactor mae to use timm vit #1461

Merged
merged 59 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
951d43e
Add MAE evaluation
guarin May 30, 2023
503cc44
Add stochastic depth dropout
guarin May 31, 2023
ac43499
Add MAE
guarin May 31, 2023
15bfe3a
Drop assertion
guarin May 31, 2023
49c85c0
Fix smooth cross entropy loss and mixup
guarin May 31, 2023
9d95783
Update comments
guarin May 31, 2023
0bb601d
Add layer lr decay and weight decay
guarin Jun 5, 2023
d7d69af
Update comment
guarin Jun 5, 2023
ec05437
Add test for MAE images_to_tokens
guarin Jun 5, 2023
923a606
Disable BN update
guarin Jun 5, 2023
bdce8a6
Add BN before classification head
guarin Jun 6, 2023
316f918
Format
guarin Jun 6, 2023
a6943fd
Fix BN freezing
guarin Jun 6, 2023
1a2b454
Cleanup
guarin Jun 6, 2023
bc066ae
Use torch.no_grad instead of deactivating gradients manually
guarin Jun 6, 2023
d56e340
Create new stochastic depth instances
guarin Jun 6, 2023
5ed6803
Add mask token to learnable params
guarin Jun 6, 2023
4f0baf1
Add sine-cosine positional embedding
guarin Jun 6, 2023
9c4a8cf
Initialize parameters as in paper
guarin Jun 6, 2023
9904c10
Merge branch 'master' into guarin-lig-3056-add-mae-imagenet-benchmark
guarin Dec 6, 2023
83edd1c
Fix types
guarin Dec 6, 2023
e27946e
Format
guarin Dec 6, 2023
0672b0a
Merge branch 'guarin-lig-3056-add-mae-imagenet-benchmark' of github.c…
ersi-lightly Dec 17, 2023
45433c5
adjusted to existing interface
ersi-lightly Dec 18, 2023
c5cab9e
draft
ersi-lightly Dec 19, 2023
017168e
remove
ersi-lightly Dec 19, 2023
278423b
added modifications
ersi-lightly Jan 4, 2024
fde116c
added mae implementation with timm and example
ersi-lightly Jan 5, 2024
f008645
formatted
ersi-lightly Jan 5, 2024
c97112e
fixed import
ersi-lightly Jan 5, 2024
2e55d6b
removed
ersi-lightly Jan 5, 2024
484add1
fixed typing
ersi-lightly Jan 5, 2024
971c19a
addressed comments
ersi-lightly Jan 9, 2024
1ec7470
fixed typing and formatted
ersi-lightly Jan 9, 2024
76ee356
addressed comments
ersi-lightly Jan 9, 2024
edb2d42
added docstring and formatted
ersi-lightly Jan 9, 2024
f00d320
removed images to tokens method
ersi-lightly Jan 10, 2024
cc263fe
Ersi lig 3910 update mae benchmark code (#1468)
ersi-lightly Feb 23, 2024
f7a532b
resolved conflict
ersi-lightly Feb 23, 2024
cc9d4ac
resolved conflicts
ersi-lightly Feb 23, 2024
eda5ac0
formatted
ersi-lightly Feb 23, 2024
0993ec5
adjusted examples
ersi-lightly Feb 25, 2024
4842901
removed comment
ersi-lightly Feb 25, 2024
bc9c6c3
added test
ersi-lightly Feb 26, 2024
73cb2ec
added message in case of ImportError
ersi-lightly Feb 26, 2024
12eeb14
fixed skipping of test
ersi-lightly Feb 26, 2024
ab15124
removed example
ersi-lightly Feb 26, 2024
c22b2ff
handling the TIMM dependency
ersi-lightly Feb 26, 2024
3fb9e86
added note to docs for MAE installation
ersi-lightly Feb 26, 2024
17760fd
added unit tests for MAE with torchvision
ersi-lightly Feb 26, 2024
44cf4e8
removed unecessary maks token definition
ersi-lightly Feb 26, 2024
07fdae4
addressed comments
ersi-lightly Feb 29, 2024
a0f87ac
moved test to separate file
ersi-lightly Feb 29, 2024
f61b708
added typing
ersi-lightly Feb 29, 2024
0f8a927
fixed import
ersi-lightly Feb 29, 2024
a18bccd
fixes typing
ersi-lightly Feb 29, 2024
b16bba8
fixed typing
ersi-lightly Feb 29, 2024
90a2ee0
fixed typing
ersi-lightly Feb 29, 2024
d89808f
Ersi lig 4471 cleanup and merge mae branch (#1510)
ersi-lightly Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ If you want to work with video files you need to additionally install
pip install av

If you want to work use the Masked Autoencoder you need to additionally install
`TIMM <https://timm.fast.ai/>`_.
`TIMM <https://github.com/huggingface/pytorch-image-models>`_.

.. code-block:: bash

Expand Down
23 changes: 5 additions & 18 deletions examples/pytorch/mae.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import sys

import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

from lightly.utils import dependency

if dependency.timm_vit_available():
from timm.models.vision_transformer import vit_base_patch32_224
else:
sys.exit(1)

from lightly.models import utils
from lightly.models.modules import (
masked_autoencoder_timm,
masked_vision_transformer_timm,
)
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.mae_transform import MAETransform


Expand All @@ -30,16 +19,14 @@ def __init__(self, vit):
self.mask_ratio = 0.75
self.patch_size = vit.patch_embed.patch_size[0]

self.backbone = masked_vision_transformer_timm.MaskedVisionTransformerTIMM(
vit=vit
)
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.decoder = masked_autoencoder_timm.MAEDecoder(
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=decoder_dim,
decoder_depth=8,
decoder_depth=1,
decoder_num_heads=16,
mlp_ratio=4.0,
proj_drop_rate=0.0,
Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch/msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@

from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms.msn_transform import MSNTransform


Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch/pmsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@

from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms import MSNTransform


Expand Down
23 changes: 5 additions & 18 deletions examples/pytorch_lightning/mae.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import sys

import pytorch_lightning as pl
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

from lightly.utils import dependency

if dependency.timm_vit_available():
from timm.models.vision_transformer import vit_base_patch32_224
else:
sys.exit(1)

from lightly.models import utils
from lightly.models.modules import (
masked_autoencoder_timm,
masked_vision_transformer_timm,
)
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.mae_transform import MAETransform


Expand All @@ -31,16 +20,14 @@ def __init__(self):
vit = vit_base_patch32_224()
self.mask_ratio = 0.75
self.patch_size = vit.patch_embed.patch_size[0]
self.backbone = masked_vision_transformer_timm.MaskedVisionTransformerTIMM(
vit=vit
)
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.decoder = masked_autoencoder_timm.MAEDecoder(
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=decoder_dim,
decoder_depth=8,
decoder_depth=1,
decoder_num_heads=16,
mlp_ratio=4.0,
proj_drop_rate=0.0,
Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch_lightning/msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms.msn_transform import MSNTransform


Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch_lightning/pmsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms import MSNTransform


Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch_lightning/simmim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from torch import nn

from lightly.models import utils
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE


Expand Down
23 changes: 5 additions & 18 deletions examples/pytorch_lightning_distributed/mae.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import sys

import pytorch_lightning as pl
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

from lightly.utils import dependency

if dependency.timm_vit_available():
from timm.models.vision_transformer import vit_base_patch32_224
else:
sys.exit(1)

from lightly.models import utils
from lightly.models.modules import (
masked_autoencoder_timm,
masked_vision_transformer_timm,
)
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.mae_transform import MAETransform


Expand All @@ -31,16 +20,14 @@ def __init__(self):
vit = vit_base_patch32_224()
self.mask_ratio = 0.75
self.patch_size = vit.patch_embed.patch_size[0]
self.backbone = masked_vision_transformer_timm.MaskedVisionTransformerTIMM(
vit=vit
)
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.decoder = masked_autoencoder_timm.MAEDecoder(
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=decoder_dim,
decoder_depth=8,
decoder_depth=1,
decoder_num_heads=16,
mlp_ratio=4.0,
proj_drop_rate=0.0,
Expand Down
7 changes: 2 additions & 5 deletions examples/pytorch_lightning_distributed/msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms.msn_transform import MSNTransform


Expand Down Expand Up @@ -117,7 +115,6 @@ def configure_optimizers(self):
devices="auto",
accelerator="gpu",
strategy="ddp",
# use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
replace_sampler_ddp=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)
7 changes: 2 additions & 5 deletions examples/pytorch_lightning_distributed/pmsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms import MSNTransform


Expand Down Expand Up @@ -118,7 +116,6 @@ def configure_optimizers(self):
devices="auto",
accelerator="gpu",
strategy="ddp",
# use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
replace_sampler_ddp=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)
7 changes: 2 additions & 5 deletions examples/pytorch_lightning_distributed/simmim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from torch import nn

from lightly.models import utils
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE


Expand Down Expand Up @@ -97,7 +95,6 @@ def configure_optimizers(self):
devices="auto",
accelerator="gpu",
strategy="ddp",
# use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
replace_sampler_ddp=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)
3 changes: 1 addition & 2 deletions lightly/models/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from lightly.utils import dependency as _dependency

if _dependency.torchvision_vit_available():
# Requires torchvision >=0.12
# Requires torchvision >=0.12
guarin marked this conversation as resolved.
Show resolved Hide resolved
from lightly.models.modules.masked_autoencoder import (
MAEBackbone,
Expand All @@ -45,10 +44,10 @@
if _dependency.timm_vit_available():
# Requires timm >= 0.9.9
from lightly.models.modules.heads_timm import AIMPredictionHead
from lightly.models.modules.masked_autoencoder_timm import MAEDecoder
from lightly.models.modules.masked_autoencoder_timm import MAEDecoderTIMM

Check warning on line 47 in lightly/models/modules/__init__.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/__init__.py#L47

Added line #L47 was not covered by tests
from lightly.models.modules.masked_causal_vision_transformer import (
MaskedCausalVisionTransformer,
)
from lightly.models.modules.masked_vision_transformer_timm import (

Check warning on line 51 in lightly/models/modules/__init__.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/__init__.py#L51

Added line #L51 was not covered by tests
MaskedVisionTransformerTIMM,
)
9 changes: 2 additions & 7 deletions lightly/models/modules/masked_autoencoder_timm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
from __future__ import annotations

Check warning on line 1 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L1

Added line #L1 was not covered by tests

from functools import partial
from typing import Callable, Optional

Check warning on line 4 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L3-L4

Added lines #L3 - L4 were not covered by tests

from lightly.utils import dependency

if dependency.timm_vit_available():
from timm.models import vision_transformer


import torch
import torch.nn as nn
from timm.models import vision_transformer
from torch.nn import LayerNorm, Linear, Module, Parameter, Sequential

Check warning on line 9 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L6-L9

Added lines #L6 - L9 were not covered by tests

from lightly.models import utils

Check warning on line 11 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L11

Added line #L11 was not covered by tests


class MAEDecoder(Module):
class MAEDecoderTIMM(Module):

Check warning on line 14 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L14

Added line #L14 was not covered by tests
"""Decoder for the Masked Autoencoder model [0].

Decodes encoded patches and predicts pixel values for every patch.
Expand Down Expand Up @@ -53,7 +48,7 @@

"""

def __init__(

Check warning on line 51 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L51

Added line #L51 was not covered by tests
self,
num_patches: int,
patch_size: int,
Expand All @@ -68,21 +63,21 @@
norm_layer: Callable[..., nn.Module] = partial(LayerNorm, eps=1e-6),
mask_token: Optional[Parameter] = None,
):
super().__init__()

Check warning on line 66 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L66

Added line #L66 was not covered by tests

self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = (

Check warning on line 69 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L68-L69

Added lines #L68 - L69 were not covered by tests
nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
if mask_token is None
else mask_token
)

# positional encoding of the decoder
self.decoder_pos_embed = nn.Parameter(

Check warning on line 76 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L76

Added line #L76 was not covered by tests
torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False
) # fixed sin-cos embedding

self.decoder_blocks = Sequential(

Check warning on line 80 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L80

Added line #L80 was not covered by tests
*[
vision_transformer.Block(
decoder_embed_dim,
Expand All @@ -97,14 +92,14 @@
]
)

self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(

Check warning on line 96 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L95-L96

Added lines #L95 - L96 were not covered by tests
decoder_embed_dim, patch_size**2 * in_chans, bias=True
) # decoder to patch

self._initialize_weights()

Check warning on line 100 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L100

Added line #L100 was not covered by tests

def forward(self, input: torch.Tensor) -> torch.Tensor:

Check warning on line 102 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L102

Added line #L102 was not covered by tests
"""Returns predicted pixel values from encoded tokens.

Args:
Expand All @@ -115,11 +110,11 @@
Tensor with shape (batch_size, seq_length, out_dim).

"""
out = self.embed(input)
out = self.decode(out)
return self.predict(out)

Check warning on line 115 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L113-L115

Added lines #L113 - L115 were not covered by tests

def embed(self, input: torch.Tensor) -> torch.Tensor:

Check warning on line 117 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L117

Added line #L117 was not covered by tests
"""Embeds encoded input tokens into decoder token dimension.

This is a single linear layer that changes the token dimension from
Expand All @@ -135,10 +130,10 @@
the embedded tokens.

"""
out: torch.Tensor = self.decoder_embed(input)
return out

Check warning on line 134 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L133-L134

Added lines #L133 - L134 were not covered by tests

def decode(self, input: torch.Tensor) -> torch.Tensor:

Check warning on line 136 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L136

Added line #L136 was not covered by tests
"""Forward pass through the decoder transformer.

Args:
Expand All @@ -151,12 +146,12 @@
the decoded tokens.

"""
output: torch.Tensor = input + self.decoder_pos_embed
output = self.decoder_blocks(output)
output = self.decoder_norm(output)
return output

Check warning on line 152 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L149-L152

Added lines #L149 - L152 were not covered by tests

def predict(self, input: torch.Tensor) -> torch.Tensor:

Check warning on line 154 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L154

Added line #L154 was not covered by tests
"""Predics pixel values from decoded tokens.

Args:
Expand All @@ -169,35 +164,35 @@
predictions for each token.

"""
out: torch.Tensor = self.decoder_pred(input)
return out

Check warning on line 168 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L167-L168

Added lines #L167 - L168 were not covered by tests

def _initialize_weights(self) -> None:
torch.nn.init.normal_(self.mask_token, std=0.02)
_initialize_2d_sine_cosine_positional_embedding(self.decoder_pos_embed)
self.apply(_init_weights)

Check warning on line 173 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L170-L173

Added lines #L170 - L173 were not covered by tests


def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None:
_, seq_length, hidden_dim = pos_embedding.shape
grid_size = int((seq_length - 1) ** 0.5)
sine_cosine_embedding = utils.get_2d_sine_cosine_positional_embedding(

Check warning on line 179 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L176-L179

Added lines #L176 - L179 were not covered by tests
embed_dim=hidden_dim,
grid_size=grid_size,
cls_token=True,
)
pos_embedding.data.copy_(

Check warning on line 184 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L184

Added line #L184 was not covered by tests
torch.from_numpy(sine_cosine_embedding).float().unsqueeze(0)
)
# Freeze positional embedding.
pos_embedding.requires_grad = False

Check warning on line 188 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L188

Added line #L188 was not covered by tests


def _init_weights(module: Module) -> None:
if isinstance(module, Linear):
nn.init.xavier_uniform_(module.weight)
if isinstance(module, Linear) and module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)

Check warning on line 198 in lightly/models/modules/masked_autoencoder_timm.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/masked_autoencoder_timm.py#L191-L198

Added lines #L191 - L198 were not covered by tests
1 change: 0 additions & 1 deletion lightly/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def timm_vit_available() -> bool:
import timm.models.vision_transformer # Requires timm >= 0.3.3
from timm.layers import LayerType # Requires timm >= 0.9.9
except ImportError:
print("TIMM is not available. Please install if you would like to use the MAE.")
return False
else:
return True
Loading
Loading