From edc58bc60959fafe6fa79f15078c1e78309bbe35 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Thu, 14 Mar 2024 10:24:38 -0400 Subject: [PATCH] add from_scratch override --- zoobot/pytorch/training/finetune.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 37d79c8d..f1ad8b5e 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -79,7 +79,8 @@ def __init__( cosine_schedule=False, warmup_epochs=10, max_cosine_epochs=100, - max_learning_rate_reduction_factor=0.01 + max_learning_rate_reduction_factor=0.01, + from_scratch=False ): super().__init__() @@ -123,6 +124,8 @@ def __init__( self.max_cosine_epochs = max_cosine_epochs self.max_learning_rate_reduction_factor = max_learning_rate_reduction_factor + self.from_scratch = from_scratch + self.always_train_batchnorm = always_train_batchnorm if self.always_train_batchnorm: raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported') @@ -159,6 +162,11 @@ def configure_optimizers(self): logging.info(f'Encoder architecture to finetune: {type(self.encoder)}') + if self.from_scratch: + logging.warning('self.from_scratch is True, training everything and ignoring all settings') + params += [{"params": self.encoder.parameters(), "lr": lr}] + return torch.optim.AdamW(params, weight_decay=self.weight_decay) + if isinstance(self.encoder, timm.models.EfficientNet): # includes v2 # TODO for now, these count as separate layers, not ideal early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]