Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ersi lig 3912 refactor mae to use timm vit #1461

Merged
merged 59 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
951d43e
Add MAE evaluation
guarin May 30, 2023
503cc44
Add stochastic depth dropout
guarin May 31, 2023
ac43499
Add MAE
guarin May 31, 2023
15bfe3a
Drop assertion
guarin May 31, 2023
49c85c0
Fix smooth cross entropy loss and mixup
guarin May 31, 2023
9d95783
Update comments
guarin May 31, 2023
0bb601d
Add layer lr decay and weight decay
guarin Jun 5, 2023
d7d69af
Update comment
guarin Jun 5, 2023
ec05437
Add test for MAE images_to_tokens
guarin Jun 5, 2023
923a606
Disable BN update
guarin Jun 5, 2023
bdce8a6
Add BN before classification head
guarin Jun 6, 2023
316f918
Format
guarin Jun 6, 2023
a6943fd
Fix BN freezing
guarin Jun 6, 2023
1a2b454
Cleanup
guarin Jun 6, 2023
bc066ae
Use torch.no_grad instead of deactivating gradients manually
guarin Jun 6, 2023
d56e340
Create new stochastic depth instances
guarin Jun 6, 2023
5ed6803
Add mask token to learnable params
guarin Jun 6, 2023
4f0baf1
Add sine-cosine positional embedding
guarin Jun 6, 2023
9c4a8cf
Initialize parameters as in paper
guarin Jun 6, 2023
9904c10
Merge branch 'master' into guarin-lig-3056-add-mae-imagenet-benchmark
guarin Dec 6, 2023
83edd1c
Fix types
guarin Dec 6, 2023
e27946e
Format
guarin Dec 6, 2023
0672b0a
Merge branch 'guarin-lig-3056-add-mae-imagenet-benchmark' of github.c…
ersi-lightly Dec 17, 2023
45433c5
adjusted to existing interface
ersi-lightly Dec 18, 2023
c5cab9e
draft
ersi-lightly Dec 19, 2023
017168e
remove
ersi-lightly Dec 19, 2023
278423b
added modifications
ersi-lightly Jan 4, 2024
fde116c
added mae implementation with timm and example
ersi-lightly Jan 5, 2024
f008645
formatted
ersi-lightly Jan 5, 2024
c97112e
fixed import
ersi-lightly Jan 5, 2024
2e55d6b
removed
ersi-lightly Jan 5, 2024
484add1
fixed typing
ersi-lightly Jan 5, 2024
971c19a
addressed comments
ersi-lightly Jan 9, 2024
1ec7470
fixed typing and formatted
ersi-lightly Jan 9, 2024
76ee356
addressed comments
ersi-lightly Jan 9, 2024
edb2d42
added docstring and formatted
ersi-lightly Jan 9, 2024
f00d320
removed images to tokens method
ersi-lightly Jan 10, 2024
cc263fe
Ersi lig 3910 update mae benchmark code (#1468)
ersi-lightly Feb 23, 2024
f7a532b
resolved conflict
ersi-lightly Feb 23, 2024
cc9d4ac
resolved conflicts
ersi-lightly Feb 23, 2024
eda5ac0
formatted
ersi-lightly Feb 23, 2024
0993ec5
adjusted examples
ersi-lightly Feb 25, 2024
4842901
removed comment
ersi-lightly Feb 25, 2024
bc9c6c3
added test
ersi-lightly Feb 26, 2024
73cb2ec
added message in case of ImportError
ersi-lightly Feb 26, 2024
12eeb14
fixed skipping of test
ersi-lightly Feb 26, 2024
ab15124
removed example
ersi-lightly Feb 26, 2024
c22b2ff
handling the TIMM dependency
ersi-lightly Feb 26, 2024
3fb9e86
added note to docs for MAE installation
ersi-lightly Feb 26, 2024
17760fd
added unit tests for MAE with torchvision
ersi-lightly Feb 26, 2024
44cf4e8
removed unecessary maks token definition
ersi-lightly Feb 26, 2024
07fdae4
addressed comments
ersi-lightly Feb 29, 2024
a0f87ac
moved test to separate file
ersi-lightly Feb 29, 2024
f61b708
added typing
ersi-lightly Feb 29, 2024
0f8a927
fixed import
ersi-lightly Feb 29, 2024
a18bccd
fixes typing
ersi-lightly Feb 29, 2024
b16bba8
fixed typing
ersi-lightly Feb 29, 2024
90a2ee0
fixed typing
ersi-lightly Feb 29, 2024
d89808f
Ersi lig 4471 cleanup and merge mae branch (#1510)
ersi-lightly Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/imagenet/vitb16/finetune_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
super().__init__(
model, batch_size_per_device, feature_dim, num_classes, topk, freeze_model
)
# TODO(Ersi, 2/24): Add path dropout for TIMM.

# Add path dropout.
add_stochastic_depth_to_blocks(self.model, prob=0.1)
# Add mixup and cutmix.
Expand Down Expand Up @@ -140,7 +142,6 @@ def finetune_eval(
Parameters follow MAE settings.
"""
print("Running fine-tune evaluation...")

# Setup training data.
# NOTE: We use transforms from the timm library here as they are the default in MAE
# and torchvision does not provide all required parameters.
Expand Down
155 changes: 155 additions & 0 deletions benchmarks/imagenet/vitb16/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import sys
from typing import List, Tuple

import torch
from pytorch_lightning import LightningModule
from timm.models.vision_transformer import vit_base_patch16_224
from torch import Tensor
from torch.nn import MSELoss, Parameter
from torch.optim import AdamW

from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler


class MAE(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

decoder_dim = 512
vit = vit_base_patch16_224()

self.mask_ratio = 0.75
self.patch_size = vit.patch_embed.patch_size[0]
self.sequence_length = vit.patch_embed.num_patches + vit.num_prefix_tokens
mask_token = Parameter(torch.zeros(1, 1, decoder_dim))
torch.nn.init.normal_(mask_token, std=0.02)
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=decoder_dim,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0,
proj_drop_rate=0.0,
attn_drop_rate=0.0,
mask_token=mask_token,
)
self.criterion = MSELoss()

self.online_classifier = OnlineLinearClassifier(
feature_dim=768, num_classes=num_classes
)

def forward(self, x: Tensor) -> Tensor:
return self.backbone(images=x)

def forward_encoder(self, images, idx_keep=None):
return self.backbone.encode(images=images, idx_keep=idx_keep)

def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded)
x_masked = utils.repeat_token(
self.decoder.mask_token, (batch_size, self.sequence_length)
)
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

# decoder forward pass
x_decoded = self.decoder.decode(x_masked)

# predict pixel values for masked tokens
x_pred = utils.get_at_index(x_decoded, idx_mask)
x_pred = self.decoder.predict(x_pred)
return x_pred

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
images = images[0] # images is a list containing only one view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
features = self.forward_encoder(images, idx_keep)
predictions = self.forward_decoder(features, idx_keep, idx_mask)

# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)

loss = self.criterion(predictions, target)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

cls_features = features[:, 0]
cls_loss, cls_log = self.online_classifier.training_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
cls_features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = utils.get_weight_decay_parameters(
[self.backbone, self.decoder]
)
optimizer = AdamW(
[
{"name": "mae", "params": params},
{
"name": "mae_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=1.5e-4 * self.batch_size_per_device * self.trainer.world_size / 256,
weight_decay=0.05,
betas=(0.9, 0.95),
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 40
),
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]


transform = MAETransform()
3 changes: 3 additions & 0 deletions benchmarks/imagenet/vitb16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import finetune_eval
import knn_eval
import linear_eval
import mae
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
Expand Down Expand Up @@ -38,7 +39,9 @@
parser.add_argument("--float32-matmul-precision", type=str, default="high")
parser.add_argument("--strategy", default="ddp_find_unused_parameters_true")


METHODS = {
"mae": {"model": mae.MAE, "transform": mae.transform},
"aim": {"model": aim.AIM, "transform": aim.transform},
}

Expand Down
Loading
Loading