-
Notifications
You must be signed in to change notification settings - Fork 285
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
1,690 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from pytorch_lightning import LightningModule | ||
from torch import Tensor | ||
from torch.nn import MSELoss | ||
from torch.optim import AdamW | ||
from torch.optim.optimizer import Optimizer | ||
|
||
from lightly.models import utils | ||
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer | ||
from lightly.models.utils import get_2d_sincos_pos_embed, random_prefix_mask | ||
from lightly.transforms import AIMTransform | ||
from lightly.utils.benchmarking import OnlineLinearClassifier | ||
from lightly.utils.scheduler import CosineWarmupScheduler | ||
|
||
|
||
class AIM(LightningModule): | ||
def __init__(self, batch_size_per_device: int, num_classes: int) -> None: | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.batch_size_per_device = batch_size_per_device | ||
|
||
img_size = 224 | ||
self.patch_size = 14 | ||
self.num_patches = (img_size // self.patch_size) ** 2 | ||
|
||
vit = MaskedCausalVisionTransformer( | ||
img_size=img_size, | ||
patch_size=self.patch_size, | ||
num_classes=num_classes, | ||
embed_dim=1536, | ||
depth=24, | ||
num_heads=12, | ||
qk_norm=False, | ||
class_token=False, | ||
no_embed_class=True, | ||
) | ||
# Use absolute positional embedding. | ||
pos_embed = get_2d_sincos_pos_embed( | ||
embed_dim=vit.embed_dim, | ||
grid_size=int(self.num_patches**0.5), | ||
cls_token=False, | ||
) | ||
vit.pos_embed.requires_grad = False | ||
vit.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | ||
|
||
self.backbone = vit | ||
self.projection_head = AIMPredictionHead( | ||
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2 | ||
) | ||
|
||
self.criterion = MSELoss() | ||
|
||
self.online_classifier = OnlineLinearClassifier( | ||
feature_dim=vit.embed_dim, num_classes=num_classes | ||
) | ||
|
||
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: | ||
features = self.backbone.forward_features(x, mask=mask) | ||
# TODO: We use mean aggregation for simplicity. The paper uses | ||
# AttentionPoolingClassifier to get the class features. But this is not great | ||
# as it requires training an additional head. | ||
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337 | ||
return features.mean(dim=1).flatten(start_dim=1) | ||
|
||
def training_step( | ||
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int | ||
) -> Tensor: | ||
images, targets = batch[0], batch[1] | ||
images = images[0] # images is a list containing only one view | ||
batch_size = images.shape[0] | ||
|
||
mask = random_prefix_mask( | ||
size=(batch_size, self.num_patches), | ||
max_prefix_length=self.num_patches - 1, | ||
device=images.device, | ||
) | ||
features = self.backbone.forward_features(images, mask=mask) | ||
# Add positional embedding before head. | ||
features = self.backbone._pos_embed(features) | ||
predictions = self.projection_head(features) | ||
|
||
# Convert images to patches and normalize them. | ||
patches = utils.patchify(images, self.patch_size) | ||
mean = patches.mean(dim=-1, keepdim=True) | ||
var = patches.var(dim=-1, keepdim=True) | ||
patches = (patches - mean) / (var + 1.0e-6) ** 0.5 | ||
|
||
loss = self.criterion(predictions, patches) | ||
|
||
self.log( | ||
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets) | ||
) | ||
|
||
# TODO: We could use AttentionPoolingClassifier instead of mean aggregation: | ||
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337 | ||
cls_features = features.mean(dim=1).flatten(start_dim=1) | ||
cls_loss, cls_log = self.online_classifier.training_step( | ||
(cls_features.detach(), targets), batch_idx | ||
) | ||
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) | ||
return loss + cls_loss | ||
|
||
def validation_step( | ||
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int | ||
) -> Tensor: | ||
images, targets = batch[0], batch[1] | ||
cls_features = self.forward(images).flatten(start_dim=1) | ||
cls_loss, cls_log = self.online_classifier.validation_step( | ||
(cls_features.detach(), targets), batch_idx | ||
) | ||
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) | ||
return cls_loss | ||
|
||
def configure_optimizers(self): | ||
# Don't use weight decay for batch norm, bias parameters, and classification | ||
# head to improve performance. | ||
params, params_no_weight_decay = utils.get_weight_decay_parameters( | ||
[self.backbone, self.projection_head] | ||
) | ||
optimizer = AdamW( | ||
[ | ||
{"name": "aim", "params": params}, | ||
{ | ||
"name": "aim_no_weight_decay", | ||
"params": params_no_weight_decay, | ||
"weight_decay": 0.0, | ||
}, | ||
{ | ||
"name": "online_classifier", | ||
"params": self.online_classifier.parameters(), | ||
"weight_decay": 0.0, | ||
}, | ||
], | ||
lr=0.001 * self.batch_size_per_device * self.trainer.world_size / 4096, | ||
weight_decay=0.05, | ||
betas=(0.9, 0.95), | ||
) | ||
scheduler = { | ||
"scheduler": CosineWarmupScheduler( | ||
optimizer=optimizer, | ||
warmup_epochs=31250 / 125000 * self.trainer.estimated_stepping_batches, | ||
max_epochs=self.trainer.estimated_stepping_batches, | ||
), | ||
"interval": "step", | ||
} | ||
return [optimizer], [scheduler] | ||
|
||
def configure_gradient_clipping( | ||
self, | ||
optimizer: Optimizer, | ||
gradient_clip_val: Union[int, float, None] = None, | ||
gradient_clip_algorithm: Union[str, None] = None, | ||
) -> None: | ||
self.clip_gradients( | ||
optimizer=optimizer, gradient_clip_val=1.0, gradient_clip_algorithm="norm" | ||
) | ||
|
||
|
||
transform = AIMTransform() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
from pathlib import Path | ||
from typing import Dict, Tuple | ||
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
from timm.data import create_transform | ||
from timm.data.mixup import Mixup | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.optim import AdamW | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms as T | ||
|
||
from lightly.data import LightlyDataset | ||
from lightly.models import utils | ||
from lightly.models.utils import add_stochastic_depth_to_blocks | ||
from lightly.transforms.utils import IMAGENET_NORMALIZE | ||
from lightly.utils.benchmarking import LinearClassifier, MetricCallback | ||
from lightly.utils.benchmarking.topk import mean_topk_accuracy | ||
from lightly.utils.scheduler import CosineWarmupScheduler | ||
|
||
|
||
class FinetuneEvalClassifier(LinearClassifier): | ||
# Parameters follow MAE settings. | ||
# Adapt initialization to include mixup. | ||
def __init__( | ||
self, | ||
model: Module, | ||
batch_size_per_device: int, | ||
feature_dim: int, | ||
num_classes: int, | ||
topk: Tuple[int, ...] = (1, 5), | ||
freeze_model: bool = False, | ||
) -> None: | ||
super().__init__( | ||
model, batch_size_per_device, feature_dim, num_classes, topk, freeze_model | ||
) | ||
# Add path dropout. | ||
add_stochastic_depth_to_blocks(self.model, prob=0.1) | ||
# Add mixup and cutmix. | ||
self.mixup = Mixup( | ||
mixup_alpha=0.8, | ||
cutmix_alpha=1.0, | ||
label_smoothing=0.1, | ||
num_classes=num_classes, | ||
) | ||
|
||
# Adapt step to include mixup. | ||
def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]: | ||
images, targets = batch[0], batch[1] | ||
if self.trainer.state.stage == "train": | ||
images, targets = self.mixup(images, targets) | ||
predictions = self.forward(images) | ||
loss = self.criterion(predictions, targets) | ||
_, predicted_labels = predictions.topk(max(self.topk)) | ||
# Pass targets without mixup for topk accuracy calculation. | ||
topk = mean_topk_accuracy(predicted_labels, batch[1], k=self.topk) | ||
return loss, topk | ||
|
||
# Adapt optimizer to match MAE settings. Parameters follow the original code from | ||
# the authors: https://github.com/facebookresearch/mae/blob/main/FINETUNE.md#fine-tuning | ||
# Note that lr and layerwise_lr_decay for ViT-B/16 are 1e-3 and 0.75 in the paper | ||
# but 5e-4 and 0.65 in the code. | ||
def configure_optimizers(self): | ||
lr = 5e-4 * self.batch_size_per_device * self.trainer.world_size / 256 | ||
layerwise_lr_decay = 0.65 | ||
|
||
# Group parameters by weight decay and learning rate. | ||
param_groups = {} | ||
for name, module in utils.get_named_leaf_modules(self.model).items(): | ||
if "encoder_layer_" in name: | ||
layer_index = int(name.split("encoder_layer_")[1].split(".")[0]) | ||
group_name = f"vit_layer_{layer_index}" | ||
# ViT-B has 12 layers. LR increases from first layer with index 0 to | ||
# last layer with index 11. | ||
group_lr = lr * (layerwise_lr_decay ** (11 - layer_index)) | ||
else: | ||
group_name = "vit" | ||
group_lr = lr | ||
params, params_no_weight_decay = utils.get_weight_decay_parameters([module]) | ||
group = param_groups.setdefault( | ||
group_name, | ||
{ | ||
"name": group_name, | ||
"params": [], | ||
"lr": group_lr, | ||
"weight_decay": 0.05, | ||
}, | ||
) | ||
group["params"].extend(params) | ||
group_no_weight_decay = param_groups.setdefault( | ||
f"{group_name}_no_weight_decay", | ||
{ | ||
"name": f"{group_name}_no_weight_decay", | ||
"params": [], | ||
"lr": group_lr, | ||
"weight_decay": 0.0, | ||
}, | ||
) | ||
group_no_weight_decay["params"].extend(params_no_weight_decay) | ||
param_groups["classification_head"] = { | ||
"name": "classification_head", | ||
"params": self.classification_head.parameters(), | ||
"weight_decay": 0.0, | ||
} | ||
optimizer = AdamW( | ||
list(param_groups.values()), | ||
betas=(0.9, 0.999), | ||
) | ||
scheduler = { | ||
"scheduler": CosineWarmupScheduler( | ||
optimizer=optimizer, | ||
warmup_epochs=( | ||
self.trainer.estimated_stepping_batches | ||
/ self.trainer.max_epochs | ||
* 5 | ||
), | ||
max_epochs=self.trainer.estimated_stepping_batches, | ||
), | ||
"interval": "step", | ||
} | ||
return [optimizer], [scheduler] | ||
|
||
|
||
def finetune_eval( | ||
model: Module, | ||
train_dir: Path, | ||
val_dir: Path, | ||
log_dir: Path, | ||
batch_size_per_device: int, | ||
num_workers: int, | ||
accelerator: str, | ||
devices: int, | ||
precision: str, | ||
num_classes: int, | ||
) -> None: | ||
"""Runs fine-tune evaluation on the given model. | ||
Parameters follow MAE settings. | ||
""" | ||
print("Running fine-tune evaluation...") | ||
|
||
# Setup training data. | ||
# NOTE: We use transforms from the timm library here as they are the default in MAE | ||
# and torchvision does not provide all required parameters. | ||
train_transform = create_transform( | ||
input_size=224, | ||
is_training=True, | ||
auto_augment="rand-m9-mstd0.5-inc1", | ||
interpolation="bicubic", | ||
re_prob=0.25, | ||
re_mode="pixel", | ||
re_count=1, | ||
mean=IMAGENET_NORMALIZE["mean"], | ||
std=IMAGENET_NORMALIZE["std"], | ||
) | ||
train_dataset = LightlyDataset(input_dir=str(train_dir), transform=train_transform) | ||
train_dataloader = DataLoader( | ||
train_dataset, | ||
batch_size=batch_size_per_device, | ||
shuffle=True, | ||
num_workers=num_workers, | ||
drop_last=True, | ||
persistent_workers=True, | ||
) | ||
|
||
# Setup validation data. | ||
val_transform = T.Compose( | ||
[ | ||
T.Resize(256), | ||
T.CenterCrop(224), | ||
T.ToTensor(), | ||
T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"]), | ||
] | ||
) | ||
val_dataset = LightlyDataset(input_dir=str(val_dir), transform=val_transform) | ||
val_dataloader = DataLoader( | ||
val_dataset, | ||
batch_size=batch_size_per_device, | ||
shuffle=False, | ||
num_workers=num_workers, | ||
persistent_workers=True, | ||
) | ||
|
||
# Train linear classifier. | ||
metric_callback = MetricCallback() | ||
trainer = Trainer( | ||
max_epochs=100, | ||
accelerator=accelerator, | ||
devices=devices, | ||
callbacks=[ | ||
LearningRateMonitor(), | ||
DeviceStatsMonitor(), | ||
metric_callback, | ||
], | ||
logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"), | ||
precision=precision, | ||
strategy="ddp_find_unused_parameters_true", | ||
) | ||
classifier = FinetuneEvalClassifier( | ||
model=model, | ||
batch_size_per_device=batch_size_per_device, | ||
feature_dim=768, | ||
num_classes=num_classes, | ||
freeze_model=False, | ||
) | ||
trainer.fit( | ||
model=classifier, | ||
train_dataloaders=train_dataloader, | ||
val_dataloaders=val_dataloader, | ||
) | ||
for metric in ["val_top1", "val_top5"]: | ||
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}") |
Oops, something went wrong.