Skip to content

Commit

Permalink
ssl changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 4, 2024
1 parent 7d1f379 commit 8d15167
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
17 changes: 13 additions & 4 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import types
from typing import Callable
import logging
import torch.utils.data
import numpy as np
Expand Down Expand Up @@ -27,7 +27,8 @@ def __init__(
color=False,
crop_scale_bounds=(0.7, 0.8),
crop_ratio_bounds=(0.9, 1.1),
resize_after_crop=224
resize_after_crop=224,
transform: Callable=None
):
super().__init__()

Expand Down Expand Up @@ -60,6 +61,8 @@ def __init__(
self.crop_scale_bounds = crop_scale_bounds
self.crop_ratio_bounds = crop_ratio_bounds

self.transform = transform

for url_name in ['train', 'val', 'test', 'predict']:
urls = getattr(self, f'{url_name}_urls')
if urls is not None:
Expand Down Expand Up @@ -98,7 +101,12 @@ def make_loader(self, urls, mode="train"):
assert mode in ['val', 'test', 'predict'], mode
shuffle = 0

transform_image = self.make_image_transform(mode=mode)
if self.transform is None:
logging.info('Using default transform')
transform_image = self.make_image_transform(mode=mode)
else:
logging.info('Ignoring hparams and using directly-passed transform')
transform_image = self.transform

transform_label = dict_to_label_cols_factory(self.label_cols)

Expand All @@ -109,7 +117,8 @@ def make_loader(self, urls, mode="train"):
if shuffle > 0:
dataset = dataset.shuffle(shuffle)

dataset = dataset.decode("rgb")
# dataset = dataset.decode("rgb") # np.array, for albumentations
dataset = dataset.decode("pilrgb") # PIL Image, for torchvision

if mode == 'predict':
if self.label_cols != ['id_str']:
Expand Down
26 changes: 22 additions & 4 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,23 @@ def train_default_zoobot_from_scratch(
)
else:
# this branch will use WebDataModule to load premade webdatasets

# temporary: use SSL-like transform
from foundation.models import transforms
# from omegaconf import DictConfig
# cfg = DictConfig({
# 'aug': {
# 'global_transform_0': {
# 'interpolation': 'bilinear',
# 'random_affine': {} # etc
# }

# }
# })
cfg = transforms.default_view_config()
cfg.output_size = resize_after_crop
transform = transforms.GalaxyViewTransform(cfg)

datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
val_urls=val_urls,
Expand All @@ -246,10 +263,11 @@ def train_default_zoobot_from_scratch(
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
# augmentation args
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop
transform=transform,
# color=color,
# crop_scale_bounds=crop_scale_bounds,
# crop_ratio_bounds=crop_ratio_bounds,
# resize_after_crop=resize_after_crop
)

datamodule.setup(stage='fit')
Expand Down

0 comments on commit 8d15167

Please sign in to comment.