Skip to content

Commit

Permalink
Merge pull request #113 from mwalmsley/narval-migration
Browse files Browse the repository at this point in the history
Latest finetuning changes
  • Loading branch information
mwalmsley authored Mar 21, 2024
2 parents 15be580 + 31d03ba commit a1ce097
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 70 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.8", "3.9"] # zoobot should support these (many academics not on 3.9)
python-version: ["3.9"] # zoobot should support these
experimental: [false]
include:
- python-version: "3.10" # test the next python version but allow it to fail
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ hparams.yaml

data/pretrained_models

*.tar
*.tar
*.ckpt
14 changes: 0 additions & 14 deletions Dockerfile.tf

This file was deleted.

11 changes: 0 additions & 11 deletions docker-compose-tf.yml

This file was deleted.

6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"Environment :: GPU :: NVIDIA CUDA"
],
packages=setuptools.find_packages(),
python_requires=">=3.8", # recommend 3.9 for new users. TF needs >=3.7.2, torchvision>=3.8
python_requires=">=3.9", # bumped to 3.9 for typing
extras_require={
'pytorch-cpu': [
# A100 GPU currently only seems to support cuda 11.3 on manchester cluster, let's stick with this version for now
Expand Down Expand Up @@ -112,7 +112,9 @@
'pyarrow', # to read parquet, which is very handy for big datasets
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'webdataset', # for reading webdataset files
'huggingface_hub', # login may be required
'setuptools', # no longer pinned
'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
'galaxy-datasets>=0.0.17' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
]
)
3 changes: 2 additions & 1 deletion tests/pytorch/test_define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def schema():
def test_ZoobotTree_init(schema):
model = define_model.ZoobotTree(
output_dim=12,
question_index_groups=schema.question_index_groups,
question_answer_pairs=schema.question_answer_pairs,
dependencies=schema.dependencies
)

43 changes: 43 additions & 0 deletions tests/test_from_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

import timm
import torch


def test_get_encoder():
model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True)
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280)


def test_get_finetuned():
# checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning
# checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually

from huggingface_hub import hf_hub_download

REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal"
FILENAME = "FinetuneableZoobotClassifier.ckpt"

downloaded_loc = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
)
from zoobot.pytorch.training import finetune
model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

def test_get_finetuned_class_method():

from zoobot.pytorch.training import finetune

model = finetune.FinetuneableZoobotClassifier.load_from_name('mwalmsley/zoobot-finetuned-is_tidal', map_location='cpu')
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

# def test_get_finetuned_from_local():
# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt'
# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt'

# from zoobot.pytorch.training import finetune
# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam.
# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', )
# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
4 changes: 2 additions & 2 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def make_image_transform(self, mode="train"):
crop_ratio_bounds=self.crop_ratio_bounds,
resize_after_crop=self.resize_after_crop,
pytorch_greyscale=not self.color,
to_float=True # wrong, webdataset rgb decoder already converts to 0-1 float
# TODO this must be changed! will be different for new model training runs
to_float=False # True was wrong, webdataset rgb decoder already converts to 0-1 float
# TODO now changed on dev branch will be different for new model training runs
) # A.Compose object

# logging.warning('Minimal augmentations for speed test')
Expand Down
2 changes: 2 additions & 0 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse
# in augs that could be 0.x-1.0, and here a pre-crop to 0.8 i.e. 340px
# but this would change the centering
# let's stick to small boundary crop and 0.75-0.85 in augs

# turn these off for current euclidized images, already 300x300
A.CenterCrop(
height=400,
width=400,
Expand Down
22 changes: 11 additions & 11 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ class ZoobotTree(GenericLightningModule):
Args:
output_dim (int): Output dimension of model's head e.g. 34 for predicting a 34-answer decision tree.
question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0".
channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1.
test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
Expand All @@ -192,7 +191,7 @@ def __init__(
self,
output_dim: int,
# in the simplest case, this is all zoobot needs: grouping of label col indices as questions
question_index_groups: List=None,
# question_index_groups: List=None,
# BUT
# if you pass these, it enables better per-question and per-survey logging (because we have names)
# must be passed as simple dicts, not objects, so can't just pass schema in
Expand All @@ -219,7 +218,6 @@ def __init__(
super().__init__(
# these all do nothing, they are simply saved by lightning as hparams
output_dim,
question_index_groups,
question_answer_pairs,
dependencies,
architecture_name,
Expand All @@ -236,13 +234,12 @@ def __init__(

logging.info('Generic __init__ complete - moving to Zoobot __init__')

if question_answer_pairs is not None:
logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups
# logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups

self.setup_metrics()

Expand Down Expand Up @@ -480,4 +477,7 @@ def schema_to_campaigns(schema):
if __name__ == '__main__':
encoder = get_pytorch_encoder(channels=1)
dim = get_encoder_dim(encoder, channels=1)
print(dim)
print(dim)


ZoobotTree.load_from_checkpoint
96 changes: 76 additions & 20 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
Both :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree`
can (and should) be passed any of these arguments to customise finetuning.
You could subclass this class to solve new finetuning tasks (like regression) - see :ref:`advanced_finetuning`.
Any FinetuneableZoobot model can be loaded in one of three ways:
- HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...)
- ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`.
Args:
checkpoint_loc (str, optional): Path to encoder checkpoint to load (likely a saved ZoobotTree). Defaults to None.
encoder (pl.LightningModule, optional): Alternatively, pass an encoder directly. Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`.
name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None.
encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory
zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. Defaults to None.
encoder_dim (int, optional): Output dimension of encoder. Defaults to 1280 (EfficientNetB0's encoder dim).
lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
Expand All @@ -61,25 +67,39 @@ class FinetuneableZoobotAbstract(pl.LightningModule):

def __init__(
self,
# can provide either zoobot_checkpoint_loc, and will load this model as encoder...
zoobot_checkpoint_loc=None,

# load a pretrained timm encoder saved on huggingface hub
# (aimed at most users, easiest way to load published models)
name=None,

# ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later)
encoder=None,
# (aimed at tinkering with new architectures e.g. SSL)
encoder=None, # use any torch model already loaded in memory (must have .forward() method)

# load a pretrained zoobottree model and grab the encoder (a timm model)
# requires the exact same zoobot version used for training, not very portable
# (aimed at supervised experiments)
zoobot_checkpoint_loc=None,

# finetuning settings
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
always_train_batchnorm=False, # temporarily deprecated
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
n_layers=0, # for backward compat., n_blocks preferred
# these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
cosine_schedule=False,
warmup_epochs=10,
max_cosine_epochs=100,
max_learning_rate_reduction_factor=0.01
max_learning_rate_reduction_factor=0.01,
# escape hatch for 'from scratch' baselines
from_scratch=False,
# debugging utils
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42
):
super().__init__()

Expand All @@ -94,17 +114,22 @@ def __init__(
self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy
# if you need the encoder to recreate, pass when loading checkpoint e.g.
# FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder)

if zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc)

if name is not None:
assert encoder is None, 'Cannot pass both name and encoder to use'
self.encoder = timm.create_model(name, pretrained=True)
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder
self.encoder_dim = self.encoder.num_features
else:
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder

# TODO read as encoder property
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder
# work out encoder dim 'manually'
self.encoder_dim = define_model.get_encoder_dim(self.encoder)

# for backwards compat.
if n_layers:
Expand All @@ -123,6 +148,8 @@ def __init__(
self.max_cosine_epochs = max_cosine_epochs
self.max_learning_rate_reduction_factor = max_learning_rate_reduction_factor

self.from_scratch = from_scratch

self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported')
Expand Down Expand Up @@ -159,6 +186,11 @@ def configure_optimizers(self):

logging.info(f'Encoder architecture to finetune: {type(self.encoder)}')

if self.from_scratch:
logging.warning('self.from_scratch is True, training everything and ignoring all settings')
params += [{"params": self.encoder.parameters(), "lr": lr}]
return torch.optim.AdamW(params, weight_decay=self.weight_decay)

if isinstance(self.encoder, timm.models.EfficientNet): # includes v2
# TODO for now, these count as separate layers, not ideal
early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
Expand Down Expand Up @@ -345,6 +377,13 @@ def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx

def upload_images_to_wandb(self, outputs, batch, batch_idx):
raise NotImplementedError('Must be subclassed')

@classmethod
def load_from_name(cls, name: str, **kwargs):
downloaded_loc = download_from_name(cls.__name__, name, **kwargs)
return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error





Expand All @@ -364,6 +403,8 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract):
"""



def __init__(
self,
num_classes: int,
Expand Down Expand Up @@ -730,3 +771,18 @@ def get_trainer(
)

return trainer


def download_from_name(class_name: str, hub_name: str, **kwargs):
from huggingface_hub import hf_hub_download

if hub_name.startswith('hf_hub:'):
logging.info('Passed name with hf_hub: prefix, dropping prefix')
repo_id = hub_name.split('hf_hub:')[1]
else:
repo_id = hub_name
downloaded_loc = hf_hub_download(
repo_id=repo_id,
filename=f"{class_name}.ckpt"
)
return downloaded_loc
28 changes: 21 additions & 7 deletions zoobot/pytorch/training/representations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import logging
import pytorch_lightning as pl

from timm import create_model


class ZoobotEncoder(pl.LightningModule):
# very simple wrapper to turn pytorch model into lightning module
# useful when we want to use lightning to make predictions with our encoder
# (i.e. to get representations)

def __init__(self, encoder, pyramid=False) -> None:
super().__init__()
def __init__(self, encoder):
logging.info('ZoobotEncoder: using provided in-memory encoder')
self.encoder = encoder # plain pytorch module e.g. Sequential
if pyramid:
raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features')


def forward(self, x):
if isinstance(x, list) and len(x) == 1:
return self(x[0])
return self.encoder(x)

@classmethod
def load_from_name(cls, name: str):
"""
e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
Args:
name (str): huggingface hub name to load
Returns:
nn.Module: timm model
"""
timm_model = create_model(name)
return cls(timm_model)

Loading

0 comments on commit a1ce097

Please sign in to comment.