diff --git a/.gitignore b/.gitignore index d7ae58f9..ff65996f 100755 --- a/.gitignore +++ b/.gitignore @@ -167,4 +167,5 @@ hparams.yaml data/pretrained_models -*.tar \ No newline at end of file +*.tar +*.ckpt \ No newline at end of file diff --git a/setup.py b/setup.py index 1bf92be6..4f9d7188 100755 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ 'pyarrow', # to read parquet, which is very handy for big datasets # for saving metrics to weights&biases (cloud service, free within limits) 'wandb', + 'huggingface_hub', # login may be required 'setuptools', # no longer pinned 'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets) ] diff --git a/tests/test_from_hub.py b/tests/test_from_hub.py new file mode 100644 index 00000000..9bc1c115 --- /dev/null +++ b/tests/test_from_hub.py @@ -0,0 +1,38 @@ +import pytest + +import timm +import torch + + +def test_get_encoder(): + model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True) + assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280) + + +def test_get_finetuned(): + # checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning + # checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually + + from huggingface_hub import hf_hub_download + + REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal" + FILENAME = "4.ckpt" + + downloaded_loc = hf_hub_download( + repo_id=REPO_ID, + filename=FILENAME, + ) + from zoobot.pytorch.training import finetune + model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', + assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) + + + +# def test_get_finetuned_from_local(): +# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt' +# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt' + +# from zoobot.pytorch.training import finetune +# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam. +# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ) +# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) \ No newline at end of file diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index d35dc151..d04ab746 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -480,4 +480,7 @@ def schema_to_campaigns(schema): if __name__ == '__main__': encoder = get_pytorch_encoder(channels=1) dim = get_encoder_dim(encoder, channels=1) - print(dim) \ No newline at end of file + print(dim) + + + ZoobotTree.load_from_checkpoint \ No newline at end of file diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index f1ad8b5e..0ef638fb 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -43,11 +43,17 @@ class FinetuneableZoobotAbstract(pl.LightningModule): Both :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree` can (and should) be passed any of these arguments to customise finetuning. - You could subclass this class to solve new finetuning tasks (like regression) - see :ref:`advanced_finetuning`. + Any FinetuneableZoobot model can be loaded in one of three ways: + - HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended. + - Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...) + - ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...) + + You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`. Args: - checkpoint_loc (str, optional): Path to encoder checkpoint to load (likely a saved ZoobotTree). Defaults to None. - encoder (pl.LightningModule, optional): Alternatively, pass an encoder directly. Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. + name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None. + encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory + zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. Defaults to None. encoder_dim (int, optional): Output dimension of encoder. Defaults to 1280 (EfficientNetB0's encoder dim). lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75. weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05. @@ -61,26 +67,39 @@ class FinetuneableZoobotAbstract(pl.LightningModule): def __init__( self, - # can provide either zoobot_checkpoint_loc, and will load this model as encoder... - zoobot_checkpoint_loc=None, + + # load a pretrained timm encoder saved on huggingface hub + # (aimed at most users, easiest way to load published models) + name=None, + # ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later) - encoder=None, + # (aimed at tinkering with new architectures e.g. SSL) + encoder=None, # use any torch model already loaded in memory (must have .forward() method) + + # load a pretrained zoobottree model and grab the encoder (a timm model) + # requires the exact same zoobot version used for training, not very portable + # (aimed at supervised experiments) + zoobot_checkpoint_loc=None, + + # finetuning settings n_blocks=0, # how many layers deep to FT lr_decay=0.75, weight_decay=0.05, learning_rate=1e-4, # 10x lower than typical, you may like to experiment dropout_prob=0.5, always_train_batchnorm=False, # temporarily deprecated - prog_bar=True, - visualize_images=False, # upload examples to wandb, good for debugging - seed=42, n_layers=0, # for backward compat., n_blocks preferred # these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already cosine_schedule=False, warmup_epochs=10, max_cosine_epochs=100, max_learning_rate_reduction_factor=0.01, - from_scratch=False + # escape hatch for 'from scratch' baselines + from_scratch=False, + # debugging utils + prog_bar=True, + visualize_images=False, # upload examples to wandb, good for debugging + seed=42 ): super().__init__() @@ -95,17 +114,22 @@ def __init__( self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy # if you need the encoder to recreate, pass when loading checkpoint e.g. # FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder) - - if zoobot_checkpoint_loc is not None: - assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use' - self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) + + if name is not None: + assert encoder is None, 'Cannot pass both name and encoder to use' + self.encoder = timm.create_model(name, pretrained=True) + self.encoder_dim = self.encoder.num_features + + elif zoobot_checkpoint_loc is not None: + assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use' + self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder + self.encoder_dim = self.encoder.num_features else: - assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use' - assert encoder is not None, 'Must pass either checkpoint to load or encoder to use' - self.encoder = encoder - - # TODO read as encoder property - self.encoder_dim = define_model.get_encoder_dim(self.encoder) + assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use' + assert encoder is not None, 'Must pass either checkpoint to load or encoder to use' + self.encoder = encoder + # work out encoder dim 'manually' + self.encoder_dim = define_model.get_encoder_dim(self.encoder) # for backwards compat. if n_layers: diff --git a/zoobot/pytorch/training/representations.py b/zoobot/pytorch/training/representations.py index dd8912a1..1c5fab19 100644 --- a/zoobot/pytorch/training/representations.py +++ b/zoobot/pytorch/training/representations.py @@ -1,17 +1,53 @@ +import logging import pytorch_lightning as pl +from timm import create_model + + class ZoobotEncoder(pl.LightningModule): - # very simple wrapper to turn pytorch model into lightning module - # useful when we want to use lightning to make predictions with our encoder - # (i.e. to get representations) - def __init__(self, encoder, pyramid=False) -> None: - super().__init__() + def __init__(self, encoder): + logging.info('ZoobotEncoder: using provided in-memory encoder') self.encoder = encoder # plain pytorch module e.g. Sequential - if pyramid: - raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features') + def forward(self, x): if isinstance(x, list) and len(x) == 1: return self(x[0]) return self.encoder(x) + + @classmethod + def load_from_name(cls, name: str): + """ + e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano') + Args: + name (str): huggingface hub name to load + + Returns: + nn.Module: timm model + """ + timm_model = create_model(name) + return cls(timm_model) + + + + + +class ZoobotEncoder(pl.LightningModule): + # very simple wrapper to turn pytorch model into lightning module + # useful when we want to use lightning to make predictions with our encoder + # (i.e. to get representations) + + # pretrained_cfg, pretrained_cfg_overlay=timm_kwargs + def __init__(self, architecture_name=None, channels=None, timm_kwargs={}) -> None: + super().__init__() + + logging.info('ZoobotEncoder: using timm encoder') + self.encoder = + + # if pyramid: + # raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features') + + +# def save_timm_encoder(): +