From 372bdb1d421ddf78668d7101426650db068d5119 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Mon, 15 Jan 2024 12:57:51 -0500 Subject: [PATCH] revert datamodule --- .../training/train_with_pytorch_lightning.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 9abc5225..8caea736 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -237,21 +237,11 @@ def train_default_zoobot_from_scratch( # this branch will use WebDataModule to load premade webdatasets # temporary: use SSL-like transform - from foundation.models import transforms - # from omegaconf import DictConfig - # cfg = DictConfig({ - # 'aug': { - # 'global_transform_0': { - # 'interpolation': 'bilinear', - # 'random_affine': {} # etc - # } - - # } - # }) - train_transform_cfg = transforms.default_view_config() - inference_transform_cfg = transforms.minimal_view_config() - train_transform_cfg.output_size = resize_after_crop - inference_transform_cfg.output_size = resize_after_crop + # from foundation.models import transforms + # train_transform_cfg = transforms.default_view_config() + # inference_transform_cfg = transforms.minimal_view_config() + # train_transform_cfg.output_size = resize_after_crop + # inference_transform_cfg.output_size = resize_after_crop datamodule = webdatamodule.WebDataModule( train_urls=train_urls, @@ -264,12 +254,13 @@ def train_default_zoobot_from_scratch( prefetch_factor=prefetch_factor, cache_dir=cache_dir, # augmentation args - train_transform=transforms.GalaxyViewTransform(train_transform_cfg), - inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg), - # color=color, - # crop_scale_bounds=crop_scale_bounds, - # crop_ratio_bounds=crop_ratio_bounds, - # resize_after_crop=resize_after_crop + color=color, + crop_scale_bounds=crop_scale_bounds, + crop_ratio_bounds=crop_ratio_bounds, + resize_after_crop=resize_after_crop, + # temporary: use SSL-like transform + # train_transform=transforms.GalaxyViewTransform(train_transform_cfg), + # inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg), ) datamodule.setup(stage='fit') @@ -352,7 +343,7 @@ def train_default_zoobot_from_scratch( # can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs. # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting - if test_catalog is not None: + if datamodule.test_dataloader is not None: logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') test_trainer.validate( model=lightning_model,