diff --git a/.gitignore b/.gitignore index 6199eb12..d7ae58f9 100755 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,6 @@ results hparams.yaml -data/pretrained_models \ No newline at end of file +data/pretrained_models + +*.tar \ No newline at end of file diff --git a/benchmarks/pytorch/run_benchmarks.sh b/benchmarks/pytorch/run_benchmarks.sh index b44791e3..07094601 100755 --- a/benchmarks/pytorch/run_benchmarks.sh +++ b/benchmarks/pytorch/run_benchmarks.sh @@ -16,7 +16,7 @@ SEED=$RANDOM # effnet, greyscale and color # sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB # sbatch --job-name=evo_py_gr_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB -sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB +# sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB # sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=128,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB # and resnet18 @@ -25,11 +25,13 @@ sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0, # and resnet50 # sbatch --job-name=evo_py_gr_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB # sbatch --job-name=evo_py_gr_res50_300_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB -# and with max-vit tiny because hey transformers are cool +# color 224 version +sbatch --job-name=evo_py_co_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB +# and with max-vit tiny because hey transformers are cool # smaller batch size due to memory -sbatch --job-name=evo_py_gr_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB -sbatch --job-name=evo_py_co_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB +# sbatch --job-name=evo_py_gr_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB +# sbatch --job-name=evo_py_co_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB # and max-vit small (works badly) # sbatch --job-name=evo_py_gr_vitsmall_224_$SEED --export=ARCHITECTURE=maxvit_small_224,BATCH_SIZE=64,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB diff --git a/docs/data_notes.rst b/docs/data_notes.rst index 9daf9c6d..e6ce0a4f 100755 --- a/docs/data_notes.rst +++ b/docs/data_notes.rst @@ -24,11 +24,6 @@ Zoobot includes weights for the following pretrained models. - 1 - Yes - `Link `__ - * - EfficientNetB0 - - 300px - - 3 - - Yes - - WIP * - EfficientNetB0 - 224px - 3 @@ -57,12 +52,12 @@ Zoobot includes weights for the following pretrained models. * - Max-ViT Tiny - 224px - 1 - - Not yet + - Yes - `Link `__ * - Max-ViT Tiny - 224px - 3 - - Not yet + - Yes - `Link `__ @@ -108,7 +103,7 @@ We also include a few additional ad-hoc models `on Dropbox = 2.0.0', 'albumentations', 'pyro-ppl>=1.8.0', 'torchmetrics==0.11.0', 'timm == 0.6.12' ], + # TODO may add narval/Digital Research Canada config 'tensorflow': [ 'tensorflow == 2.10.0', # 2.11.0 turns on XLA somewhere which then fails on multi-GPU...TODO 'keras_applications', @@ -95,13 +97,12 @@ 'pandas', 'scipy', 'astropy', # for reading fits - 'scikit-image >= 0.19.2', 'scikit-learn >= 1.0.2', 'matplotlib', 'pyarrow', # to read parquet, which is very handy for big datasets # for saving metrics to weights&biases (cloud service, free within limits) 'wandb', - 'setuptools==59.5.0', # wandb logger incompatibility - 'galaxy-datasets==0.0.14' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets) + 'setuptools', # no longer pinned + 'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets) ] ) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index e05bd536..b55d9fed 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -85,7 +85,7 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx): return self.make_step(batch, batch_idx, step_name='train') - def on_training_batch_end(self, outputs, *args): + def on_train_batch_end(self, outputs, *args): self.log_outputs(outputs, step_name='train') def validation_step(self, batch, batch_idx): @@ -94,6 +94,9 @@ def validation_step(self, batch, batch_idx): def on_validation_batch_end(self, outputs, *args): self.log_outputs(outputs, step_name='validation') + def log_outputs(self, outputs, step_name): + raise NotImplementedError('Must be subclassed') + def test_step(self, batch, batch_idx): return self.make_step(batch, batch_idx, step_name='test') diff --git a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py index b596ee04..c5309e8b 100644 --- a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py +++ b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py @@ -27,9 +27,10 @@ # To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine. # load a pretrained checkpoint saved here - # checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/temp/dr5_py_gr_2270/checkpoints/epoch=360-step=231762.ckpt') - checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt' - + # https://www.dropbox.com/s/7ixwo59imjfz4ay/effnetb0_greyscale_224px.ckpt?dl=0 + # see https://zoobot.readthedocs.io/en/latest/data_notes.html for more options + checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt') + # save the finetuning results here save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_binary_classification') @@ -70,8 +71,9 @@ finetuned_model, n_samples=1, label_cols=label_cols, - save_loc=os.path.join(save_dir, 'finetuned_predictions.csv') - # trainer_kwargs={'accelerator': 'gpu'} + save_loc=os.path.join(save_dir, 'finetuned_predictions.csv'), + datamodule_kwargs={'batch_size': 32}, # we also need to set batch size here, or you may run out of memory + trainer_kwargs={'accelerator': 'gpu'} ) """ Under the hood, this is essentially doing: diff --git a/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py b/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py index 98c8ca14..e58ee842 100644 --- a/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py +++ b/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py @@ -1,8 +1,10 @@ import logging import os +import pandas as pd + from zoobot.pytorch.training import finetune -from galaxy_datasets import demo_rings +from galaxy_datasets import galaxy_mnist from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule @@ -10,31 +12,29 @@ logging.basicConfig(level=logging.INFO) - zoobot_dir = '/Users/user/repos/zoobot' # TODO set to directory where you cloned Zoobot + zoobot_dir = '/home/walml/repos/zoobot' # TODO set to directory where you cloned Zoobot + data_dir = '/home/walml/repos/galaxy-datasets/roots/galaxy_mnist' # TODO set to any directory. rings dataset will be downloaded here + batch_size = 32 + num_workers= 8 + n_blocks = 1 # EffnetB0 is divided into 7 blocks. set 0 to only fit the head weights. Set 1, 2, etc to finetune deeper. + max_epochs = 1 # 6 epochs should get you ~93% accuracy. Set much higher (e.g. 1000) for harder problems, to use Zoobot's default early stopping. + # the remaining key parameters for high accuracy are weight_decay, learning_rate, and lr_decay. You might like to tinker with these. # load in catalogs of images and labels to finetune on # each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels # here I'm using galaxy-datasets to download some premade data - check it out for examples - data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings' # TODO set to any directory. rings dataset will be downloaded here - train_catalog, _ = demo_rings(root=data_dir, download=True, train=True) - test_catalog, _ = demo_rings(root=data_dir, download=True, train=False) + + train_catalog, _ = galaxy_mnist(root=data_dir, download=True, train=True) + test_catalog, _ = galaxy_mnist(root=data_dir, download=True, train=False) # wondering about "label_cols"? # This is a list of catalog columns which should be used as labels - # Here: - # TODO should use Galaxy MNIST as my example here - label_cols = ['ring'] - # For binary classification, the label column should have binary (0 or 1) labels for your classes - import numpy as np - # 0, 1, 2 - train_catalog['ring'] = np.random.randint(low=0, high=3, size=len(train_catalog)) - - # TODO - # To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine. - + # Here, it's a single column, 'label', with values 0-3 (for each of the 4 classes) + label_cols = ['label'] + num_classes = 4 + # load a pretrained checkpoint saved here checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt') - # checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt' # save the finetuning results here save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_multiclass_classification') @@ -42,9 +42,11 @@ datamodule = GalaxyDataModule( label_cols=label_cols, catalog=train_catalog, # very small, as a demo - batch_size=32 + batch_size=batch_size, # increase for faster training, decrease to avoid out-of-memory errors + num_workers=num_workers # TODO set to a little less than num. CPUs ) - # datamodule.setup() + datamodule.setup() + # optionally, check the data loads and looks okay # for images, labels in datamodule.train_dataloader(): # print(images.shape) # print(labels.shape) @@ -53,31 +55,38 @@ model = finetune.FinetuneableZoobotClassifier( checkpoint_loc=checkpoint_loc, - num_classes=3, - n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper. + num_classes=num_classes, + n_blocks=n_blocks ) # under the hood, this does: # encoder = finetune.load_pretrained_encoder(checkpoint_loc) # model = finetune.FinetuneableZoobotClassifier(encoder=encoder, ...) # retrain to find rings - trainer = finetune.get_trainer(save_dir, accelerator='cpu', max_epochs=1) + trainer = finetune.get_trainer(save_dir, accelerator='auto', max_epochs=max_epochs) trainer.fit(model, datamodule) # can now use this model or saved checkpoint to make predictions on new data. Well done! + # see how well the model performs + # (don't do this all the time) + trainer.test(model, datamodule) + + # we can load the model later any time # pretending we want to load from scratch: best_checkpoint = trainer.checkpoint_callback.best_model_path finetuned_model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(best_checkpoint) from zoobot.pytorch.predictions import predict_on_catalog + predictions_save_loc = os.path.join(save_dir, 'finetuned_predictions.csv') predict_on_catalog.predict( test_catalog, finetuned_model, n_samples=1, - label_cols=label_cols, - save_loc=os.path.join(save_dir, 'finetuned_predictions.csv') - # trainer_kwargs={'accelerator': 'gpu'} + label_cols=['class_{}'.format(n) for n in range(num_classes)], # TODO feel free to rename, it's just for the csv header + save_loc=predictions_save_loc, + trainer_kwargs={'accelerator': 'auto'}, + datamodule_kwargs={'batch_size': batch_size, 'num_workers': num_workers} ) """ Under the hood, this is essentially doing: @@ -91,4 +100,9 @@ ) preds = predict_trainer.predict(finetuned_model, predict_datamodule) print(preds) - """ \ No newline at end of file + """ + + predictions = pd.read_csv(predictions_save_loc) + print(predictions) + + exit() # now over to you! diff --git a/zoobot/pytorch/predictions/predict_on_catalog.py b/zoobot/pytorch/predictions/predict_on_catalog.py index 7acac09d..3a68ab88 100644 --- a/zoobot/pytorch/predictions/predict_on_catalog.py +++ b/zoobot/pytorch/predictions/predict_on_catalog.py @@ -11,7 +11,7 @@ from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule -def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}): +def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}) -> None: """ Use trained model to make predictions on a catalog of galaxies. @@ -55,12 +55,19 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la start = datetime.datetime.fromtimestamp(time.time()) logging.info('Starting at: {}'.format(start.strftime('%Y-%m-%d %H:%M:%S'))) - logging.info(len(trainer.predict(model, predict_datamodule))) + # logging.info(len(trainer.predict(model, predict_datamodule))) # trainer.predict gives list of tensors, each tensor being predictions for a batch. Concat on axis 0. # range(n_samples) list comprehension repeats this, for dropout-permuted predictions. Stack to create new last axis. # final shape (n_galaxies, n_answers, n_samples) - predictions = torch.stack([torch.concat(trainer.predict(model, predict_datamodule), dim=0) for n in range(n_samples)], dim=-1).numpy() + predictions = torch.stack( + [ + # trainer.predict gives [(galaxy, answer), ...] list, batchwise + # concat batches + torch.concat(trainer.predict(model, predict_datamodule), dim=0) + for n in range(n_samples) + ], + dim=-1).numpy() # now stack on final dim for (galaxy, answer, dropout) shape logging.info('Predictions complete - {}'.format(predictions.shape)) logging.info(f'Saving predictions to {save_loc}') diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 57dbaabf..6f3cf3e2 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -24,7 +24,7 @@ def freeze_batchnorm_layers(model): for name, child in (model.named_children()): if isinstance(child, torch.nn.BatchNorm2d): - logging.debug('freezing {} {}'.format(child, name)) + logging.debug('Freezing {} {}'.format(child, name)) child.eval() # no grads, no param updates, no statistic updates else: freeze_batchnorm_layers(child) # recurse @@ -64,15 +64,16 @@ def __init__( encoder=None, encoder_dim=1280, # as per current Zooot. TODO Could get automatically? n_epochs=100, # TODO early stopping - n_layers=0, # how many layers deep to FT + n_blocks=0, # how many layers deep to FT lr_decay=0.75, weight_decay=0.05, - learning_rate=1e-4, + learning_rate=1e-4, # 10x lower than typical, you may like to experiment dropout_prob=0.5, - freeze_batchnorm=True, + always_train_batchnorm=True, prog_bar=True, visualize_images=False, # upload examples to wandb, good for debugging - seed=42 + seed=42, + n_layers=0 # for backward compat., n_blocks preferred ): super().__init__() @@ -91,11 +92,17 @@ def __init__( self.encoder = load_pretrained_encoder(checkpoint_loc) else: assert checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use' + assert encoder is not None, 'Must pass either checkpoint to load or encoder to use' self.encoder = encoder self.encoder_dim = encoder_dim - self.n_layers = n_layers - self.freeze = True if n_layers == 0 else False + self.n_blocks = n_blocks + + # for backwards compat. + if n_layers: + logging.warning('FinetuneableZoobot(n_layers) is now renamed to n_blocks, please update to pass n_blocks instead! For now, setting n_blocks=n_layers') + self.n_blocks = n_layers + logging.info('Layers to finetune: {}'.format(n_layers)) self.learning_rate = learning_rate self.lr_decay = lr_decay @@ -103,60 +110,86 @@ def __init__( self.dropout_prob = dropout_prob self.n_epochs = n_epochs - self.freeze_batchnorm = freeze_batchnorm + self.always_train_batchnorm = always_train_batchnorm + if self.always_train_batchnorm: + logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned') self.train_loss_metric = tm.MeanMetric() self.val_loss_metric = tm.MeanMetric() self.test_loss_metric = tm.MeanMetric() - - if self.freeze_batchnorm: - freeze_batchnorm_layers(self.encoder) # inplace - self.seed = seed self.prog_bar = prog_bar self.visualize_images = visualize_images def configure_optimizers(self): - if self.freeze: - params = self.head.parameters() - return torch.optim.AdamW(params, betas=(0.9, 0.999), lr=self.learning_rate) + lr = self.learning_rate + params = [{"params": self.head.parameters(), "lr": lr}] + + if hasattr(self.encoder, 'blocks'): + logging.info('Effnet detected') + # TODO this actually excludes the first conv layer/bn + encoder_blocks = self.encoder.blocks + blocks_to_tune = list(encoder_blocks) + elif hasattr(self.encoder, 'layer4'): + logging.info('Resnet detected') + # similarly, excludes first conv/bn + blocks_to_tune = [ + self.encoder.layer1, + self.encoder.layer2, + self.encoder.layer3, + self.encoder.layer4 + ] + elif hasattr(self.encoder, 'stages'): + logging.info('Max-ViT Tiny detected') + blocks_to_tune = [ + # getattr as obj.0 is not allowed (why does timm call them 0!?) + getattr(self.encoder.stages, '0'), + getattr(self.encoder.stages, '1'), + getattr(self.encoder.stages, '2'), + getattr(self.encoder.stages, '3'), + ] else: - lr = self.learning_rate - params = [{"params": self.head.parameters(), "lr": lr}] - - # this bit is specific to Zoobot EffNet - # TODO check these are blocks not individual layers - encoder_blocks = list(self.encoder.children()) - - # for n, l in enumerate(encoder_blocks): - # print('\n') - # print(n) - # print(l) - - # layers with no parameters don't count - # TODO double-check is_tuneable - tuneable_blocks = [b for b in encoder_blocks if is_tuneable(b)] - - assert self.n_layers <= len( - tuneable_blocks - ), f"Network only has {len(tuneable_blocks)} tuneable blocks, {self.n_layers} specified for finetuning" - - # Append parameters of layers for finetuning along with decayed learning rate - blocks_to_tune = tuneable_blocks[:self.n_layers] - blocks_to_tune.reverse() # highest block to lowest block - for i, layer in enumerate(blocks_to_tune): + raise ValueError('Encoder architecture not automatically recognised') + + assert self.n_blocks <= len( + blocks_to_tune + ), f"Network only has {len(blocks_to_tune)} tuneable blocks, {self.n_blocks} specified for finetuning" + + + # take n blocks, ordered highest layer to lowest layer + blocks_to_tune.reverse() + # will finetune all params in first N + blocks_to_tune = blocks_to_tune[:self.n_blocks] + # optionally, can finetune batchnorm params in remaining layers + remaining_blocks = blocks_to_tune[self.n_blocks:] + + # Append parameters of layers for finetuning along with decayed learning rate + for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3' + params.append({ + "params": block.parameters(), + "lr": lr * (self.lr_decay**i) + }) + + logging.debug(params) + + # optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers + for i, block in enumerate(remaining_blocks): + if self.always_train_batchnorm: params.append({ - "params": layer.parameters(), + "params": get_batch_norm_params_lighting(block), "lr": lr * (self.lr_decay**i) }) - # Initialize AdamW optimizer - opt = torch.optim.AdamW( - params, weight_decay=self.weight_decay, betas=(0.9, 0.999)) # higher weight decay is typically good + # TODO this actually breaks training because the generator only iterates once! + # total_params = sum(p.numel() for param_set in params.copy() for p in param_set['params']) + # logging.info('Total params to fit: {}'.format(total_params)) - return opt + # Initialize AdamW optimizer + opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict + + return opt def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -324,14 +357,14 @@ def on_test_batch_end(self, step_output, *args) -> None: def predict_step(self, x, batch_idx): x = self.forward(x) # logits from LinearClassifier # then applies softmax - return F.softmax(x, dim=1)[:, 1] + return F.softmax(x, dim=1) def upload_images_to_wandb(self, outputs, batch, batch_idx): # self.logger is set by pl.Trainer(logger=) argument if (self.logger is not None) and (batch_idx == 0): x, y = batch - y_pred_softmax = F.softmax(outputs['predictions'], dim=1)[:, 1] # odds of class 1 (assumed binary) + y_pred_softmax = F.softmax(outputs['predictions'], dim=1) n_images = 5 images = [img for img in x[:n_images]] captions = [f'Ground Truth: {y_i} \nPrediction: {y_p_i}' for y_i, y_p_i in zip( @@ -501,6 +534,16 @@ def is_tuneable(block_of_layers): else: # currently, allowed to include batchnorm return True + +def get_batch_norm_params_lighting(parent_module, current_params=[]): + for child_module in parent_module.children(): + if isinstance(child_module, torch.nn.BatchNorm2d): + current_params += child_module.parameters() + else: + current_params = get_batch_norm_params_lighting(child_module, current_params) + return current_params + + # when ready (don't peek often, you'll overfit) # trainer.test(model, dataloaders=datamodule) diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py index 960253cd..8d32f878 100755 --- a/zoobot/shared/schemas.py +++ b/zoobot/shared/schemas.py @@ -277,9 +277,4 @@ def answers(self): # trigger basicConfig() and prevent user setting their own logging. # so don't log anything during Schema.__init__! -# temp for debugging -# print(label_metadata.desi_pairs) -# print(label_metadata.desi_dependencies) - -# print(desi_schema.questions) -# print(desi_schema.answers) +gz_evo_v1_schema = Schema(label_metadata.gz_evo_v1_pairs, label_metadata.gz_evo_v1_dependencies)