Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I-JEPA #1273

Merged
merged 40 commits into from
Jul 14, 2023
Merged

I-JEPA #1273

Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5c6d01a
very dirty draft
Jun 1, 2023
a1c0c15
little refactoring
Jun 1, 2023
c85ee97
refactoring
Jun 1, 2023
0ece1d3
refactoring
Jun 1, 2023
6a3f7fe
+ encoder. TODO: decoder based on causal attention
Jun 2, 2023
7102910
fix imports
Jun 2, 2023
0c5bac6
add Decoder class to consistency between code
Jun 2, 2023
7c41f59
change naming; change class structure
Jun 3, 2023
d99994d
change module; add example
Jun 6, 2023
963c7f1
Merge branch 'lightly-ai:master' into master
Natyren Jun 7, 2023
4e585e7
few refactoring
Jun 7, 2023
57f9020
Merge branch 'master' of https://github.com/Natyren/lightly
Jun 7, 2023
9034061
del line
Jun 7, 2023
035ed5a
del comments
Jun 7, 2023
3b41342
del comment
Jun 7, 2023
4646786
del comment
Jun 7, 2023
28fd7db
del line
Natyren Jun 7, 2023
de1106c
pass
Natyren Jun 7, 2023
9995df6
pass
Natyren Jun 7, 2023
c23de39
add model itself, todo: train loop and debug
Natyren Jul 9, 2023
ca171bf
add collator;
Natyren Jul 9, 2023
e1b97ec
added template to train code and transforms
Natyren Jul 9, 2023
612f7cc
add train in pure pytorch
Natyren Jul 12, 2023
d9a9fd8
little fix
Natyren Jul 12, 2023
a77045d
Merge branch 'lightly-ai:master' into master
Natyren Jul 12, 2023
adac38b
little fix
Natyren Jul 12, 2023
800aed6
little fix
Natyren Jul 12, 2023
3a92cda
fix classmethod
Natyren Jul 13, 2023
485f9fc
fix collator
Natyren Jul 13, 2023
80cc077
fixes
Natyren Jul 13, 2023
45510e2
fix in collator
Natyren Jul 13, 2023
263c221
fix collators, added imports, fix models
Natyren Jul 13, 2023
6939ff9
add ijepa backbone
Natyren Jul 13, 2023
0fd4639
finish pure torch impelementation
Natyren Jul 13, 2023
72ca5cb
docstring fix
Natyren Jul 13, 2023
605cdc2
fixes of name and references to original paper
Natyren Jul 14, 2023
1be6453
Format
guarin Jul 14, 2023
d35b7a9
Add note about experimental support
guarin Jul 14, 2023
d262f12
Add datasets to gitignore
guarin Jul 14, 2023
3dafc05
Cleanup imports
guarin Jul 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions examples/pytorch/i_jepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import copy

from lightly.models import utils
from lightly.models.modules import i_jepa
from lightly.transforms.ijepa_transform import IJEPATransform
from lightly.data.collate import IJEPAMaskCollator

from tqdm import tqdm


class I_JEPA(nn.Module):
def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):
super().__init__()
self.encoder = i_jepa.IJEPA_Backbone.from_vit(vit_encoder)
self.predictor = i_jepa.IJEPA_predictor.from_vit_encoder(vit_predictor.encoder, (vit_predictor.image_size//vit_predictor.patch_size)**2)
self.target_encoder = copy.deepcopy(self.encoder)
self.momentum_scheduler = momentum_scheduler

def forward_target(self, imgs, masks_enc, masks_pred):
with torch.no_grad():
h = self.target_encoder(imgs)
h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim
B = len(h)
# -- create targets (masked regions of h)
h = utils.apply_masks(h, masks_pred)
h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))
return h

def forward_context(self, imgs, masks_enc, masks_pred):
z = self.encoder(imgs, masks_enc)
z = self.predictor(z, masks_enc, masks_pred)
return z

def forward(self, imgs, masks_enc, masks_pred):
z = self.forward_context(imgs, masks_enc, masks_pred)
h = self.forward_target(imgs, masks_enc, masks_pred)
return z, h

def update_target_encoder(self,):
with torch.no_grad():
m = next(self.momentum_scheduler)
for param_q, param_k in zip(self.encoder.parameters(), self.target_encoder.parameters()):
param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)


collator = IJEPAMaskCollator(
input_size=(224,224),
patch_size=32,
)

transform = IJEPATransform()

# we ignore object detection annotations by setting target_transform to return 0
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
batch_size=10,
persistent_workers=False
)

ema = (0.996, 1.0)
ipe_scale = 1.0
ipe = len(data_loader)
num_epochs = 10
momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
for i in range(int(ipe*num_epochs*ipe_scale)+1))

vit_for_predictor = torchvision.models.vit_b_32(pretrained=False)
vit_for_embedder = torchvision.models.vit_b_32(pretrained=False)
model = I_JEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)

criterion = nn.SmoothL1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("Starting Training")
for epoch in range(num_epochs):
total_loss = 0
for udata, masks_enc, masks_pred in tqdm(data_loader):

def load_imgs():
# -- unsupervised imgs
imgs = udata[0].to(device, non_blocking=True)
masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
return (imgs, masks_1, masks_2)
imgs, masks_enc, masks_pred = load_imgs()
z, h = model(imgs, masks_enc, masks_pred)
loss = criterion(z, h)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
model.update_target_encoder()

avg_loss = total_loss / len(data_loader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
Empty file.
Empty file.
152 changes: 152 additions & 0 deletions lightly/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from lightly.transforms.random_crop_and_flip_with_grid import RandomResizedCropAndFlip
from lightly.transforms.rotation import random_rotation_transform
from lightly.transforms.utils import IMAGENET_NORMALIZE
from multiprocessing import Value
import math

imagenet_normalize = IMAGENET_NORMALIZE
# Kept for backwards compatibility
Expand Down Expand Up @@ -1345,6 +1347,156 @@ def forward(
return (views_global, views_local, grids_global, grids_local), labels, fnames


class IJEPAMaskCollator:

def __init__(
self,
input_size=(224, 224),
patch_size=16,
enc_mask_scale=(0.2, 0.8),
pred_mask_scale=(0.2, 0.8),
aspect_ratio=(0.3, 3.0),
nenc=1,
npred=2,
min_keep=4,
allow_overlap=False
):
if not isinstance(input_size, tuple):
input_size = (input_size, ) * 2
self.patch_size = patch_size
self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size
self.enc_mask_scale = enc_mask_scale
self.pred_mask_scale = pred_mask_scale
self.aspect_ratio = aspect_ratio
self.nenc = nenc
self.npred = npred
self.min_keep = min_keep # minimum number of patches to keep
self.allow_overlap = allow_overlap # whether to allow overlap b/w enc and pred masks
self._itr_counter = Value('i', -1) # collator is shared across worker processes

def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v

def _sample_block_size(self, generator, scale, aspect_ratio_scale):
_rand = torch.rand(1, generator=generator).item()
# -- Sample block scale
min_s, max_s = scale
mask_scale = min_s + _rand * (max_s - min_s)
max_keep = int(self.height * self.width * mask_scale)
# -- Sample block aspect-ratio
min_ar, max_ar = aspect_ratio_scale
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
# -- Compute block height and width (given scale and aspect-ratio)
h = int(round(math.sqrt(max_keep * aspect_ratio)))
w = int(round(math.sqrt(max_keep / aspect_ratio)))
while h >= self.height:
h -= 1
while w >= self.width:
w -= 1

return (h, w)

def _sample_block_mask(self, b_size, acceptable_regions=None):
h, w = b_size

def constrain_mask(mask, tries=0):
""" Helper to restrict given mask to a set of acceptable regions """
N = max(int(len(acceptable_regions)-tries), 0)
for k in range(N):
mask *= acceptable_regions[k]
# --
# -- Loop to sample masks until we find a valid one
tries = 0
timeout = og_timeout = 20
valid_mask = False
while not valid_mask:
# -- Sample block top-left corner
top = torch.randint(0, self.height - h, (1,))
left = torch.randint(0, self.width - w, (1,))
mask = torch.zeros((self.height, self.width), dtype=torch.int32)
mask[top:top+h, left:left+w] = 1
# -- Constrain mask to a set of acceptable regions
if acceptable_regions is not None:
constrain_mask(mask, tries)
mask = torch.nonzero(mask.flatten())
# -- If mask too small try again
valid_mask = len(mask) > self.min_keep
if not valid_mask:
timeout -= 1
if timeout == 0:
tries += 1
timeout = og_timeout
mask = mask.squeeze()
# --
mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
mask_complement[top:top+h, left:left+w] = 0
# --
return mask, mask_complement

def __call__(self, batch):
'''
Create encoder and predictor masks when collating imgs into a batch
# 1. sample enc block (size + location) using seed
# 2. sample pred block (size) using seed
# 3. sample several enc block locations for each image (w/o seed)
# 4. sample several pred block locations for each image (w/o seed)
# 5. return enc mask and pred mask
'''
B = len(batch)

collated_batch = torch.utils.data.default_collate(batch)

seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
p_size = self._sample_block_size(
generator=g,
scale=self.pred_mask_scale,
aspect_ratio_scale=self.aspect_ratio)
e_size = self._sample_block_size(
generator=g,
scale=self.enc_mask_scale,
aspect_ratio_scale=(1., 1.))

collated_masks_pred, collated_masks_enc = [], []
min_keep_pred = self.height * self.width
min_keep_enc = self.height * self.width
for _ in range(B):

masks_p, masks_C = [], []
for _ in range(self.npred):
mask, mask_C = self._sample_block_mask(p_size)
masks_p.append(mask)
masks_C.append(mask_C)
min_keep_pred = min(min_keep_pred, len(mask))
collated_masks_pred.append(masks_p)

acceptable_regions = masks_C

if self.allow_overlap:
acceptable_regions= None


masks_e = []
for _ in range(self.nenc):
mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
masks_e.append(mask)
min_keep_enc = min(min_keep_enc, len(mask))
collated_masks_enc.append(masks_e)

collated_masks_pred = [[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred]
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
# --
collated_masks_enc = [[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc]
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

return collated_batch, collated_masks_enc, collated_masks_pred


def _deprecation_warning_collate_functions() -> None:
warn(
"Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n"
Expand Down
Loading