diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index 075f4201..2968bd2b 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -1,5 +1,5 @@ import os -import types +from typing import Callable import logging import torch.utils.data import numpy as np @@ -27,7 +27,8 @@ def __init__( color=False, crop_scale_bounds=(0.7, 0.8), crop_ratio_bounds=(0.9, 1.1), - resize_after_crop=224 + resize_after_crop=224, + transform: Callable=None ): super().__init__() @@ -60,6 +61,8 @@ def __init__( self.crop_scale_bounds = crop_scale_bounds self.crop_ratio_bounds = crop_ratio_bounds + self.transform = transform + for url_name in ['train', 'val', 'test', 'predict']: urls = getattr(self, f'{url_name}_urls') if urls is not None: @@ -98,7 +101,12 @@ def make_loader(self, urls, mode="train"): assert mode in ['val', 'test', 'predict'], mode shuffle = 0 - transform_image = self.make_image_transform(mode=mode) + if self.transform is None: + logging.info('Using default transform') + transform_image = self.make_image_transform(mode=mode) + else: + logging.info('Ignoring hparams and using directly-passed transform') + transform_image = self.transform transform_label = dict_to_label_cols_factory(self.label_cols) @@ -109,7 +117,8 @@ def make_loader(self, urls, mode="train"): if shuffle > 0: dataset = dataset.shuffle(shuffle) - dataset = dataset.decode("rgb") + # dataset = dataset.decode("rgb") # np.array, for albumentations + dataset = dataset.decode("pilrgb") # PIL Image, for torchvision if mode == 'predict': if self.label_cols != ['id_str']: diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index ce3e0529..da734dec 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -235,6 +235,23 @@ def train_default_zoobot_from_scratch( ) else: # this branch will use WebDataModule to load premade webdatasets + + # temporary: use SSL-like transform + from foundation.models import transforms + # from omegaconf import DictConfig + # cfg = DictConfig({ + # 'aug': { + # 'global_transform_0': { + # 'interpolation': 'bilinear', + # 'random_affine': {} # etc + # } + + # } + # }) + cfg = transforms.default_view_config() + cfg.output_size = resize_after_crop + transform = transforms.GalaxyViewTransform(cfg) + datamodule = webdatamodule.WebDataModule( train_urls=train_urls, val_urls=val_urls, @@ -246,10 +263,11 @@ def train_default_zoobot_from_scratch( prefetch_factor=prefetch_factor, cache_dir=cache_dir, # augmentation args - color=color, - crop_scale_bounds=crop_scale_bounds, - crop_ratio_bounds=crop_ratio_bounds, - resize_after_crop=resize_after_crop + transform=transform, + # color=color, + # crop_scale_bounds=crop_scale_bounds, + # crop_ratio_bounds=crop_ratio_bounds, + # resize_after_crop=resize_after_crop ) datamodule.setup(stage='fit')