Skip to content

Commit

Permalink
I-JEPA (#1273)
Browse files Browse the repository at this point in the history
* Add experimental I-JEPA version
* Add IJEPABackbone
* Add IJEPAPredictor
* Add IJEPAMaskCollator
* Add IJEPATransform
* Add pytorch example
  • Loading branch information
Natyren authored Jul 14, 2023
1 parent 2206729 commit a55cf44
Show file tree
Hide file tree
Showing 8 changed files with 874 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
lightning_logs/
**lightning_logs/
**/__MACOSX
datasets/
docs/source/tutorials/package/*
docs/source/tutorials/platform/*
docs/source/tutorials_source/platform/data
Expand Down
117 changes: 117 additions & 0 deletions examples/pytorch/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import copy

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from lightly.data.collate import IJEPAMaskCollator
from lightly.models import utils
from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor
from lightly.transforms.ijepa_transform import IJEPATransform


class IJEPA(nn.Module):
def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):
super().__init__()
self.encoder = IJEPABackbone.from_vit(vit_encoder)
self.predictor = IJEPAPredictor.from_vit_encoder(
vit_predictor.encoder,
(vit_predictor.image_size // vit_predictor.patch_size) ** 2,
)
self.target_encoder = copy.deepcopy(self.encoder)
self.momentum_scheduler = momentum_scheduler

def forward_target(self, imgs, masks_enc, masks_pred):
with torch.no_grad():
h = self.target_encoder(imgs)
h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim
B = len(h)
# -- create targets (masked regions of h)
h = utils.apply_masks(h, masks_pred)
h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))
return h

def forward_context(self, imgs, masks_enc, masks_pred):
z = self.encoder(imgs, masks_enc)
z = self.predictor(z, masks_enc, masks_pred)
return z

def forward(self, imgs, masks_enc, masks_pred):
z = self.forward_context(imgs, masks_enc, masks_pred)
h = self.forward_target(imgs, masks_enc, masks_pred)
return z, h

def update_target_encoder(
self,
):
with torch.no_grad():
m = next(self.momentum_scheduler)
for param_q, param_k in zip(
self.encoder.parameters(), self.target_encoder.parameters()
):
param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)


collator = IJEPAMaskCollator(
input_size=(224, 224),
patch_size=32,
)

transform = IJEPATransform()

# we ignore object detection annotations by setting target_transform to return 0
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=collator, batch_size=10, persistent_workers=False
)

ema = (0.996, 1.0)
ipe_scale = 1.0
ipe = len(data_loader)
num_epochs = 10
momentum_scheduler = (
ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
for i in range(int(ipe * num_epochs * ipe_scale) + 1)
)

vit_for_predictor = torchvision.models.vit_b_32(pretrained=False)
vit_for_embedder = torchvision.models.vit_b_32(pretrained=False)
model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)

criterion = nn.SmoothL1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("Starting Training")
for epoch in range(num_epochs):
total_loss = 0
for udata, masks_enc, masks_pred in tqdm(data_loader):

def load_imgs():
# -- unsupervised imgs
imgs = udata[0].to(device, non_blocking=True)
masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
return (imgs, masks_1, masks_2)

imgs, masks_enc, masks_pred = load_imgs()
z, h = model(imgs, masks_enc, masks_pred)
loss = criterion(z, h)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
model.update_target_encoder()

avg_loss = total_loss / len(data_loader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
1 change: 1 addition & 0 deletions examples/pytorch_lightning/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO
1 change: 1 addition & 0 deletions examples/pytorch_lightning_distributed/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO
172 changes: 172 additions & 0 deletions lightly/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import math
from multiprocessing import Value
from typing import List, Optional, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -1345,6 +1347,176 @@ def forward(
return (views_global, views_local, grids_global, grids_local), labels, fnames


class IJEPAMaskCollator:
"""Collator for IJEPA model [0].
Experimental: Support for I-JEPA is experimental, there might be breaking changes
in the future.
Code inspired by [1].
- [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243
- [1]: https://github.com/facebookresearch/ijepa
"""

def __init__(
self,
input_size=(224, 224),
patch_size=16,
enc_mask_scale=(0.2, 0.8),
pred_mask_scale=(0.2, 0.8),
aspect_ratio=(0.3, 3.0),
nenc=1,
npred=2,
min_keep=4,
allow_overlap=False,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.patch_size = patch_size
self.height, self.width = (
input_size[0] // patch_size,
input_size[1] // patch_size,
)
self.enc_mask_scale = enc_mask_scale
self.pred_mask_scale = pred_mask_scale
self.aspect_ratio = aspect_ratio
self.nenc = nenc
self.npred = npred
self.min_keep = min_keep # minimum number of patches to keep
self.allow_overlap = (
allow_overlap # whether to allow overlap b/w enc and pred masks
)
self._itr_counter = Value("i", -1) # collator is shared across worker processes

def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v

def _sample_block_size(self, generator, scale, aspect_ratio_scale):
_rand = torch.rand(1, generator=generator).item()
# -- Sample block scale
min_s, max_s = scale
mask_scale = min_s + _rand * (max_s - min_s)
max_keep = int(self.height * self.width * mask_scale)
# -- Sample block aspect-ratio
min_ar, max_ar = aspect_ratio_scale
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
# -- Compute block height and width (given scale and aspect-ratio)
h = int(round(math.sqrt(max_keep * aspect_ratio)))
w = int(round(math.sqrt(max_keep / aspect_ratio)))
while h >= self.height:
h -= 1
while w >= self.width:
w -= 1

return (h, w)

def _sample_block_mask(self, b_size, acceptable_regions=None):
h, w = b_size

def constrain_mask(mask, tries=0):
"""Helper to restrict given mask to a set of acceptable regions"""
N = max(int(len(acceptable_regions) - tries), 0)
for k in range(N):
mask *= acceptable_regions[k]

# --
# -- Loop to sample masks until we find a valid one
tries = 0
timeout = og_timeout = 20
valid_mask = False
while not valid_mask:
# -- Sample block top-left corner
top = torch.randint(0, self.height - h, (1,))
left = torch.randint(0, self.width - w, (1,))
mask = torch.zeros((self.height, self.width), dtype=torch.int32)
mask[top : top + h, left : left + w] = 1
# -- Constrain mask to a set of acceptable regions
if acceptable_regions is not None:
constrain_mask(mask, tries)
mask = torch.nonzero(mask.flatten())
# -- If mask too small try again
valid_mask = len(mask) > self.min_keep
if not valid_mask:
timeout -= 1
if timeout == 0:
tries += 1
timeout = og_timeout
mask = mask.squeeze()
# --
mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
mask_complement[top : top + h, left : left + w] = 0
# --
return mask, mask_complement

def __call__(self, batch):
"""
Create encoder and predictor masks when collating imgs into a batch
# 1. sample enc block (size + location) using seed
# 2. sample pred block (size) using seed
# 3. sample several enc block locations for each image (w/o seed)
# 4. sample several pred block locations for each image (w/o seed)
# 5. return enc mask and pred mask
"""
B = len(batch)

collated_batch = torch.utils.data.default_collate(batch)

seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
p_size = self._sample_block_size(
generator=g,
scale=self.pred_mask_scale,
aspect_ratio_scale=self.aspect_ratio,
)
e_size = self._sample_block_size(
generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0)
)

collated_masks_pred, collated_masks_enc = [], []
min_keep_pred = self.height * self.width
min_keep_enc = self.height * self.width
for _ in range(B):
masks_p, masks_C = [], []
for _ in range(self.npred):
mask, mask_C = self._sample_block_mask(p_size)
masks_p.append(mask)
masks_C.append(mask_C)
min_keep_pred = min(min_keep_pred, len(mask))
collated_masks_pred.append(masks_p)

acceptable_regions = masks_C

if self.allow_overlap:
acceptable_regions = None

masks_e = []
for _ in range(self.nenc):
mask, _ = self._sample_block_mask(
e_size, acceptable_regions=acceptable_regions
)
masks_e.append(mask)
min_keep_enc = min(min_keep_enc, len(mask))
collated_masks_enc.append(masks_e)

collated_masks_pred = [
[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred
]
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
# --
collated_masks_enc = [
[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc
]
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

return collated_batch, collated_masks_enc, collated_masks_pred


def _deprecation_warning_collate_functions() -> None:
warn(
"Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n"
Expand Down
Loading

0 comments on commit a55cf44

Please sign in to comment.