Skip to content

Commit

Permalink
Merge branch 'master' into v4
Browse files Browse the repository at this point in the history
  • Loading branch information
japrescott committed Jan 30, 2024
2 parents 5cb9e2d + 25894a9 commit 1403ff9
Show file tree
Hide file tree
Showing 23 changed files with 1,690 additions and 111 deletions.
4 changes: 2 additions & 2 deletions benchmarks/imagenet/resnet50/finetune_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lightly.utils.scheduler import CosineWarmupScheduler


class FinetuneLinearClassifier(LinearClassifier):
class FinetuneEvalClassifier(LinearClassifier):
def configure_optimizers(self):
parameters = list(self.classification_head.parameters())
parameters += self.model.parameters()
Expand Down Expand Up @@ -119,7 +119,7 @@ def finetune_eval(
strategy="ddp_find_unused_parameters_true",
num_sanity_val_steps=0,
)
classifier = FinetuneLinearClassifier(
classifier = FinetuneEvalClassifier(
model=model,
batch_size_per_device=batch_size_per_device,
feature_dim=2048,
Expand Down
161 changes: 161 additions & 0 deletions benchmarks/imagenet/vitb16/aim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import List, Optional, Tuple, Union

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import MSELoss
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer

from lightly.models import utils
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer
from lightly.models.utils import get_2d_sincos_pos_embed, random_prefix_mask
from lightly.transforms import AIMTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler


class AIM(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

img_size = 224
self.patch_size = 14
self.num_patches = (img_size // self.patch_size) ** 2

vit = MaskedCausalVisionTransformer(
img_size=img_size,
patch_size=self.patch_size,
num_classes=num_classes,
embed_dim=1536,
depth=24,
num_heads=12,
qk_norm=False,
class_token=False,
no_embed_class=True,
)
# Use absolute positional embedding.
pos_embed = get_2d_sincos_pos_embed(
embed_dim=vit.embed_dim,
grid_size=int(self.num_patches**0.5),
cls_token=False,
)
vit.pos_embed.requires_grad = False
vit.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

self.backbone = vit
self.projection_head = AIMPredictionHead(
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2
)

self.criterion = MSELoss()

self.online_classifier = OnlineLinearClassifier(
feature_dim=vit.embed_dim, num_classes=num_classes
)

def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
features = self.backbone.forward_features(x, mask=mask)
# TODO: We use mean aggregation for simplicity. The paper uses
# AttentionPoolingClassifier to get the class features. But this is not great
# as it requires training an additional head.
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337
return features.mean(dim=1).flatten(start_dim=1)

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
images = images[0] # images is a list containing only one view
batch_size = images.shape[0]

mask = random_prefix_mask(
size=(batch_size, self.num_patches),
max_prefix_length=self.num_patches - 1,
device=images.device,
)
features = self.backbone.forward_features(images, mask=mask)
# Add positional embedding before head.
features = self.backbone._pos_embed(features)
predictions = self.projection_head(features)

# Convert images to patches and normalize them.
patches = utils.patchify(images, self.patch_size)
mean = patches.mean(dim=-1, keepdim=True)
var = patches.var(dim=-1, keepdim=True)
patches = (patches - mean) / (var + 1.0e-6) ** 0.5

loss = self.criterion(predictions, patches)

self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

# TODO: We could use AttentionPoolingClassifier instead of mean aggregation:
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337
cls_features = features.mean(dim=1).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.training_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
cls_features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = utils.get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = AdamW(
[
{"name": "aim", "params": params},
{
"name": "aim_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.001 * self.batch_size_per_device * self.trainer.world_size / 4096,
weight_decay=0.05,
betas=(0.9, 0.95),
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=31250 / 125000 * self.trainer.estimated_stepping_batches,
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]

def configure_gradient_clipping(
self,
optimizer: Optimizer,
gradient_clip_val: Union[int, float, None] = None,
gradient_clip_algorithm: Union[str, None] = None,
) -> None:
self.clip_gradients(
optimizer=optimizer, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
)


transform = AIMTransform()
214 changes: 214 additions & 0 deletions benchmarks/imagenet/vitb16/finetune_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from pathlib import Path
from typing import Dict, Tuple

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from timm.data import create_transform
from timm.data.mixup import Mixup
from torch import Tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision import transforms as T

from lightly.data import LightlyDataset
from lightly.models import utils
from lightly.models.utils import add_stochastic_depth_to_blocks
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import LinearClassifier, MetricCallback
from lightly.utils.benchmarking.topk import mean_topk_accuracy
from lightly.utils.scheduler import CosineWarmupScheduler


class FinetuneEvalClassifier(LinearClassifier):
# Parameters follow MAE settings.
# Adapt initialization to include mixup.
def __init__(
self,
model: Module,
batch_size_per_device: int,
feature_dim: int,
num_classes: int,
topk: Tuple[int, ...] = (1, 5),
freeze_model: bool = False,
) -> None:
super().__init__(
model, batch_size_per_device, feature_dim, num_classes, topk, freeze_model
)
# Add path dropout.
add_stochastic_depth_to_blocks(self.model, prob=0.1)
# Add mixup and cutmix.
self.mixup = Mixup(
mixup_alpha=0.8,
cutmix_alpha=1.0,
label_smoothing=0.1,
num_classes=num_classes,
)

# Adapt step to include mixup.
def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]:
images, targets = batch[0], batch[1]
if self.trainer.state.stage == "train":
images, targets = self.mixup(images, targets)
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
_, predicted_labels = predictions.topk(max(self.topk))
# Pass targets without mixup for topk accuracy calculation.
topk = mean_topk_accuracy(predicted_labels, batch[1], k=self.topk)
return loss, topk

# Adapt optimizer to match MAE settings. Parameters follow the original code from
# the authors: https://github.com/facebookresearch/mae/blob/main/FINETUNE.md#fine-tuning
# Note that lr and layerwise_lr_decay for ViT-B/16 are 1e-3 and 0.75 in the paper
# but 5e-4 and 0.65 in the code.
def configure_optimizers(self):
lr = 5e-4 * self.batch_size_per_device * self.trainer.world_size / 256
layerwise_lr_decay = 0.65

# Group parameters by weight decay and learning rate.
param_groups = {}
for name, module in utils.get_named_leaf_modules(self.model).items():
if "encoder_layer_" in name:
layer_index = int(name.split("encoder_layer_")[1].split(".")[0])
group_name = f"vit_layer_{layer_index}"
# ViT-B has 12 layers. LR increases from first layer with index 0 to
# last layer with index 11.
group_lr = lr * (layerwise_lr_decay ** (11 - layer_index))
else:
group_name = "vit"
group_lr = lr
params, params_no_weight_decay = utils.get_weight_decay_parameters([module])
group = param_groups.setdefault(
group_name,
{
"name": group_name,
"params": [],
"lr": group_lr,
"weight_decay": 0.05,
},
)
group["params"].extend(params)
group_no_weight_decay = param_groups.setdefault(
f"{group_name}_no_weight_decay",
{
"name": f"{group_name}_no_weight_decay",
"params": [],
"lr": group_lr,
"weight_decay": 0.0,
},
)
group_no_weight_decay["params"].extend(params_no_weight_decay)
param_groups["classification_head"] = {
"name": "classification_head",
"params": self.classification_head.parameters(),
"weight_decay": 0.0,
}
optimizer = AdamW(
list(param_groups.values()),
betas=(0.9, 0.999),
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 5
),
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]


def finetune_eval(
model: Module,
train_dir: Path,
val_dir: Path,
log_dir: Path,
batch_size_per_device: int,
num_workers: int,
accelerator: str,
devices: int,
precision: str,
num_classes: int,
) -> None:
"""Runs fine-tune evaluation on the given model.
Parameters follow MAE settings.
"""
print("Running fine-tune evaluation...")

# Setup training data.
# NOTE: We use transforms from the timm library here as they are the default in MAE
# and torchvision does not provide all required parameters.
train_transform = create_transform(
input_size=224,
is_training=True,
auto_augment="rand-m9-mstd0.5-inc1",
interpolation="bicubic",
re_prob=0.25,
re_mode="pixel",
re_count=1,
mean=IMAGENET_NORMALIZE["mean"],
std=IMAGENET_NORMALIZE["std"],
)
train_dataset = LightlyDataset(input_dir=str(train_dir), transform=train_transform)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size_per_device,
shuffle=True,
num_workers=num_workers,
drop_last=True,
persistent_workers=True,
)

# Setup validation data.
val_transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"]),
]
)
val_dataset = LightlyDataset(input_dir=str(val_dir), transform=val_transform)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)

# Train linear classifier.
metric_callback = MetricCallback()
trainer = Trainer(
max_epochs=100,
accelerator=accelerator,
devices=devices,
callbacks=[
LearningRateMonitor(),
DeviceStatsMonitor(),
metric_callback,
],
logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
)
classifier = FinetuneEvalClassifier(
model=model,
batch_size_per_device=batch_size_per_device,
feature_dim=768,
num_classes=num_classes,
freeze_model=False,
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
for metric in ["val_top1", "val_top5"]:
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
Loading

0 comments on commit 1403ff9

Please sign in to comment.