diff --git a/.github/workflows/run_CI.yml b/.github/workflows/run_CI.yml
index 97a93d03..b9a1ce31 100644
--- a/.github/workflows/run_CI.yml
+++ b/.github/workflows/run_CI.yml
@@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: true
matrix:
- python-version: ["3.8", "3.9"] # zoobot should support these (many academics not on 3.9)
+ python-version: ["3.9"] # zoobot should support these
experimental: [false]
include:
- python-version: "3.10" # test the next python version but allow it to fail
diff --git a/.gitignore b/.gitignore
index d7ae58f9..ff65996f 100755
--- a/.gitignore
+++ b/.gitignore
@@ -167,4 +167,5 @@ hparams.yaml
data/pretrained_models
-*.tar
\ No newline at end of file
+*.tar
+*.ckpt
\ No newline at end of file
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 2b64d49a..822a10af 100755
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -1,14 +1,17 @@
version: 2
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.9"
+
python:
- version: 3.8
install:
- method: pip
path: .
extra_requirements:
- docs
- pytorch_m1
- - tensorflow
sphinx:
fail_on_warning: true
\ No newline at end of file
diff --git a/Dockerfile.tf b/Dockerfile.tf
deleted file mode 100644
index e7cfa547..00000000
--- a/Dockerfile.tf
+++ /dev/null
@@ -1,14 +0,0 @@
-FROM tensorflow/tensorflow:2.8.0
-
-# if you have a supported nvidia GPU and https://github.com/NVIDIA/nvidia-docker
-# FROM tensorflow/tensorflow:2.8.0-gpu
-
-WORKDIR /usr/src/zoobot
-
-# install dependencies but remove tensorflow as it's in the base image
-COPY README.md .
-COPY setup.py .
-RUN pip install -U .[tensorflow]
-
-# install package
-COPY . .
diff --git a/README.md b/README.md
index ceb5d482..9a2ad02b 100755
--- a/README.md
+++ b/README.md
@@ -17,32 +17,30 @@ Zoobot is trained using millions of answers by Galaxy Zoo volunteers. This code
- [Install](#installation)
- [Quickstart](#quickstart)
- [Worked Examples](#worked-examples)
-- [Pretrained Weights](https://zoobot.readthedocs.io/en/latest/data_notes.html)
+- [Pretrained Weights](https://zoobot.readthedocs.io/en/latest/pretrained_models.html)
- [Datasets](https://www.github.com/mwalmsley/galaxy-datasets)
- [Documentation](https://zoobot.readthedocs.io/) (for understanding/reference)
## Installation
+
-You can retrain Zoobot in the cloud with a free GPU using this [Google Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing). To install locally, keep reading.
+You can retrain Zoobot in the cloud with a free GPU using this [Google Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing). To install locally, keep reading.
Download the code using git:
git clone git@github.com:mwalmsley/zoobot.git
-And then pick one of the three commands below to install Zoobot and either PyTorch (recommended) or TensorFlow:
+And then pick one of the three commands below to install Zoobot and PyTorch:
- # Zoobot with PyTorch and a GPU. Requires CUDA 11.3.
- pip install -e "zoobot[pytorch_cu113]" --extra-index-url https://download.pytorch.org/whl/cu113
+ # Zoobot with PyTorch and a GPU. Requires CUDA 12.1 (or CUDA 11.8, if you use `_cu118` instead)
+ pip install -e "zoobot[pytorch-cu121]" --extra-index-url https://download.pytorch.org/whl/cu121
# OR Zoobot with PyTorch and no GPU
- pip install -e "zoobot[pytorch_cpu]" --extra-index-url https://download.pytorch.org/whl/cpu
+ pip install -e "zoobot[pytorch-cpu]" --extra-index-url https://download.pytorch.org/whl/cpu
# OR Zoobot with PyTorch on Mac with M1 chip
- pip install -e "zoobot[pytorch_m1]"
-
- # OR Zoobot with TensorFlow. Works with and without a GPU, but if you have a GPU, you need CUDA 11.2.
- pip install -e "zoobot[tensorflow]
+ pip install -e "zoobot[pytorch-m1]"
This installs the downloaded Zoobot code using pip [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) so you can easily change the code locally. Zoobot is also available directly from pip (`pip install zoobot[option]`). Only use this if you are sure you won't be making changes to Zoobot itself. For Google Colab, use `pip install zoobot[pytorch_colab]`
@@ -50,13 +48,14 @@ To use a GPU, you must *already* have CUDA installed and matching the versions a
I share my install steps [here](#install_cuda). GPUs are optional - Zoobot will run retrain fine on CPU, just slower.
## Quickstart
+
-The [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) is the quickest way to get started. Alternatively, the minimal example below illustrates how Zoobot works.
+The [Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing) is the quickest way to get started. Alternatively, the minimal example below illustrates how Zoobot works.
Let's say you want to find ringed galaxies and you have a small labelled dataset of 500 ringed or not-ringed galaxies. You can retrain Zoobot to find rings like so:
-```python
+ ```python
import pandas as pd
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
@@ -77,11 +76,11 @@ Let's say you want to find ringed galaxies and you have a small labelled dataset
# retrain to find rings
trainer = finetune.get_trainer(save_dir)
trainer.fit(model, datamodule)
-```
+ ```
Then you can make predict if new galaxies have rings:
-```python
+ ```python
from zoobot.pytorch.predictions import predict_on_catalog
# csv with 'file_loc' column (path to image). Zoobot will predict the labels.
@@ -93,34 +92,31 @@ Then you can make predict if new galaxies have rings:
label_cols=['ring'], # only used for
save_loc='/your/path/finetuned_predictions.csv'
)
-```
+ ```
Zoobot includes many guides and working examples - see the [Getting Started](#getting-started) section below.
## Getting Started
+
-I suggest starting with the [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) or the worked examples below, which you can copy and adapt.
+I suggest starting with the [Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing) or the worked examples below, which you can copy and adapt.
For context and explanation, see the [documentation](https://zoobot.readthedocs.io/).
-For pretrained model weights, precalculated representations, catalogues, and so forth, see the [data notes](https://zoobot.readthedocs.io/en/latest/data_notes.html) in particular.
+Pretrained models are listed [here](https://zoobot.readthedocs.io/en/latest/pretrained_models.html) and available on [HuggingFace](https://huggingface.co/collections/mwalmsley/zoobot-encoders-65fa14ae92911b173712b874)
### Worked Examples
+
PyTorch (recommended):
+
- [pytorch/examples/finetuning/finetune_binary_classification.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py)
- [pytorch/examples/finetuning/finetune_counts_full_tree.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/pytorch/examples/finetuning/finetune_counts_full_tree.py)
- [pytorch/examples/representations/get_representations.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/pytorch/examples/representations/get_representations.py)
- [pytorch/examples/train_model_on_catalog.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/pytorch/examples/train_model_on_catalog.py) (only necessary to train from scratch)
-TensorFlow:
-- [tensorflow/examples/train_model_on_catalog.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/tensorflow/examples/train_model_on_catalog.py) (only necessary to train from scratch)
-- [tensorflow/examples/make_predictions.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/tensorflow/examples/make_predictions.py)
-- [tensorflow/examples/finetune_minimal.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/tensorflow/examples/finetune_minimal.py)
-- [tensorflow/examples/finetune_advanced.py](https://github.com/mwalmsley/zoobot/blob/main/zoobot/tensorflow/examples/finetune_advanced.py)
-
There is more explanation and an API reference on the [docs](https://zoobot.readthedocs.io/).
I also [include](https://github.com/mwalmsley/zoobot/blob/main/benchmarks) the scripts used to create and benchmark our pretrained models. Many pretrained models are available [already](https://zoobot.readthedocs.io/en/latest/data_notes.html), but if you need one trained on e.g. different input image sizes or with a specific architecture, I can probably make it for you.
@@ -128,45 +124,34 @@ I also [include](https://github.com/mwalmsley/zoobot/blob/main/benchmarks) the s
When trained with a decision tree head (ZoobotTree, FinetuneableZoobotTree), Zoobot can learn from volunteer labels of varying confidence and predict posteriors for what the typical volunteer might say. Specifically, this Zoobot mode predicts the parameters for distributions, not simple class labels! For a demonstration of how to interpret these predictions, see the [gz_decals_data_release_analysis_demo.ipynb](https://github.com/mwalmsley/zoobot/blob/main/gz_decals_data_release_analysis_demo.ipynb).
+### (Optional) Install PyTorch with CUDA
-### (Optional) Install PyTorch or TensorFlow, with CUDA
-*If you're not using a GPU, skip this step. Use the pytorch_cpu or tensorflow_cpu options in the section below.*
+*If you're not using a GPU, skip this step. Use the pytorch-cpu option in the section below.*
-Install PyTorch 1.12.1 or Tensorflow 2.10.0 and compatible CUDA drivers. I highly recommend using [conda](https://docs.conda.io/en/latest/miniconda.html) to do this. Conda will handle both creating a new virtual environment (`conda create`) and installing CUDA (`cudatoolkit`, `cudnn`)
+Install PyTorch 2.1.0 or Tensorflow 2.10.0 and compatible CUDA drivers. I highly recommend using [conda](https://docs.conda.io/en/latest/miniconda.html) to do this. Conda will handle both creating a new virtual environment (`conda create`) and installing CUDA (`cudatoolkit`, `cudnn`)
-CUDA 11.3 for PyTorch:
+CUDA 12.1 for PyTorch 2.1.0:
- conda create --name zoobot38_torch python==3.8
- conda activate zoobot38_torch
- conda install -c conda-forge cudatoolkit=11.3
+ conda create --name zoobot39_torch python==3.9
+ conda activate zoobot39_torch
+ conda install -c conda-forge cudatoolkit=12.1
-CUDA 11.2 and CUDNN 8.1 for TensorFlow 2.10.0:
+### Recent release features (v2.0.0)
- conda create --name zoobot38_tf python==3.8
- conda activate zoobot38_tf
- conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/ # add this environment variable
-
-### Latest minor features (v1.0.4)
-
-- Now supports multi-class finetuning. See `pytorch/examples/finetuning/finetune_multiclass_classification.py`
-- Removed `simplejpeg` dependency due to M1 install issue.
-- Pinned `timm` version to ensure MaX-ViT models load correctly. Models supporting the latest `timm` will follow.
-- (internal until published) GZ Evo v2 now includes Cosmic Dawn (HSC). Significant performance improvement on HSC finetuning.
-
-### Latest major features (v1.0.0)
-
-v1.0.0 recognises that most of the complexity in this repo is training Zoobot from scratch, but most non-GZ users will probably simply want to load the pretrained Zoobot and finetune it on their data.
-
-- Adds new finetuning interface (`finetune.run_finetuning()`), examples.
-- Refocuses docs on finetuning rather than training from scratch.
-- Rework installation process to separate CUDA from Zoobot (simpler, easier)
-- Better wandb logging throughout, to monitor training
-- Remove need to make TFRecords. Now TF directly uses images.
-- Refactor out augmentations and datasets to `galaxy-datasets` repo. TF and Torch now use identical augmentations (via albumentations).
-- Many small quality-of-life improvements
+- New pretrained architectures: ConvNeXT, EfficientNetV2, MaxViT, and more. Each in several sizes.
+- Reworked finetuning procedure. All these architectures are finetuneable through a common method.
+- Reworked finetuning options. Batch norm finetuning removed. Cosine schedule option added.
+- Reworked finetuning saving/loading. Auto-downloads encoder from HuggingFace.
+- Now supports regression finetuning (as well as multi-class and binary). See `pytorch/examples/finetuning`
+- Updated `timm` to 0.9.10, allowing latest model architectures. Previously downloaded checkpoints may not load correctly!
+- (internal until published) GZ Evo v2 now includes Cosmic Dawn (HSC H2O). Significant performance improvement on HSC finetuning. Also now includes GZ UKIDSS (dragged from our archives).
+- Updated `pytorch` to `2.1.0`
+- Added support for webdatasets (only recommended for large-scale distributed training)
+- Improved per-question logging when training from scratch
+- Added option to compile encoder for max speed (not recommended for finetuning, only for pretraining).
+- Deprecates TensorFlow. The CS research community focuses on PyTorch and new frameworks like JAX.
Contributions are very welcome and will be credited in any future work. Please get in touch! See [CONTRIBUTING.md](https://github.com/mwalmsley/zoobot/blob/main/benchmarks) for more.
@@ -176,6 +161,8 @@ The [benchmarks](https://github.com/mwalmsley/zoobot/blob/main/benchmarks) folde
Training Zoobot using the GZ DECaLS dataset option will create models very similar to those used for the GZ DECaLS catalogue and shared with the early versions of this repo. The GZ DESI Zoobot model is trained on additional data (GZD-1, GZD-2), as the GZ Evo Zoobot model (GZD-1/2/5, Hubble, Candels, GZ2).
+**Pretraining is becoming increasingly complex and is now partially refactored out to a separate repository. We are gradually migrating this `zoobot` repository to focus on finetuning.**
+
### Citing
If you use this software, or otherwise wish to cite Zoobot as a software package, please use the [JOSS paper](https://doi.org/10.21105/joss.05312):
@@ -189,10 +176,10 @@ You might be interested in reading papers using Zoobot:
- [Practical Galaxy Morphology Tools from Deep Supervised Representation Learning](https://arxiv.org/abs/2110.12735) (2022)
- [Towards Foundation Models for Galaxy Morphology](https://arxiv.org/abs/2206.11927) (2022)
- [Harnessing the Hubble Space Telescope Archives: A Catalogue of 21,926 Interacting Galaxies](https://arxiv.org/abs/2303.00366) (2023)
-- [Astronomaly at Scale: Searching for Anomalies Amongst 4 Million Galaxies](https://arxiv.org/abs/2309.08660) (2023)
- [Galaxy Zoo DESI: Detailed morphology measurements for 8.7M galaxies in the DESI Legacy Imaging Surveys](https://academic.oup.com/mnras/advance-article/doi/10.1093/mnras/stad2919/7283169?login=false) (2023)
- [Galaxy mergers in Subaru HSC-SSP: A deep representation learning approach for identification, and the role of environment on merger incidence](https://doi.org/10.1051/0004-6361/202346743) (2023)
-
-
+- [Astronomaly at Scale: Searching for Anomalies Amongst 4 Million Galaxies](https://arxiv.org/abs/2309.08660) (2023, submitted)
+- [Transfer learning for galaxy feature detection: Finding Giant Star-forming Clumps in low redshift galaxies using Faster R-CNN](https://arxiv.org/abs/2312.03503) (2023)
+- [Euclid preparation. Measuring detailed galaxy morphologies for Euclid with Machine Learning](https://arxiv.org/abs/2402.10187) (2024, submitted)
Many other works use Zoobot indirectly via the [Galaxy Zoo DECaLS](https://arxiv.org/abs/2102.08414) catalog (and now via the new [Galaxy Zoo DESI](https://academic.oup.com/mnras/advance-article/doi/10.1093/mnras/stad2919/7283169?login=false) catalog).
diff --git a/benchmarks/pytorch/run_benchmarks.sh b/benchmarks/pytorch/run_benchmarks.sh
index 07094601..3ff5e946 100755
--- a/benchmarks/pytorch/run_benchmarks.sh
+++ b/benchmarks/pytorch/run_benchmarks.sh
@@ -13,11 +13,11 @@ SEED=$RANDOM
# GZ Evo i.e. all galaxies
-# effnet, greyscale and color
-# sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
-# sbatch --job-name=evo_py_gr_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
-# sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
-# sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=128,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
+# effnet, greyscale and color, 224 and 300px
+sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
+sbatch --job-name=evo_py_gr_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
+sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
+sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=128,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
# and resnet18
# sbatch --job-name=evo_py_gr_res18_224_$SEED --export=ARCHITECTURE=resnet18,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
@@ -26,7 +26,7 @@ SEED=$RANDOM
# sbatch --job-name=evo_py_gr_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_gr_res50_300_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# color 224 version
-sbatch --job-name=evo_py_co_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
+# sbatch --job-name=evo_py_co_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# and with max-vit tiny because hey transformers are cool
# smaller batch size due to memory
@@ -35,11 +35,12 @@ sbatch --job-name=evo_py_co_res50_224_$SEED --export=ARCHITECTURE=resnet50,BATCH
# and max-vit small (works badly)
# sbatch --job-name=evo_py_gr_vitsmall_224_$SEED --export=ARCHITECTURE=maxvit_small_224,BATCH_SIZE=64,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
-# and convnext (works badly)
+# and convnext (works badly, would really like to try again but bigger)
# sbatch --job-name=evo_py_gr_$SEED --export=ARCHITECTURE=convnext_nano,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# and vit
# sbatch --job-name=evo_py_gr_vittinyp16_224_$SEED --export=ARCHITECTURE=vit_tiny_patch16_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
-
+# and swinv2
+# TODO
# and in color with no mixed precision, for specific project
# sbatch --job-name=evo_py_co_res50_224_fullprec_$SEED --export=ARCHITECTURE=resnet50,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
diff --git a/docker-compose-tf.yml b/docker-compose-tf.yml
deleted file mode 100644
index 7c3b5167..00000000
--- a/docker-compose-tf.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-version: '3'
-
-services:
- zoobot:
- image: zoobot:tensorflow
- build:
- context: ./
- dockerfile: Dockerfile.tf
- volumes:
- # inject the code at run time to allow edits etc
- - ./:/usr/src/zoobot
diff --git a/docs/autodoc/api.rst b/docs/autodoc/api.rst
deleted file mode 100755
index e60ef207..00000000
--- a/docs/autodoc/api.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-
-API
-====
-
-We encourage you to explore the code directly.
-There are many comments (and commented-out examples) which might be helpful.
-However, for convenience, you can check the docstrings directly here.
-
-
-.. toctree::
- :maxdepth: 2
-
- pytorch
- tensorflow
- shared
diff --git a/docs/autodoc/pytorch/training/finetune.rst b/docs/autodoc/pytorch/training/finetune.rst
index a23e767c..bd8a261c 100644
--- a/docs/autodoc/pytorch/training/finetune.rst
+++ b/docs/autodoc/pytorch/training/finetune.rst
@@ -7,6 +7,7 @@ See the `README `_ for a minimal example.
See zoobot/pytorch/examples for more worked examples.
.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract
+ :members: configure_optimizers
|
@@ -14,12 +15,27 @@ See zoobot/pytorch/examples for more worked examples.
|
+.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor
+
+|
+
.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotTree
|
+.. autoclass:: zoobot.pytorch.training.finetune.LinearHead
+ :members: forward
+
+|
+
+.. autofunction:: zoobot.pytorch.training.finetune.load_pretrained_zoobot
+
+|
+
.. autofunction:: zoobot.pytorch.training.finetune.get_trainer
|
-.. autofunction:: zoobot.pytorch.training.finetune.load_pretrained_encoder
+.. autofunction:: zoobot.pytorch.training.finetune.download_from_name
+
+|
\ No newline at end of file
diff --git a/docs/autodoc/shared/schemas.rst b/docs/autodoc/shared/schemas.rst
index 7df8e0a9..afafe8a1 100755
--- a/docs/autodoc/shared/schemas.rst
+++ b/docs/autodoc/shared/schemas.rst
@@ -26,6 +26,5 @@ See :ref:`training_on_vote_counts` for full explanation.
|
.. autoclass:: zoobot.shared.schemas.Schema
- :members:
|
\ No newline at end of file
diff --git a/docs/autodoc/tensorflow.rst b/docs/autodoc/tensorflow.rst
deleted file mode 100644
index c36b0943..00000000
--- a/docs/autodoc/tensorflow.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-tensorflow
-=============
-
-estimators
--------------
-
-.. toctree::
-
- tensorflow/estimators/define_model
- tensorflow/estimators/efficientnet_custom
-
-training
--------------
-
-.. toctree::
-
- tensorflow/training/finetune
- tensorflow/training/train_with_keras
- tensorflow/training/training_config
- tensorflow/training/losses
-
-predictions
--------------
-
-.. toctree::
-
- tensorflow/predictions/predict_on_dataset
diff --git a/docs/autodoc/tensorflow/estimators/define_model.rst b/docs/autodoc/tensorflow/estimators/define_model.rst
deleted file mode 100755
index 3bbe02ed..00000000
--- a/docs/autodoc/tensorflow/estimators/define_model.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-define_model
-===================
-
-This module contains functions for defining an EfficientNet model (:meth:`zoobot.estimators.define_model.get_model`),
-with or without the GZ DECaLS head, and optionally to load the weights of a pretrained model.
-
-Models are defined using functions in ``efficientnet_standard`` and ``efficientnet_custom``.
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.get_model
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.load_weights
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.load_model
diff --git a/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst b/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst
deleted file mode 100755
index ba656134..00000000
--- a/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-efficientnet_custom
-===================
-
-.. autofunction:: zoobot.tensorflow.estimators.efficientnet_custom.define_headless_efficientnet
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.efficientnet_custom.custom_top_dirichlet
-
diff --git a/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst b/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst
deleted file mode 100755
index eda2c76a..00000000
--- a/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-predict_on_dataset
-===================
-
-This module includes utilities to make predictions with a trained model on a list of images.
-
-.. autofunction:: zoobot.tensorflow.predictions.predict_on_dataset.predict
-
-|
-
-.. autofunction:: zoobot.tensorflow.predictions.predict_on_dataset.paths_in_folder
diff --git a/docs/autodoc/tensorflow/training/finetune.rst b/docs/autodoc/tensorflow/training/finetune.rst
deleted file mode 100644
index 6d0ceee3..00000000
--- a/docs/autodoc/tensorflow/training/finetune.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-.. _tensorflow_finetune:
-
-finetune
-===================
-
-Functions to load and adapt a trained (TensorFlow) Zoobot model to a new problem.
-
-:.. warning:: PyTorch is recommended for new users. See :ref:`pytorch_or_tensorflow` for more.
-
-
-.. autofunction:: zoobot.tensorflow.training.finetune.run_finetuning
diff --git a/docs/autodoc/tensorflow/training/losses.rst b/docs/autodoc/tensorflow/training/losses.rst
deleted file mode 100755
index e744c44c..00000000
--- a/docs/autodoc/tensorflow/training/losses.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-losses
-===================
-
-This module contains functions for calculating the custom Dirichlet-Multinomial loss used for Galaxy Zoo decision trees.
-
-
-.. autofunction:: zoobot.tensorflow.training.losses.get_multiquestion_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.calculate_multiquestion_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.dirichlet_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.get_dirichlet_neg_log_prob
diff --git a/docs/autodoc/tensorflow/training/train_with_keras.rst b/docs/autodoc/tensorflow/training/train_with_keras.rst
deleted file mode 100644
index 2c7026b2..00000000
--- a/docs/autodoc/tensorflow/training/train_with_keras.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-train_with_keras
-===================
-
-This is the interface to train new tensorflow models from scratch.
-
-.. autofunction:: zoobot.tensorflow.training.train_with_keras.train
diff --git a/docs/autodoc/tensorflow/training/training_config.rst b/docs/autodoc/tensorflow/training/training_config.rst
deleted file mode 100755
index e12d4b69..00000000
--- a/docs/autodoc/tensorflow/training/training_config.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-.. _training_config:
-
-training_config
-===================
-
-This module creates the :class:`Trainer` class for training a Zoobot model (itself a tf.keras.Model).
-Implements common features training like early stopping and tensorboard logging.
-
-Follows the same idea as the PyTorch Lightning Trainer object.
-
-.. autoclass:: zoobot.tensorflow.training.training_config.Trainer
- :members:
diff --git a/docs/conf.py b/docs/conf.py
index 227ce95e..87c633f5 100755
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -19,11 +19,11 @@
# -- Project information -----------------------------------------------------
project = 'Zoobot'
-copyright = '2023, Mike Walmsley'
+copyright = '2024, Mike Walmsley'
author = 'Mike Walmsley'
# The full version, including alpha/beta/rc tags
-release = '0.0.4'
+release = '2.0'
# -- General configuration ---------------------------------------------------
@@ -33,7 +33,8 @@
# ones.
extensions = [
'sphinx.ext.autodoc', # import docs from code
- 'sphinx.ext.napoleon' # google docstrings
+ 'sphinx.ext.napoleon', # google docstrings
+ 'sphinxemoji.sphinxemoji', # emoji support https://sphinxemojicodes.readthedocs.io/en/stable/
]
# Add any paths that contain templates here, relative to this directory.
diff --git a/docs/data_notes.rst b/docs/data_notes.rst
deleted file mode 100755
index e6ce0a4f..00000000
--- a/docs/data_notes.rst
+++ /dev/null
@@ -1,127 +0,0 @@
-.. _datanotes:
-
-Pretrained Models
-=================
-
-Zoobot includes weights for the following pretrained models.
-
-.. list-table:: PyTorch Models
- :widths: 70 35 35 35 35
- :header-rows: 1
-
- * - Architecture
- - Input Size
- - Channels
- - Finetune
- - Link
- * - EfficientNetB0
- - 224px
- - 1
- - Yes
- - `Link `__
- * - EfficientNetB0
- - 300px
- - 1
- - Yes
- - `Link `__
- * - EfficientNetB0
- - 224px
- - 3
- - Yes
- - `Link `__
- * - ResNet50
- - 300px
- - 1
- - Yes
- - `Link `__
- * - ResNet50
- - 224px
- - 1
- - Yes
- - `Link `__
- * - ResNet18
- - 300px
- - 1
- - Yes
- - `Link `__
- * - ResNet18
- - 224px
- - 1
- - Yes
- - `Link `__
- * - Max-ViT Tiny
- - 224px
- - 1
- - Yes
- - `Link `__
- * - Max-ViT Tiny
- - 224px
- - 3
- - Yes
- - `Link `__
-
-
-
-.. list-table:: TensorFlow Models
- :widths: 70 35 35 35 35
- :header-rows: 1
-
- * - Architecture
- - Input Size
- - Channels
- - Finetune
- - Link
- * - EfficientNetB0
- - 300px
- - 1
- - Yes
- - `Link `__
- * - EfficientNetB0
- - 224px
- - 1
- - Yes
- - WIP
-
-
-.. note::
-
- Missing a model you need? Reach out! There's a good chance we can train any small-ish model supported by `timm `_.
-
-All models are trained on the GZ Evo dataset described in the `Towards Foundation Models paper `_.
-This dataset includes 550k galaxy images and 92M votes drawn from every major Galaxy Zoo campaign: GZ2, GZ Hubble, GZ CANDELS, and GZ DECaLS/DESI.
-
-All models are trained on the same images shown to Galaxy Zoo volunteers.
-These are typically 424 pixels across.
-The images are transformed using the galaxy-datasets default transforms (random off-center crop/zoom, flips, rotation) and then resized to the desired input size (224px or 300px) and, for 1-channel models, channel-averaged.
-
-We also include a few additional ad-hoc models `on Dropbox `_.
-
-- EfficientNetB0 models pretrained only on GZ DECaLS GZD-5. For reference/comparison.
-- EfficientNetB0 models pretrained with smaller images (128px and 64px). For debugging.
-
-
-Which model should I use?
---------------------------
-
-We suggest the PyTorch EfficientNetB0 224-pixel model for most users.
-
-Zoobot will prioritise PyTorch going forward. For more, see here.
-The TensorFlow models currently perform just as well as the PyTorch equivalents but will not benefit from any future updates.
-
-EfficientNetB0 is a small yet capable modern architecture.
-The ResNet50 models perform slightly worse than EfficientNet, but are a very common architecture and may be useful as benchmarks or as part of other frameworks (like detectron2, for segmentation).
-
-It's unclear if color information improves overall performance at predicting GZ votes.
-For CNNs, the change in performance is not significant. For ViT, it is measureable but small.
-We suggesst including color if it is expected to be important to your specific task, such as hunting green peas.
-
-Larger input images (300px vs 224px) may provide a small boost in performance at predicting GZ votes.
-However, the models require more memory and train/finetune slightly more slowly.
-You may want to start with a 224px model and experiment with "upgrading" once you're happy everything works.
-
-
-What about the images?
---------------------------
-
-You can find most of our datasets on the `galaxy-datasets repo `_.
-The datasets are self-downloading and have loading functions for both PyTorch and TensorFlow.
diff --git a/docs/guides/advanced_finetuning.rst b/docs/guides/advanced_finetuning.rst
index 59a59aff..767703c7 100644
--- a/docs/guides/advanced_finetuning.rst
+++ b/docs/guides/advanced_finetuning.rst
@@ -4,64 +4,34 @@ Advanced Finetuning
=====================
-Zoobot includes the :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
-classes to help you finetune Zoobot on classification or decision tree problems, respectively.
-But what about other problems, like regression or object detection?
+Zoobot includes the :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier`, :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor`, and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
+classes to help you finetune Zoobot on classification, regression, or decision tree problems, respectively.
+But what about other problems, like object detection?
Here's how to integrate pretrained Zoobot models into your own code.
Using Zoobot's Encoder Directly
------------------------------------
-To get Zoobot's encoder, load the model and access the .encoder attribute:
+To get Zoobot's encoder, load any Finetuneable class and grab the encoder attribute:
.. code-block:: python
- model = ZoobotTree.load_from_checkpoint(pretrained_checkpoint_loc)
+ model = FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
encoder = model.encoder
- model = FinetuneableZoobotClassifier.load_from_checkpoint(finetuned_checkpoint_loc)
- encoder = model.encoder
-
- # for ZoobotTree, there's also a utility function to do this in one line
- encoder = finetune.load_pretrained_encoder(pretrained_checkpoint_loc)
-
-:class:`zoobot.pytorch.estimators.define_model.ZoobotTree`, :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
-all have ``.encoder`` and ``.head`` attributes. These are the plain PyTorch (Sequential) models used for encoding or task predictions.
-The Zoobot classes simply wrap these with instructions for training, logging, checkpointing, and so on.
-
-You can use the encoder separately like any PyTorch Sequential for any machine learning task. We did this to `add contrastive learning `_. Go nuts.
-
-
-Subclassing FinetuneableZoobotAbstract
----------------------------------------
-
-If you'd like to finetune Zoobot on a new task that isn't classification or vote counts,
-you could instead subclass :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract`.
-This is less general but avoids having to write out your own finetuning training code in e.g. PyTorch Lightning.
-
-For example, to make a regression version:
+or, because Zoobot encoders are `timm` models, you can just directly use `timm`:
.. code-block:: python
-
- class FinetuneableZoobotRegression(FinetuneableZoobotAbstract):
-
- def __init__(
- self,
- foo,
- **super_kwargs
- ):
+ import timm
- super().__init__(**super_kwargs)
+ encoder = timm.create_model('hf_hub:mwalmsley/zoobot-encoder-convnext_nano', pretrained=True, num_classes=0)
- self.foo = foo
- self.loss = torch.nn.MSELoss()
- self.head = torch.nn.Sequential(...)
- # see zoobot/pytorch/training/finetune.py for more examples and all methods required
+You can use it like any other `timm` model. For example, we did this to `add contrastive learning `_. Good luck!
-You can then finetune this new class just as with e.g. FinetuneableZoobotClassifier.
+If you don't need to change the encoder and just want representations, see below.
Extracting Frozen Representations
@@ -71,27 +41,21 @@ Once you've finetuned to your survey, or if you're using a pretrained survey, (S
the representations can be stored as frozen vectors and used as features.
We use this at Galaxy Zoo to power our upcoming similary search and anomaly-finding tools.
-As above, we can get Zoobot's encoder from the .encoder attribute:
-
-.. code-block:: python
-
- # can load from either ZoobotTree (if trained from scratch)
- # or FinetuneableZoobotTree (if finetuned)
- encoder = finetune.FinetuneableZoobotTree.load_from_checkpoint(checkpoint_loc).encoder
-
-``encoder`` is a PyTorch Sequential object, so we could use ``encoder.predict()`` to calculate our representations.
+As above, we can get Zoobot's encoder from the .encoder attribute. We could use ``encoder.forward()`` to calculate our representations.
But then we'd have to deal with batching, looping, etc.
To avoid this boilerplate, Zoobot includes a PyTorch Lightning class that lets you pass ``encoder`` to the same :func:`zoobot.pytorch.predictions.predict_on_catalog.predict`
utility function used for making predictions with a full Zoobot model.
.. code-block:: python
+ from zoobot.pytorch.training import representations
+
# convert to simple pytorch lightning model
- model = representations.ZoobotEncoder(encoder=encoder, pyramid=False)
+ lightning_encoder = ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
predict_on_catalog.predict(
catalog,
- model,
+ lightning_encoder,
n_samples=1,
label_cols=label_cols,
save_loc=save_loc,
@@ -101,9 +65,41 @@ utility function used for making predictions with a full Zoobot model.
See `zoobot/pytorch/examples/representations `_ for a full working example.
-We plan on adding precalculated representations for all our DESI galaxies - but we haven't done it yet. Sorry.
-Please raise an issue if you really need these.
+We have precalculated representations for all our DESI galaxies, and soon for HSC as well.
+See :doc:`/science_data`.
-The representations are typically quite high-dimensional (1280 for EfficientNetB0) and therefore highly redundant.
+The representations are typically quite high-dimensional (e.g. 1280 for EfficientNetB0) and therefore highly redundant.
We suggest using PCA to compress them down to a more reasonable dimension (e.g. 40) while preserving most of the information.
This was our approach in the `Practical Morphology Tools paper `_.
+
+
+Subclassing FinetuneableZoobotAbstract
+---------------------------------------
+
+If you'd like to finetune Zoobot on a new task that isn't classification, regression, or vote counts,
+you could instead subclass :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract`.
+This lets you use our finetuning code with your own head and loss functions.
+
+Imagine there wasn't a regression version and you wanted to finetune Zoobot on a regression task. You could do:
+
+.. code-block:: python
+
+
+ class FinetuneableZoobotCustomRegression(FinetuneableZoobotAbstract):
+
+ def __init__(
+ self,
+ foo,
+ **super_kwargs
+ ):
+
+ super().__init__(**super_kwargs)
+
+ self.foo = foo
+ self.loss = torch.nn.SomeCrazyLoss()
+ self.head = torch.nn.Sequential(my_crazy_head)
+
+ # see zoobot/pytorch/training/finetune.py for more examples and all methods required
+
+You can then finetune this new class just as with e.g. :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor`.
+
diff --git a/docs/guides/choosing_parameters.rst b/docs/guides/choosing_parameters.rst
new file mode 100644
index 00000000..9cd4b337
--- /dev/null
+++ b/docs/guides/choosing_parameters.rst
@@ -0,0 +1,101 @@
+.. _choosing_parameters:
+
+Choosing Parameters
+=====================================
+
+All FinetuneableZoobot classes share a common set of parameters for controlling the finetuning process. These can have a big effect on performance.
+
+
+Finetuning is fast and easy to experiment with, so we recommend trying different parameters to see what works best for your dataset.
+This guide provides some explanation for each option.
+
+We list the key parameters below in rough order of importance.
+See :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract` for the full list of parameters.
+
+``learning_rate``
+...............................
+
+Learning rate sets how fast the model parameters are updated during training.
+Zoobot uses the adaptive optimizer ``AdamW``.
+Adaptive optimizers adjust the learning rate for each parameter based on the mean and variance of the previous gradients.
+This means you don't need to tune the learning rate as carefully as you would with a fixed learning rate optimizer like SGD.
+We find a learning of 1e-4 is a good starting point for most tasks.
+
+If you find the model is not learning, you can try increasing the learning rate.
+If you see the model loss is varying wildly, or the train loss decreases much faster than the validation loss (overfitting), you can try decreasing the learning rate.
+Increasing ``n_blocks`` (below) often requires a lower learning rate, as the model will adjust more parameters for each batch.
+
+
+``n_blocks``
+...............................
+
+Deep learning models are often divided into blocks of layers.
+For example, a ResNet model might have 4 blocks, each containing a number of convolutional layers.
+The ``n_blocks`` parameter specifies how many of these blocks to finetune.
+
+By default, ``n_blocks=0``, and so only the head is finetuned.
+This is a good choice when you have a small dataset or when you want to avoid overfitting.
+Finetuning only the head is sometimes called transfer learning.
+It's equivalent to calculating representations with the pretrained model and then training a new one-layer model on top of those representations.
+
+You can experiment with increasing ``n_blocks`` to finetune more of the model.
+This works best for larger datasets (typically more than 1k examples).
+To finetune the full model, keep increasing ``n_blocks``; Zoobot will raise an error if you try to finetune more blocks than the model has.
+Our recommended encoder, ``ConvNext``, has 5 blocks.
+
+
+``lr_decay``
+...............................
+
+The common intuition for deep learning is that lower blocks (near the input) learn simple general features and higher blocks (near the output) learn more complex features specific to your task.
+It is often useful to adjust the learning rate to be lower for lower blocks, which have already been pretrained to recognise simple galaxy features.
+
+Learning rate decay reduces the learning rate by block.
+For example, with ``learning_rate=1e-4`` and ``lr_decay=0.75`` (the default):
+
+* The highest block has a learning rate of 1e-4 * (0.75^0) = 1e-4
+* The second-highest block has a learning rate of 1e-4 * (0.75^1) = 7.5e-5
+* The third-highest block has a learning rate of 1e-4 * (0.75^2) = 5.6e-5
+
+and so on.
+
+Decreasing ``lr_decay`` will exponentially decrease the learning rate for lower blocks.
+
+In the extreme cases:
+
+* Setting ``learning_rate=0`` will disable learning in all blocks except the first block (0^0=1), equivalent to ``n_blocks=1``.
+* Setting ``lr_decay=1`` will give all blocks the same learning rate.
+
+The head always uses the full learning rate.
+
+``weight_decay``
+...............................
+
+Weight decay is a regularization term that penalizes large weights.
+When using Zoobot's default ``AdamW`` optimizer, it is closely related to L2 regularization, though there's some subtlety - see https://arxiv.org/abs/1711.05101.
+Increasing weight decay will increase the penalty on large weights, which can help prevent overfitting, but may slow or even stop training.
+By default, Zoobot uses a small weight decay of 0.05.
+
+
+``dropout_prob``
+...............................
+
+Dropout is a regularization technique that randomly sets some activations to zero during training.
+Similarly to weight decay, dropout can help prevent overfitting.
+Zoobot uses a dropout probability of 0.5 by default.
+
+
+``cosine_schedule`` and friends
+.................................
+
+Gradually reduce the learning rate during training can slightly improve results by finding a better minimum near convergence.
+This process is called learning rate scheduling.
+Zoobot includes a cosine learning rate schedule, which reduces the learning rate according to a cosine function.
+
+The cosine schedule is controlled by the following parameters:
+
+* ``cosine_schedule`` to enable the scheduler.
+* ``warmup_epochs`` to linearly increase the learning rate from 0 to ``learning_rate`` over the first ``warmup_epochs`` epochs, before applying cosine scheduling.
+* ``max_cosine_epochs`` sets how many epochs it takes to decay to the final learning rate (below). Warmup epochs don't count.
+* ``max_learning_rate_reduction_factor`` controls the final learning rate (``learning_rate`` * ``max_learning_rate_reduction_factor``).
+
\ No newline at end of file
diff --git a/docs/guides/finetuning.rst b/docs/guides/finetuning.rst
index 1ab59003..bce4fb56 100755
--- a/docs/guides/finetuning.rst
+++ b/docs/guides/finetuning.rst
@@ -30,12 +30,10 @@ Examples
Zoobot includes many working examples of finetuning:
-- `Google Colab notebook `__ (for binary classification in the cloud)
+- `Google Colab notebook `__ (recommended starting point)
- `finetune_binary_classification.py `__ (script version of the Colab notebook)
- `finetune_counts_full_tree.py `__ (for finetuning on a complicated GZ-style decision tree)
-There are also `examples `__ with the TensorFlow version of Zoobot. But this is no longer actively developed so we strongly suggest using the PyTorch version if possible.
-
Below, for less familiar readers, we walk through the ``finetune_binary_classification.py`` example in detail.
Background
@@ -60,12 +58,12 @@ These files are called checkpoints (like video game save files - computer scient
.. code-block:: python
model = finetune.FinetuneableZoobotClassifier(
- checkpoint_loc=checkpoint_loc, # loads weights from here
+ name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', # which pretrained model to download
num_classes=2,
n_layers=0
)
-You can download a checkpoint file from :ref:`datanotes`.
+You can see the list of pretrained models at :doc:`/pretrained_models`.
What about the other arguments?
When loading the checkpoint, FinetuneableZoobotClassifier will automatically change the head layer to suit a classification problem (hence, ``Classifier``).
diff --git a/docs/guides/guides.rst b/docs/guides/guides.rst
deleted file mode 100755
index e5ab3399..00000000
--- a/docs/guides/guides.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-
-Guides
-======
-
-Below are some practical guides for tasks that we hope Zoobot will be helpful for.
-
-.. toctree::
- :maxdepth: 2
-
- /guides/finetuning
- /guides/advanced_finetuning
- /guides/training_on_vote_counts
- /guides/how_the_code_fits_together
- /guides/pytorch_or_tensorflow
-
-If you'd prefer worked examples, you can find those under `zoobot/pytorch/examples `_ and `zoobot/tensorflow/examples `_.
-
-There's also this `Colab notebook `_ demonstrating finetuning which you can run in the cloud (with free access to a powerful GPU, courtesy of Google Research)
diff --git a/docs/guides/how_the_code_fits_together.rst b/docs/guides/how_the_code_fits_together.rst
index 6bb8109e..437bcbc7 100644
--- a/docs/guides/how_the_code_fits_together.rst
+++ b/docs/guides/how_the_code_fits_together.rst
@@ -6,37 +6,48 @@ How the Code Fits Together
The Zoobot package has many classes and methods.
This guide aims to be a map summarising how they fit together.
-.. note:: For simplicity, we will only consider the PyTorch version (see :ref:`pytorch_or_tensorflow`).
-
-Defining PyTorch Models
+The Map
-------------------------
-The deep learning part is the simplest piece.
-``define_model.py`` has functions to that define pure PyTorch ``nn.Modules`` (a.k.a. models).
-
-Encoders (a.k.a. models that take an image and compress it to a representation vector) are defined using the third party library ``timm``.
-Specifically, ``timm.create_model(architecture_name)`` is used to get the EfficientNet, ResNet, ViT, etc. architectures used to encode our galaxy images.
-This is helpful because defining complicated architectures becomes someone else's job (thanks, Ross Wightman!)
+The Zoobot package has two roles:
-Heads (a.k.a. models that take a representation vector and make a prediction) are defined using ``torch.nn.Sequential``.
-The function :func:`zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head`, for example, returns the custom head used to predict vote counts (see :ref:`training_on_vote_counts`).
+1. **Finetuning**: ``pytorch/training/finetune.py`` is the heart of the package. You will use these classes to load pretrained models and finetune them on new data.
+2. **Training from Scratch** ``pytorch/estimators/define_model.py`` and ``pytorch/training/train_with_pytorch_lightning.py`` create and train the Zoobot models from scratch. These are *not required* for finetuning and will eventually be migrated out.
-The encoders and heads in ``define_model.py`` are used for both training from scratch and finetuning
+Let's zoom in on the finetuning part.
-Training with PyTorch Lightning
+Finetuning with Zoobot Classes
--------------------------------
-PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU).
-We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else's job as well.
-The core Zoobot classes you'll use - :class:`ZoobotTree `, :class:`FinetuneableZoobotClassifier ` and :class:`FinetuneableZoobotTree ` -
+There are three Zoobot classes for finetuning:
+
+1. :class:`FinetuneableZoobotClassifier ` for classification tasks (including multi-class).
+2. :class:`FinetuneableZoobotRegressor ` for regression tasks (including on a unit interval e.g. a fraction).
+3. :class:`FinetuneableZoobotTree ` for training on a tree of labels (e.g. Galaxy Zoo vote counts).
+
+Each user-facing class is actually a subclass of a non-user-facing abstract class, :class:`FinetuneableZoobotAbstract `.
+:class:`FinetuneableZoobotAbstract ` has specifying how to finetune a general PyTorch model,
+which the user-facing classes inherit.
+
+`FinetuneableZoobotAbstract ` controls the core finetuning process: loading a model, accepting arguments controlling the finetuning process, and running the finetuning.
+The user-facing class adds features specific to that type of task. For example, :class:`FinetuneableZoobotClassifier ` adds additional arguments like `num_classes`.
+It also specifies an appropriate head and a loss function.
+
+
+
+Finetuning with PyTorch Lightning
+-----------------------------------
+
+
are all "LightningModule" classes.
These classes have (custom) methods like ``training_step``, ``validation_step``, etc., which specify what should happen at each training stage.
-:class:`FinetuneableZoobotClassifier ` and :class:`FinetuneableZoobotTree `
-are actually subclasses of a non-user-facing abstract class, :class:`FinetuneableZoobotAbstract `.
-:class:`FinetuneableZoobotAbstract ` has specifying how to finetune a general PyTorch model,
-which `FinetuneableZoobotClassifier ` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree` inherit.
+
+Zoobot is written in PyTorch, a popular deep learning library for Python.
+PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU).
+We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else's job.
+
:class:`ZoobotTree ` is similar to :class:`FinetuneableZoobotAbstract ` but has methods for training from scratch.
@@ -66,28 +77,17 @@ Slightly confusingly, Lightning's ``Trainer`` can also be used to make predictio
and that's how we make predictions with :func:`zoobot.pytorch.predictions.predict_on_catalog.predict`.
-Loading Data
---------------------------
-
-You might notice ``datamodule`` in the examples above.
-Zoobot often includes code like:
-
-.. code-block:: python
+As you can see, there's quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you're using.
- from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
- datamodule = GalaxyDataModule(
- train_catalog=train_catalog,
- val_catalog=val_catalog,
- test_catalog=test_catalog,
- batch_size=batch_size,
- # ...
- )
+.. The deep learning part is the simplest piece.
+.. ``define_model.py`` has functions to that define pure PyTorch ``nn.Modules`` (a.k.a. models).
-Note the import - Zoobot actually doesn't have any code for loading data!
-That's in the separate repository `mwalmsley/galaxy-datasets `.
+.. Encoders (a.k.a. models that take an image and compress it to a representation vector) are defined using the third party library ``timm``.
+.. Specifically, ``timm.create_model(architecture_name)`` is used to get the EfficientNet, ResNet, ViT, etc. architectures used to encode our galaxy images.
+.. This is helpful because defining complicated architectures becomes someone else's job (thanks, Ross Wightman!)
-``galaxy-datasets`` has custom code to turn catalogs of galaxies into the ``LightningDataModule``s that Lightning `expects https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html<>`_.
-These ``LightningDataModule``s themselves have attributes like ``.train_dataloader()`` and ``.predict_dataloader()`` that Lightning's ``Trainer`` object uses to demand data when training, making predictions, and so forth.
+.. Heads (a.k.a. models that take a representation vector and make a prediction) are defined using ``torch.nn.Sequential``.
+.. The function :func:`zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head`, for example, returns the custom head used to predict vote counts (see :ref:`training_on_vote_counts`).
-As you can see, there's quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you're using.
+.. The encoders and heads in ``define_model.py`` are used for both training from scratch and finetuning
diff --git a/docs/guides/loading_data.rst b/docs/guides/loading_data.rst
new file mode 100644
index 00000000..c6c74857
--- /dev/null
+++ b/docs/guides/loading_data.rst
@@ -0,0 +1,52 @@
+
+Loading Data
+--------------------------
+
+Using GalaxyDataModule
+=========================
+
+Zoobot often includes code like:
+
+.. code-block:: python
+
+ from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
+
+ datamodule = GalaxyDataModule(
+ train_catalog=train_catalog,
+ val_catalog=val_catalog,
+ test_catalog=test_catalog,
+ batch_size=batch_size,
+ label_cols=['is_cool_galaxy']
+ # ...
+ )
+
+Note the import - Zoobot actually doesn't have any code for loading data!
+That's in the separate repository `mwalmsley/galaxy-datasets `_.
+
+``galaxy-datasets`` has custom code to turn catalogs of galaxies into the ``LightningDataModule`` that Lightning `expects `_.
+Each ``LightningDataModule`` has attributes like ``.train_dataloader()`` and ``.predict_dataloader()`` that Lightning's ``Trainer`` object uses to demand data when training, making predictions, and so forth.
+
+You can pass ``GalaxyDataModule`` train, val, test and predict catalogs. Each catalog needs the columns:
+
+* ``file_loc``: the path to the image file
+* ``id_str``: a unique identifier for the galaxy
+* plus any columns for labels, which you will specify with ``label_cols``. Setting ``label_cols=None`` will load the data without labels (returning batches of (image, id_str)).
+
+``GalaxyDataModule`` will load the images from disk and apply any transformations you specify. Specify transforms one of three ways:
+
+* through the `default arguments `_ of ``GalaxyDataModule`` (e.g. ``GalaxyDataModule(resize_after_crop=(128, 128))``)
+* through a torchvision or albumentations ``Compose`` object e.g. ``GalaxyDataModule(custom_torchvision_transforms=Compose([RandomHorizontalFlip(), RandomVerticalFlip()]))``
+* through a tuple of ``Compose`` objects. The first element will be used for the train dataloaders, and the second for the other dataloaders.
+
+Using the default arguments is simplest and should work well for loading Galaxy-Zoo-like ``jpg`` images. Passing Compose objects offers full customization (short of writing your own ``LightningDataModule``). On that note...
+
+I Want To Do It Myself
+========================
+
+Using ``galaxy-datasets`` is optional. Zoobot is designed to work with any PyTorch ``LightningDataModule`` that returns batches of (images, labels).
+And advanced users can pass data to Zoobot's encoder however they like (see :doc:`advanced_finetuning`).
+
+Images should be PyTorch tensors of shape (batch_size, channels, height, width).
+Values should be floats normalized from 0 to 1 (though in practice, Zoobot can handle other ranges provided you use end-to-end finetuning).
+If you are presenting flux values, you should apply a dynamic range rescaling like ``np.arcsinh`` before normalizing to [0, 1].
+Galaxies should appear large and centered in the image.
diff --git a/docs/guides/pytorch_or_tensorflow.rst b/docs/guides/pytorch_or_tensorflow.rst
deleted file mode 100644
index 9c5bb244..00000000
--- a/docs/guides/pytorch_or_tensorflow.rst
+++ /dev/null
@@ -1,40 +0,0 @@
-.. _pytorch_or_tensorflow:
-
-
-
-PyTorch or TensorFlow?
-===========================
-
-.. warning:: You should use the PyTorch version if possible. This is being actively developed and has the latest features.
-
-Zoobot is really two separate sets of code: `zoobot/pytorch `_ and `zoobot/tensorflow `_.
-They can both train the same EfficientNet model architecture on the same Galaxy Zoo data in the same way, for extracting representations and for finetuning - but they use different underlying deep learning frameworks to do so.
-
-We originally created two versions of Zoobot so that astronomers can use their preferred framework.
-But maintaining two almost entirely separate sets of code is too much work for our current resources (Mike's time, basically).
-Going forward, the PyTorch version will be actively developed and gain new features, while the TensorFlow version will be kept up-to-date but will not otherwise improve.
-
-Tell Me More About What's Different
--------------------------------------
-
-The TensorFlow version was the original version.
-It was used for the `GZ DECaLS catalog `_ and the `Practical Morphology Tools `_ paper.
-You can train EfficientNetB0 and achieve the same performance as with PyTorch (see the "benchmarks folder").
-You can also finetune the trained model, although the process is slightly clunkier.
-
-The PyTorch version was introduced to support other researchers and to integrate with Bootstrap Your Own Latent for the `Towards Foundation Models `_ paper.
-This version is actively developed and includes the latest features.
-
-PyTorch-specific features include:
-
-- Any architecture option from timm (including ResNet and Max-ViT)
-- Improved interface for easy finetuning
-- Layerwise learning rate decay during finetuning
-- Integration with AstroAugmentations (courtesy Micah Bowles) for custom astronomy image augmentations
-- Per-question loss tracking on WandB
-
-
-Can I have a JAX version?
-----------------------------
-
-Only if you build it yourself.
diff --git a/docs/index.rst b/docs/index.rst
index a3a4cfc1..4f062cf3 100755
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -5,7 +5,7 @@ Zoobot Documentation
====================
Zoobot makes it easy to finetune a state-of-the-art deep learning classifier to solve your galaxy morphology problem.
-For example, you can finetune a classifier to find ring galaxies with `just a few hundred examples `_.
+For example, you can finetune a classifier to find ring galaxies with `just a few hundred examples `_.
.. figure:: finetuning_rings.png
:alt: Ring galaxies found using Zoobot
@@ -15,19 +15,32 @@ For example, you can finetune a classifier to find ring galaxies with `just a fe
The easiest way to learn to use Zoobot is simply to use Zoobot.
We suggest you start with our worked examples.
-The `Colab notebook `_ is the fastest way to get started.
-See the README for many scripts that you can run and adapt locally.
+* This `Colab notebook `_ will walk you through using Zoobot to classify galaxy images.
+* There's a similar `notebook `_ for using Zoobot for regression on galaxy images.
-Guides
+For more explanation, read on.
+
+User Guides
-------------
-If you'd like more explanation and context, we've written these guides.
+These introductory guides add context to the demo Colab notebooks.
+
+.. toctree::
+ :maxdepth: 1
+
+ /guides/finetuning
+ /guides/choosing_parameters
+ /guides/loading_data
+ /guides/training_on_vote_counts
+
+These advanced guides explain how to integrate Zoobot into other ML projects.
.. toctree::
:maxdepth: 2
- /guides/guides
+ /guides/advanced_finetuning
+ /guides/how_the_code_fits_together
Pretrained Models
------------------
@@ -37,30 +50,37 @@ To choose and download a pretrained model, see here.
.. toctree::
:maxdepth: 2
- data_notes
+ pretrained_models
-API reference
---------------
+Science-Ready Data
+------------------
-Look here for information on a specific function, class or
-method.
+You can find our science outputs (e.g. morphology catalogs, precalculated representations) here.
.. toctree::
:maxdepth: 2
- autodoc/api
+ science_data
+
+We are working on releasing the compiled GZ Evo dataset and will update this page when it is available.
+Estimated public release is Q4 2024. Please reach out if you'd like early access.
-.. You do not need to be a machine learning expert to use Zoobot.
-.. Zoobot includes :ref:`components ` for common tasks like loading images, managing training, and making predictions.
-.. You simply need to assemble these together.
+API reference
+--------------
-.. .. toctree::
-.. :maxdepth: 2
+We've added docstrings to all the key methods you might use. Feel free to check the code or reach out if you have questions.
-.. components/overview
+.. toctree::
+ :maxdepth: 4
+
+ autodoc/pytorch
+.. different level to not expand schema too much
+.. toctree::
+ :maxdepth: 3
+ autodoc/shared
.. Indices
@@ -78,6 +98,7 @@ method.
.. To build:
.. install sphinx https://www.sphinx-doc.org/en/master/usage/installation.html is confusing, you can just use pip install -U sphinx
+.. and pip install furo
.. run from in docs folder: make html
.. can also check docs with
diff --git a/docs/pretrained_models.rst b/docs/pretrained_models.rst
new file mode 100755
index 00000000..84a1e34b
--- /dev/null
+++ b/docs/pretrained_models.rst
@@ -0,0 +1,122 @@
+.. pretrainedmodels:
+
+Pretrained Models
+------------------
+
+Loading Models
+==========================
+
+Pretrained models are available via HuggingFace (|:hugging:|) with
+
+.. code-block:: python
+
+ from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier
+ # or FinetuneableZoobotRegressor, or FinetuneableZoobotTree
+
+ model = FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
+
+For more options (e.g. loading the ``timm`` encoder directly) see :doc:`guides/advanced_finetuning`.
+
+Available Models
+==========================
+
+Zoobot includes weights for the following pretrained models:
+
+
+.. list-table::
+ :widths: 70 35 35 35 35
+ :header-rows: 1
+
+ * - Architecture
+ - Parameters
+ - Test loss
+ - Finetune
+ - HF |:hugging:|
+ * - ConvNeXT-Nano
+ - 15.6M
+ - 19.23
+ - Yes
+ - `Link `__
+ * - ConvNeXT-Small
+ - 58.5M
+ - 19.14
+ - Yes
+ - `Link `__
+ * - ConvNeXT-Base
+ - 88.6M
+ - **19.04**
+ - Yes
+ - `Link `__
+ * - ConvNeXT-Large
+ - 197.8M
+ - 19.09
+ - Yes
+ - `Link `__
+ * - MaxViT-Small
+ - 64.9M
+ - 19.20
+ - Yes
+ - `Link `__
+ * - MaxViT-Base
+ - 124.5
+ - 19.09
+ - Yes
+ - TODO
+ * - Max-ViT-Large
+ - 211.8M
+ - 19.18
+ - Yes
+ - `Link `__
+ * - EfficientNetB0
+ - 5.33M
+ - 19.48
+ - Yes
+ - `Link `__
+ * - EfficientNetV2-S
+ - 48.3M
+ - 19.33
+ - Yes
+ - `Link `__
+ * - ResNet18
+ - 11.7M
+ - 19.83
+ - Yes
+ - `Link `__
+ * - ResNet50
+ - 25.6M
+ - 19.43
+ - Yes
+ - `Link `__
+
+
+.. note::
+
+ Missing a model you need? Reach out! There's a good chance we can train any model supported by `timm `_.
+
+
+Which model should I use?
+===========================
+
+We suggest starting with ConvNeXT-Nano for most users.
+ConvNeXT-Nano performs very well while still being small enough to train on a single gaming GPU.
+You will be able to experiment quickly.
+
+For maximum performance, you could swap ConvNeXT-Nano for ConvNeXT-Small or ConvNeXT-Base.
+MaxViT-Base also performs well and includes an ingenious attention mechanism, if you're interested in that.
+All these models are much larger and need cluster-grade GPUs (e.g. V100 or above).
+
+Other models are included for reference or as benchmarks.
+EfficientNetB0 is equivalent to the model used in the GZ DECaLS and GZ DESI papers.
+ResNet18 and ResNet50 are classics of the genre and may be useful for comparison or as part of other frameworks (like as an `object detection backbone `_).
+
+
+How were the models trained?
+===============================
+
+The models were trained as part of the report `Scaling Laws for Galaxy Images `_.
+This report systematically investigates how increasing labelled galaxy data and model size improves performance
+and leads to adaptable models that generalise well to new tasks and new telescopes.
+
+All models are trained on the GZ Evo dataset,
+which includes 820k images and 100M+ volunteer votes drawn from every major Galaxy Zoo campaign: GZ2, GZ UKIDSS (unpublished), GZ Hubble, GZ CANDELS, GZ DECaLS/DESI, and GZ Cosmic Dawn (HSC, in prep.).
+They learn an adaptable representation of galaxy images by training to answer every Galaxy Zoo question at once.
diff --git a/docs/science_data.rst b/docs/science_data.rst
new file mode 100644
index 00000000..569a3110
--- /dev/null
+++ b/docs/science_data.rst
@@ -0,0 +1,59 @@
+.. sciencedata:
+
+Science Data
+-------------
+
+The goal of Zoobot is to do science. Here are some science-ready datasets created with Zoobot.
+
+Precalulated Representations
+=============================
+
+.. warning::
+
+ New for Zoobot v2! We're really excited to see what you build. Reach out for help.
+
+Zoobot v2 now includes precalculated representations for galaxies in the Galaxy Zoo DESI data release.
+Download `here `_ (2.5GB)
+
+You could use these to power a similarity search, anomaly recommendation system, the vision part of a multi-modal model,
+or really anything else that needs a short vector summarizing the morphology in a galaxy image.
+
+
+
+
+.. list-table::
+ :widths: 35 35 35 35 35 35
+ :header-rows: 1
+
+ * - id_str
+ - ra
+ - dec
+ - feat_pca_0
+ - feat_pca_1
+ - ...
+ * - 303240_2499
+ - 4.021870
+ - 3.512972
+ - 0.257407
+ - -7.414328
+ - ...
+
+``id_str`` is the unique identifier for the galaxy in the DESI Legacy Surveys DR8 release and can be crossmatched with the GZ DESI catalog (below) ``dr8_id`` key.
+It is formed with ``{brickid}_{objid}`` where brickid is the unique identifier for the brick in the Legacy Surveys and objid is the unique identifier for the object in the brick.
+``RA`` and ``Dec`` are in degrees.
+The PCA features are the first 40 principal components representation (which is otherwse impractically large to work with).
+
+
+Galaxy Zoo Morphology
+=======================
+
+Zoobot was used to create a detailed morphology catalog for every (extended, brighter than r=19) galaxy in the DESI Legacy Surveys (8.7M galaxies).
+The catalog and schema are available from `Zenodo `_.
+For new users, we suggest starting with the ``gz_desi_deep_learning_catalog_friendly.parquet`` catalog file.
+
+We previously used Zoobot to create a similar catalog for `DECaLS DR5 `_.
+This has now been superceded by the GZ DESI catalog above (which includes the same galaxies, and many more).
+
+We aim to provide both representations and an updated morphology catalog for DESI-LS DR10, but we need to redownload all the images first |:neutral_face:|.
+
+Future catalogs will include morphology measurements for HSC, JWST, and Euclid galaxies (likely in that order).
diff --git a/setup.py b/setup.py
index 0da18e1b..dde24b50 100755
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setuptools.setup(
name="zoobot",
- version="1.0.5",
+ version="2.0.0",
author="Mike Walmsley",
author_email="walmsleymk1@gmail.com",
description="Galaxy morphology classifiers",
@@ -20,56 +20,67 @@
"Environment :: GPU :: NVIDIA CUDA"
],
packages=setuptools.find_packages(),
- python_requires=">=3.8", # recommend 3.9 for new users. TF needs >=3.7.2, torchvision>=3.8
+ python_requires=">=3.9", # bumped to 3.9 for typing
extras_require={
- 'pytorch_cpu': [
+ 'pytorch-cpu': [
# A100 GPU currently only seems to support cuda 11.3 on manchester cluster, let's stick with this version for now
# very latest version wants cuda 11.6
- 'torch == 1.12.1+cpu',
- 'torchvision == 0.13.1+cpu',
- 'torchaudio == 0.12.1',
- 'pytorch-lightning >= 2.0.0',
+ 'torch == 2.1.0+cpu',
+ 'torchvision == 0.16.0+cpu',
+ 'torchaudio >= 2.1.0',
+ 'lightning >= 2.0.0',
# 'simplejpeg',
'albumentations',
- 'pyro-ppl == 1.8.0',
+ 'pyro-ppl >= 1.8.6',
'torchmetrics == 0.11.0',
- 'timm == 0.6.12'
+ 'timm == 0.9.10'
],
- 'pytorch_m1': [
+ 'pytorch-m1': [
# as above but without the +cpu (and the extra-index-url in readme has no effect)
# all matching pytorch versions for an m1 system will be cpu
- 'torch == 1.12.1',
- 'torchvision == 0.13.1',
- 'torchaudio == 0.12.1',
- 'pytorch-lightning >= 2.0.0',
+ 'torch == 2.1.0',
+ 'torchvision == 0.16.0',
+ 'torchaudio >= 2.1.0',
+ 'lightning >= 2.0.0',
'albumentations',
- 'pyro-ppl == 1.8.0',
+ 'pyro-ppl >= 1.8.6',
'torchmetrics == 0.11.0',
- 'timm == 0.6.12'
+ 'timm >= 0.9.10'
],
# as above but without pytorch itself
# for GPU, you will also need e.g. cudatoolkit=11.3, 11.6
# https://pytorch.org/get-started/previous-versions/#v1121
- 'pytorch_cu113': [
- 'torch == 1.12.1+cu113',
- 'torchvision == 0.13.1+cu113',
- 'torchaudio == 0.12.1',
- 'pytorch-lightning >= 2.0.0',
+ 'pytorch-cu118': [
+ 'torch == 2.1.0+cu118',
+ 'torchvision == 0.16.0+cu118',
+ 'torchaudio >= 2.1.0',
+ 'lightning >= 2.0.0',
'albumentations',
- 'pyro-ppl == 1.8.0',
+ 'pyro-ppl >= 1.8.6',
'torchmetrics == 0.11.0',
- 'timm == 0.6.12'
- ],
- 'pytorch_colab': [
+ 'timm >= 0.9.10'
+ ], # exactly as above, but _cu121 for cuda 12.1 (the current default)
+ 'pytorch-cu121': [
+ 'torch == 2.1.0+cu121',
+ 'torchvision == 0.16.0+cu121',
+ 'torchaudio >= 2.1.0',
+ 'lightning >= 2.0.0',
+ 'albumentations',
+ 'pyro-ppl >= 1.8.6',
+ 'torchmetrics == 0.11.0',
+ 'timm >= 0.9.10'
+ ],
+ 'pytorch-colab': [
# colab includes pytorch already
- 'pytorch-lightning >= 2.0.0',
+ 'lightning >= 2.0.0',
'albumentations',
'pyro-ppl>=1.8.0',
'torchmetrics==0.11.0',
- 'timm == 0.6.12'
+ 'timm >= 0.9.10',
+ 'galaxy_datasets == 0.0.17'
],
# TODO may add narval/Digital Research Canada config
- 'tensorflow': [
+ 'tensorflow': [ # WARNING now deprecated
'tensorflow == 2.10.0', # 2.11.0 turns on XLA somewhere which then fails on multi-GPU...TODO
'keras_applications',
'tensorflow_probability == 0.18.0', # 0.19 requires tf 2.11
@@ -86,7 +97,8 @@
'Sphinx',
'sphinxcontrib-napoleon',
'furo',
- 'docutils<0.18'
+ 'docutils<0.18',
+ 'sphinxemoji'
]
},
install_requires=[
@@ -102,7 +114,9 @@
'pyarrow', # to read parquet, which is very handy for big datasets
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
+ 'webdataset', # for reading webdataset files
+ 'huggingface_hub', # login may be required
'setuptools', # no longer pinned
- 'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
+ 'galaxy-datasets>=0.0.17' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
]
)
diff --git a/tests/pytorch/test_define_model.py b/tests/pytorch/test_define_model.py
index 3805777d..f7628d22 100644
--- a/tests/pytorch/test_define_model.py
+++ b/tests/pytorch/test_define_model.py
@@ -10,6 +10,7 @@ def schema():
def test_ZoobotTree_init(schema):
model = define_model.ZoobotTree(
output_dim=12,
- question_index_groups=schema.question_index_groups,
+ question_answer_pairs=schema.question_answer_pairs,
+ dependencies=schema.dependencies
)
diff --git a/tests/pytorch/test_finetune_classifier.py b/tests/pytorch/test_finetune_classifier.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_from_hub.py b/tests/test_from_hub.py
new file mode 100644
index 00000000..159f22c5
--- /dev/null
+++ b/tests/test_from_hub.py
@@ -0,0 +1,43 @@
+import pytest
+
+import timm
+import torch
+
+
+def test_get_encoder():
+ model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True)
+ assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280)
+
+
+def test_get_finetuned():
+ # checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning
+ # checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually
+
+ from huggingface_hub import hf_hub_download
+
+ REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal"
+ FILENAME = "FinetuneableZoobotClassifier.ckpt"
+
+ downloaded_loc = hf_hub_download(
+ repo_id=REPO_ID,
+ filename=FILENAME,
+ )
+ from zoobot.pytorch.training import finetune
+ model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
+ assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
+
+def test_get_finetuned_class_method():
+
+ from zoobot.pytorch.training import finetune
+
+ model = finetune.FinetuneableZoobotClassifier.load_from_name('mwalmsley/zoobot-finetuned-is_tidal', map_location='cpu')
+ assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
+
+# def test_get_finetuned_from_local():
+# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt'
+# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt'
+
+# from zoobot.pytorch.training import finetune
+# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam.
+# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', )
+# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
\ No newline at end of file
diff --git a/zoobot/pytorch/datasets/__init__.py b/zoobot/pytorch/datasets/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py
new file mode 100644
index 00000000..3156d145
--- /dev/null
+++ b/zoobot/pytorch/datasets/webdatamodule.py
@@ -0,0 +1,250 @@
+import os
+from collections import defaultdict
+from typing import Callable
+import logging
+import torch.utils.data
+import numpy as np
+import pytorch_lightning as pl
+from itertools import islice
+
+import webdataset as wds
+
+from galaxy_datasets import transforms
+
+# https://github.com/webdataset/webdataset-lightning/blob/main/train.py
+class WebDataModule(pl.LightningDataModule):
+ def __init__(
+ self,
+ train_urls=None,
+ val_urls=None,
+ test_urls=None,
+ predict_urls=None,
+ label_cols=None,
+ # hardware
+ batch_size=64,
+ num_workers=4,
+ prefetch_factor=4,
+ cache_dir=None,
+ greyscale=False,
+ crop_scale_bounds=(0.7, 0.8),
+ crop_ratio_bounds=(0.9, 1.1),
+ resize_after_crop=224,
+ train_transform: Callable=None,
+ inference_transform: Callable=None
+ ):
+ super().__init__()
+
+ self.train_urls = train_urls
+ self.val_urls = val_urls
+ self.test_urls = test_urls
+ self.predict_urls = predict_urls
+
+ if train_urls is not None:
+ # assume the size of each shard is encoded in the filename as ..._{size}.tar
+ self.train_size = interpret_dataset_size_from_urls(train_urls)
+ if val_urls is not None:
+ self.val_size = interpret_dataset_size_from_urls(val_urls)
+ if test_urls is not None:
+ self.test_size = interpret_dataset_size_from_urls(test_urls)
+ if predict_urls is not None:
+ self.predict_size = interpret_dataset_size_from_urls(predict_urls)
+
+ self.label_cols = label_cols
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor
+
+ self.cache_dir = cache_dir
+
+ # could use mixin
+ self.greyscale = greyscale
+ self.resize_after_crop = resize_after_crop
+ self.crop_scale_bounds = crop_scale_bounds
+ self.crop_ratio_bounds = crop_ratio_bounds
+
+ self.train_transform = train_transform
+ self.inference_transform = inference_transform
+
+ for url_name in ['train', 'val', 'test', 'predict']:
+ urls = getattr(self, f'{url_name}_urls')
+ if urls is not None:
+ logging.info(f"{url_name} (before hardware splits) = {len(urls)} e.g. {urls[0]}", )
+
+ logging.info(f"batch_size: {self.batch_size}, num_workers: {self.num_workers}")
+
+ def make_image_transform(self, mode="train"):
+
+ augmentation_transform = transforms.default_transforms(
+ crop_scale_bounds=self.crop_scale_bounds,
+ crop_ratio_bounds=self.crop_ratio_bounds,
+ resize_after_crop=self.resize_after_crop,
+ pytorch_greyscale=self.greyscale,
+ to_float=False # True was wrong, webdataset rgb decoder already converts to 0-1 float
+ # TODO now changed on dev branch will be different for new model training runs
+ ) # A.Compose object
+
+ # logging.warning('Minimal augmentations for speed test')
+ # augmentation_transform = transforms.fast_transforms(
+ # resize_after_crop=self.resize_after_crop,
+ # pytorch_greyscale=not self.color
+ # ) # A.Compose object
+
+
+ def do_transform(img):
+ # img is 0-1 np array, intended for albumentations
+ assert img.shape[2] < 4 # 1 or 3 channels in shape[2] dim, i.e. numpy/pil HWC convention
+ # if not, check decode mode is 'rgb' not 'torchrgb'
+ # TODO could likely use torch ToTensorV2 here instead of returning np float32
+ # TODO or could transform in uint8 as I do with torchvision
+ # TODO need to generally rationalise my transform options
+ return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
+ return do_transform
+
+
+ def make_loader(self, urls, mode="train"):
+ logging.info('Making loader with mode {}'.format(mode))
+
+ dataset_size = getattr(self, f'{mode}_size')
+ if mode == "train":
+ shuffle = min(dataset_size, 5000)
+ else:
+ assert mode in ['val', 'test', 'predict'], mode
+ shuffle = 0
+
+ if self.train_transform is None:
+ logging.info('Using default transform')
+ decode_mode = 'rgb' # loads 0-1 np.array, for albumentations
+ transform_image = self.make_image_transform(mode=mode)
+ else:
+ logging.info('Ignoring hparams and using directly-passed transforms')
+ decode_mode = 'torchrgb' # tensor, for torchvision
+ transform_image = self.train_transform if mode == 'train' else self.inference_transform
+
+
+ transform_label = dict_to_label_cols_factory(self.label_cols)
+
+ dataset = wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func)
+ # https://webdataset.github.io/webdataset/multinode/
+ # WDS 'knows' which worker it is running on and selects a subset of urls accordingly
+
+ if shuffle > 0:
+ dataset = dataset.shuffle(shuffle)
+
+ dataset = dataset.decode(decode_mode)
+
+ if mode == 'predict':
+ if self.label_cols != ['id_str']:
+ logging.info('Will return images only')
+ # dataset = dataset.extract_keys('image.jpg').map(transform_image)
+ dataset = dataset.to_tuple('image.jpg').map_tuple(transform_image) # (im,) tuple. But map applied to all elements
+ # .map(get_first)
+ else:
+ logging.info('Will return id_str only')
+ dataset = dataset.to_tuple('__key__')
+ else:
+
+ dataset = (
+ dataset.to_tuple('image.jpg', 'labels.json')
+ .map_tuple(transform_image, transform_label)
+ )
+
+ # torch collate stacks dicts nicely while webdataset only lists them
+ # so use the torch collate instead
+ dataset = dataset.batched(self.batch_size, torch.utils.data.default_collate, partial=False)
+
+ # temp hack instead
+ if mode in ['train', 'val']:
+ assert dataset_size % self.batch_size == 0, (dataset_size, self.batch_size, dataset_size % self.batch_size)
+ # for test/predict, always single GPU anyway
+
+ # if mode == "train":
+ # ensure same number of batches in all clients
+ # loader = loader.ddp_equalize(dataset_size // self.batch_size)
+ # print("# loader length", len(loader))
+
+ loader = webdataset_to_webloader(dataset, self.num_workers, self.prefetch_factor)
+
+ return loader
+
+ def train_dataloader(self):
+ return self.make_loader(self.train_urls, mode="train")
+
+ def val_dataloader(self):
+ return self.make_loader(self.val_urls, mode="val")
+
+ def test_dataloader(self):
+ return self.make_loader(self.test_urls, mode="test")
+
+ def predict_dataloader(self):
+ return self.make_loader(self.predict_urls, mode="predict")
+
+def identity(x):
+ return x
+
+def nodesplitter_func(urls):
+ urls_to_use = list(wds.split_by_node(urls)) # rely on WDS for the hard work
+ rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
+ logging.debug(
+ f'''
+ Splitting urls within webdatamodule with WORLD_SIZE:
+ {world_size}, RANK: {rank}, WORKER: {worker} of {num_workers}\n
+ URLS: {len(urls_to_use)} (e.g. {urls_to_use[0]})\n\n)
+ '''
+ )
+ return urls_to_use
+
+def interpret_shard_size_from_url(url):
+ assert isinstance(url, str), TypeError(url)
+ return int(url.rstrip('.tar').split('_')[-1])
+
+def interpret_dataset_size_from_urls(urls):
+ return sum([interpret_shard_size_from_url(url) for url in urls])
+
+def get_first(x):
+ return x[0]
+
+def custom_collate(x):
+ if isinstance(x, list) and len(x) == 1:
+ x = x[0]
+ return torch.utils.data.default_collate(x)
+
+
+def webdataset_to_webloader(dataset, num_workers, prefetch_factor):
+ loader = wds.WebLoader(
+ dataset,
+ batch_size=None, # already batched
+ shuffle=False, # already shuffled
+ num_workers=num_workers,
+ pin_memory=True,
+ prefetch_factor=prefetch_factor
+ )
+
+ # loader.length = dataset_size // batch_size
+ return loader
+
+
+def dict_to_label_cols_factory(label_cols=None):
+ if label_cols is not None:
+ def label_transform(label_dict):
+ return torch.from_numpy(np.array([label_dict.get(col, 0) for col in label_cols])).double() # gets cast to int in zoobot loss
+ return label_transform
+ else:
+ return identity # do nothing
+
+def dict_to_filled_dict_factory(label_cols):
+ logging.info(f'label cols: {label_cols}')
+ # might be a little slow, but very safe
+ def label_transform(label_dict: dict):
+
+ # modifies inplace with 0 iff key missing
+ # [label_dict.setdefault(col, 0) for col in label_cols]
+
+ for col in label_cols:
+ label_dict[col] = label_dict.get(col, 0)
+
+ # label_dict_with_default = defaultdict(0)
+ # label_dict_with_default.update(label_dict)
+
+ return label_dict
+ return label_transform
\ No newline at end of file
diff --git a/zoobot/pytorch/datasets/webdataset_utils.py b/zoobot/pytorch/datasets/webdataset_utils.py
new file mode 100644
index 00000000..96d9afbd
--- /dev/null
+++ b/zoobot/pytorch/datasets/webdataset_utils.py
@@ -0,0 +1,217 @@
+import logging
+from typing import Union, Callable
+import os
+import cv2
+import json
+from itertools import islice
+import glob
+
+
+import albumentations as A
+
+import tqdm
+import numpy as np
+import pandas as pd
+from PIL import Image # necessary to avoid PIL.Image error assumption in web_datasets
+
+from galaxy_datasets import gz2
+from galaxy_datasets.transforms import default_transforms
+
+import webdataset as wds
+
+import zoobot.pytorch.datasets.webdatamodule as webdatamodule
+
+def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
+ for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
+ n_shards = len(catalog) // divisor
+ logging.info(n_shards)
+
+ catalog = catalog[:n_shards*divisor]
+ logging.info(len(catalog))
+
+ save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically
+
+ df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
+
+
+
+
+def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None, overwrite=False):
+
+ assert '.tar' in save_loc
+ df['id_str'] = df['id_str'].astype(str).str.replace('.', '_')
+ if sparse_label_df is not None:
+ logging.info(f'Using sparse label df: {len(sparse_label_df)}')
+ shard_dfs = np.array_split(df, n_shards)
+ logging.info(f'shards: {len(shard_dfs)}. Shard size: {len(shard_dfs[0])}')
+
+ transforms_to_apply = [
+ # below, for 224px fast training fast augs setup
+ # A.Resize(
+ # height=350, # now more aggressive, 65% crop effectively
+ # width=350, # now more aggressive, 65% crop effectively
+ # interpolation=cv2.INTER_AREA # slow and good interpolation
+ # ),
+ # A.CenterCrop(
+ # height=224,
+ # width=224,
+ # always_apply=True
+ # ),
+ # below, for standard training default augs
+ # small boundary trim and then resize expecting further 224px crop
+ # we want 0.7-0.8 effective crop
+ # in augs that could be 0.x-1.0, and here a pre-crop to 0.8 i.e. 340px
+ # but this would change the centering
+ # let's stick to small boundary crop and 0.75-0.85 in augs
+
+ # turn these off for current euclidized images, already 300x300
+ A.CenterCrop(
+ height=400,
+ width=400,
+ always_apply=True
+ ),
+ A.Resize(
+ height=300,
+ width=300,
+ interpolation=cv2.INTER_AREA # slow and good interpolation
+ )
+ ]
+ transform = A.Compose(transforms_to_apply)
+ # transform = None
+
+ for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)):
+ shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
+ if overwrite or not(os.path.isfile(shard_save_loc)):
+ if sparse_label_df is not None:
+ shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # type: ignore # auto-merge
+
+ assert not any(shard_df[label_cols].isna().max()) # type: ignore
+
+ # logging.info(shard_save_loc)
+ sink = wds.TarWriter(shard_save_loc)
+ for _, galaxy in shard_df.iterrows(): # type: ignore
+ try:
+ sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
+ except Exception as e:
+ logging.critical(galaxy)
+ raise(e)
+ sink.close()
+
+
+def galaxy_to_wds(galaxy: pd.Series, label_cols: Union[list[str],None]=None, metadata_cols: Union[list, None]=None, transform: Union[Callable, None]=None):
+
+ assert os.path.isfile(galaxy['file_loc']), galaxy['file_loc']
+ im = cv2.imread(galaxy['file_loc'])
+ # cv2 loads BGR for 'history', fix
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ assert not np.any(np.isnan(np.array(im))), galaxy['file_loc']
+
+ id_str = str(galaxy['id_str'])
+
+ if transform is not None:
+ im = transform(image=im)['image']
+
+ if label_cols is None:
+ labels = json.dumps({})
+ else:
+ labels = json.dumps(galaxy[label_cols].to_dict())
+
+ if metadata_cols is None:
+ metadata = json.dumps({})
+ else:
+ metadata = json.dumps(galaxy[metadata_cols].to_dict())
+
+ return {
+ "__key__": id_str, # silly wds bug where if __key__ ends .jpg, all keys get jpg. prepended?! use id_str instead
+ "image.jpg": im,
+ "labels.json": labels,
+ "metadata.json": metadata
+ }
+
+
+# just for debugging
+def load_wds_directly(wds_loc, max_to_load=3):
+
+ dataset = wds.WebDataset(wds_loc) \
+ .decode("rgb")
+
+ if max_to_load is not None:
+ sample_iterator = islice(dataset, 0, max_to_load)
+ else:
+ sample_iterator = dataset
+ for sample in sample_iterator:
+ logging.info(sample['__key__'])
+ logging.info(sample['image.jpg'].shape) # .decode(jpg) converts to decoded to 0-1 RGB float, was 0-255
+ logging.info(type(sample['labels.json'])) # automatically decoded
+
+
+# just for debugging
+def load_wds_with_augmentation(wds_loc):
+
+ augmentation_transform = default_transforms() # A.Compose object
+ def do_transform(img):
+ return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
+
+ dataset = wds.WebDataset(wds_loc) \
+ .decode("rgb") \
+ .to_tuple('image.jpg', 'labels.json') \
+ .map_tuple(do_transform, identity)
+
+ for sample in islice(dataset, 0, 3):
+ logging.info(sample[0].shape)
+ logging.info(sample[1])
+
+# just for debugging
+def load_wds_with_webdatamodule(save_loc, label_cols, batch_size=16, max_to_load=3):
+ wdm = webdatamodule.WebDataModule(
+ train_urls=save_loc,
+ val_urls=save_loc, # not used
+ # train_size=len(train_catalog),
+ # val_size=0,
+ label_cols=label_cols,
+ num_workers=1,
+ batch_size=batch_size
+ )
+ wdm.setup('fit')
+
+ if max_to_load is not None:
+ sample_iterator =islice(wdm.train_dataloader(), 0, max_to_load)
+ else:
+ sample_iterator = wdm.train_dataloader()
+ for sample in sample_iterator:
+ images, labels = sample
+ logging.info(images.shape)
+ # logging.info(len(labels)) # list of dicts
+ logging.info(labels.shape)
+
+
+def identity(x):
+ # no lambda to be pickleable
+ return x
+
+
+
+def make_mock_wds(save_dir: str, label_cols: list, n_shards: int, shard_size: int):
+ counter = 0
+ shards = [os.path.join(save_dir, f'mock_shard_{shard_n}_{shard_size}.tar') for shard_n in range(n_shards)]
+ for shard in shards:
+ sink = wds.TarWriter(shard)
+ for galaxy_n in range(shard_size):
+ data = {
+ "__key__": f'id_{galaxy_n}',
+ "image.jpg": (np.random.rand(424, 424)*255.).astype(np.uint8),
+ "labels.json": json.dumps(dict(zip(label_cols, [np.random.randint(low=0, high=10) for _ in range(len(label_cols))])))
+ }
+ sink.write(data)
+ counter += 1
+ print(counter)
+ return shards
+
+
+if __name__ == '__main__':
+
+ save_dir = '/home/walml/repos/temp'
+ from galaxy_datasets.shared import label_metadata
+ label_cols = label_metadata.decals_all_campaigns_ortho_label_cols
+
+ make_mock_wds(save_dir, label_cols, n_shards=4, shard_size=512)
diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py
index b55d9fed..f5980492 100755
--- a/zoobot/pytorch/estimators/define_model.py
+++ b/zoobot/pytorch/estimators/define_model.py
@@ -5,11 +5,12 @@
import torch
from torch import nn
import pytorch_lightning as pl
-from torchmetrics import Accuracy
+import torchmetrics
import timm
+from zoobot.shared import schemas
from zoobot.pytorch.estimators import efficientnet_custom, custom_layers
-from zoobot.pytorch.training import losses
+from zoobot.pytorch.training import losses, schedulers
# overall strategy
# timm for defining complicated pytorch modules
@@ -55,56 +56,109 @@ def __init__(
):
super().__init__()
self.save_hyperparameters() # saves all args by default
- self.setup_metrics()
- def setup_metrics(self):
- # these are ignored unless output dim = 2
- self.train_accuracy = Accuracy(task='binary')
- self.val_accuracy = Accuracy(task='binary')
- # self.log_on_step = False
- # self.log_on_step is useful for debugging, but slower - best when log_every_n_steps is fairly large
+ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore nan even in main metrics
+ self.val_accuracy = torchmetrics.Accuracy(task='binary')
+
+ self.loss_metrics = torch.nn.ModuleDict({
+ 'train/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
+ 'validation/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
+ 'test/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
+ })
+
+ # TODO handle when schema doesn't exist
+ question_metric_dict = {}
+ for step_name in ['train', 'validation', 'test']:
+ question_metric_dict.update({
+ step_name + '/question_loss/' + question.text: torchmetrics.MeanMetric(nan_strategy='ignore')
+ for question in self.schema.questions
+ })
+ self.question_loss_metrics = torch.nn.ModuleDict(question_metric_dict)
+
+ campaigns = schema_to_campaigns(self.schema)
+ campaign_metric_dict = {}
+ for step_name in ['train', 'validation', 'test']:
+ campaign_metric_dict.update({
+ step_name + '/campaign_loss/' + campaign: torchmetrics.MeanMetric(nan_strategy='ignore')
+ for campaign in campaigns
+ })
+ self.campaign_loss_metrics = torch.nn.ModuleDict(campaign_metric_dict)
def forward(self, x):
+ assert x.shape[1] < 4 # torchlike BCHW
x = self.encoder(x)
return self.head(x)
- def make_step(self, batch, batch_idx, step_name):
+ def make_step(self, batch, step_name):
x, labels = batch
predictions = self(x) # by default, these are Dirichlet concentrations
- loss = self.calculate_and_log_loss(predictions, labels, step_name)
- return {'loss': loss, 'predictions': predictions, 'labels': labels}
-
- def calculate_and_log_loss(self, predictions, labels, step_name):
- raise NotImplementedError('Must be subclassed')
+ loss = self.calculate_loss_and_update_loss_metrics(predictions, labels, step_name)
+ outputs = {'loss': loss, 'predictions': predictions, 'labels': labels}
+ # self.update_other_metrics(outputs, step_name=step_name)
+ return outputs
def configure_optimizers(self):
raise NotImplementedError('Must be subclassed')
def training_step(self, batch, batch_idx):
- return self.make_step(batch, batch_idx, step_name='train')
-
- def on_train_batch_end(self, outputs, *args):
- self.log_outputs(outputs, step_name='train')
+ return self.make_step(batch, step_name='train')
def validation_step(self, batch, batch_idx):
- return self.make_step(batch, batch_idx, step_name='validation')
+ return self.make_step(batch, step_name='validation')
+
+ def test_step(self, batch, batch_idx):
+ return self.make_step(batch, step_name='test')
+
+ # def on_train_batch_end(self, outputs, *args):
+ # pass
+
+ # def on_validation_batch_end(self, outputs, *args):
+ # pass
- def on_validation_batch_end(self, outputs, *args):
- self.log_outputs(outputs, step_name='validation')
+ def on_train_epoch_end(self) -> None:
+ # called *after* on_validation_epoch_end, confusingly
+ # do NOT log_all_metrics here.
+ # logging a metric resets it, and on_validation_epoch_end just logged and reset everything, so you will only log nans
+ self.log_all_metrics(subset='train')
- def log_outputs(self, outputs, step_name):
+ def on_validation_epoch_end(self) -> None:
+ self.log_all_metrics(subset='validation')
+
+ def on_test_epoch_end(self) -> None:
+ # logging.info('start test epoch end')
+ self.log_all_metrics(subset='test')
+ # logging.info('end test epoch end')
+
+ def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
+ raise NotImplementedError('Must be subclassed')
+
+ def update_other_metrics(self, outputs, step_name):
raise NotImplementedError('Must be subclassed')
- def test_step(self, batch, batch_idx):
- return self.make_step(batch, batch_idx, step_name='test')
+ def log_all_metrics(self, subset=None):
+ if subset is not None:
+ for metric_collection in (self.loss_metrics, self.question_loss_metrics, self.campaign_loss_metrics):
+ prog_bar = metric_collection == self.loss_metrics
+ for name, metric in metric_collection.items():
+ if subset in name:
+ # logging.info(name)
+ self.log(name, metric, on_epoch=True, on_step=False, prog_bar=prog_bar, logger=True)
+ else: # just log everything
+ self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True)
+ self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True)
+ self.log_dict(self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True)
+
- def on_test_batch_end(self, outputs, *args):
- self.log_outputs(outputs, step_name='test')
def predict_step(self, batch, batch_idx, dataloader_idx=0):
+ # I can't work out how to get webdataset to return a single item im, not a tuple (im,).
+ # this is fine for training but annoying for predict
+ # help welcome. meanwhile, this works around it
+ if isinstance(batch, list) and len(batch) == 1:
+ return self(batch[0])
# https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#inference
# this calls forward, while avoiding the need for e.g. model.eval(), torch.no_grad()
# x, y = batch # would be usual format, but here, batch does not include labels
@@ -123,10 +177,8 @@ class ZoobotTree(GenericLightningModule):
Args:
output_dim (int): Output dimension of model's head e.g. 34 for predicting a 34-answer decision tree.
- question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0".
channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1.
- use_imagenet_weights (bool, optional): Load weights pretrained on ImageNet (NOT galaxies!). Defaults to False.
test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
timm_kwargs (dict, optional): passed to timm.create_model e.g. drop_path_rate=0.2 for effnet. Defaults to {}.
learning_rate (float, optional): AdamW learning rate. Defaults to 1e-3.
@@ -138,12 +190,19 @@ class ZoobotTree(GenericLightningModule):
def __init__(
self,
output_dim: int,
- question_index_groups: List,
+ # in the simplest case, this is all zoobot needs: grouping of label col indices as questions
+ # question_index_groups: List=None,
+ # BUT
+ # if you pass these, it enables better per-question and per-survey logging (because we have names)
+ # must be passed as simple dicts, not objects, so can't just pass schema in
+ question_answer_pairs: dict=None,
+ dependencies: dict=None,
# encoder args
architecture_name="efficientnet_b0",
channels=1,
- use_imagenet_weights=False,
+ # use_imagenet_weights=False,
test_time_dropout=True,
+ compile_encoder=False,
timm_kwargs={}, # passed to timm.create_model e.g. drop_path_rate=0.2 for effnet
# head args
dropout_rate=0.2,
@@ -157,11 +216,14 @@ def __init__(
# now, finally, can pass only standard variables as hparams to save
# will still need to actually use these variables later, this super init only saves them
super().__init__(
+ # these all do nothing, they are simply saved by lightning as hparams
output_dim,
- question_index_groups,
+ question_answer_pairs,
+ dependencies,
architecture_name,
channels,
timm_kwargs,
+ compile_encoder,
test_time_dropout,
dropout_rate,
learning_rate,
@@ -172,6 +234,15 @@ def __init__(
logging.info('Generic __init__ complete - moving to Zoobot __init__')
+ # logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
+ # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
+ assert dependencies is not None
+ self.schema = schemas.Schema(question_answer_pairs, dependencies)
+ # replace with schema-derived version
+ question_index_groups = self.schema.question_index_groups
+
+ self.setup_metrics()
+
# set attributes for learning rate, betas, used by self.configure_optimizers()
# TODO refactor to optimizer params
self.learning_rate = learning_rate
@@ -182,11 +253,16 @@ def __init__(
self.encoder = get_pytorch_encoder(
architecture_name,
channels,
- use_imagenet_weights=use_imagenet_weights,
+ # use_imagenet_weights=use_imagenet_weights,
**timm_kwargs
)
+ if compile_encoder:
+ logging.warning('Using torch.compile on encoder')
+ self.encoder = torch.compile(self.encoder)
+
# bit lazy assuming 224 input size
- self.encoder_dim = get_encoder_dim(self.encoder, input_size=224, channels=channels)
+ # logging.warning(channels)
+ self.encoder_dim = get_encoder_dim(self.encoder, channels)
# typically encoder_dim=1280 for effnetb0
logging.info('encoder dim: {}'.format(self.encoder_dim))
@@ -203,15 +279,15 @@ def __init__(
logging.info('Zoobot __init__ complete')
- def calculate_and_log_loss(self, predictions, labels, step_name):
+ def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
# self.loss_func returns shape of (galaxy, question), mean to ()
multiq_loss = self.loss_func(predictions, labels, sum_over_questions=False)
- # if hasattr(self, 'schema'):
- self.log_loss_per_question(multiq_loss, prefix=step_name)
+ self.update_per_question_loss_metric(multiq_loss, step_name=step_name)
# sum over questions and take a per-device mean
# for DDP strategy, batch size is constant (batches are not divided, data pool is divided)
# so this will be the global per-example mean
loss = torch.mean(torch.sum(multiq_loss, axis=1))
+ self.loss_metrics[step_name + '/supervised_loss'](loss)
return loss
@@ -232,31 +308,59 @@ def configure_optimizers(self):
min_lr=1e-6,
patience=self.scheduler_params.get('patience', 5)
)
- return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'validation/epoch_loss'}
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'validation/loss'}
+ elif self.scheduler_params.get('cosine_schedule', False):
+ logging.info('Using cosine schedule')
+ scheduler = schedulers.CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_epochs=self.scheduler_params['warmup_epochs'],
+ max_epochs=self.scheduler_params['max_cosine_epochs'],
+ start_value=self.learning_rate,
+ end_value=self.learning_rate * self.scheduler_params['max_learning_rate_reduction_factor']
+ )
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'validation/loss'}
else:
logging.info('No scheduler used')
return optimizer # no scheduler
- def log_outputs(self, outputs, step_name):
- self.log("{}/epoch_loss".format(step_name), outputs['loss'], on_epoch=True, on_step=False,prog_bar=True, logger=True, sync_dist=True)
- # if self.log_on_step:
- # # seperate call to allow for different name, to allow for consistency with TF.keras auto-names
- # self.log(
- # "{}/step_loss".format(step_name), outputs['loss'], on_epoch=False, on_step=True, prog_bar=True, logger=True, sync_dist=True)
- if outputs['predictions'].shape[1] == 2: # will only do for binary classifications
- # logging.info(predictions.shape, labels.shape)
- self.log(
- "{}_accuracy".format(step_name), self.train_accuracy(outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)), prog_bar=True, sync_dist=True)
-
- def log_loss_per_question(self, multiq_loss, prefix):
+ def update_per_question_loss_metric(self, multiq_loss, step_name):
# log questions individually
# TODO need schema attribute or similar to have access to question names, this will do for now
# unlike Finetuneable..., does not use TorchMetrics, simply logs directly
# TODO could use TorchMetrics and for q in schema, self.q_metric loop
- for question_n in range(multiq_loss.shape[1]):
- self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True)
+
+ # if hasattr(self, 'schema'):
+ # use schema metadata to log intelligently
+ # will have schema if question_answer_pairs and dependencies are passed to __init__
+ # assume that questions are named like smooth-or-featured-CAMPAIGN
+ for question_n, question in enumerate(self.schema.questions):
+ # for logging comparison, want to ignore loss on unlablled examples, i.e. take mean ignoring zeros
+ # could sum, but then this would vary with batch size
+ nontrivial_loss_mask = multiq_loss[:, question_n] > 0 # 'zero' seems to be ~5e-5 floor in practice
+
+ this_question_metric = self.question_loss_metrics[step_name + '/question_loss/' + question.text]
+ # raise ValueError
+ this_question_metric(torch.mean(multiq_loss[nontrivial_loss_mask, question_n]))
+
+ campaigns = schema_to_campaigns(self.schema)
+ for campaign in campaigns:
+ campaign_questions = [q for q in self.schema.questions if campaign in q.text]
+ campaign_q_indices = [self.schema.questions.index(q) for q in campaign_questions] # shape (num q in this campaign e.g. 10)
+
+ # similarly to per-question, only include in mean if (any) q in this campaign has a non-trivial loss
+ nontrivial_loss_mask = multiq_loss[:, campaign_q_indices].sum(axis=1) > 0 # shape batch size
+
+ this_campaign_metric = self.campaign_loss_metrics[step_name + '/campaign_loss/' + campaign]
+ this_campaign_metric(torch.mean(multiq_loss[nontrivial_loss_mask][:, campaign_q_indices]))
+
+ # else:
+ # # fallback to logging with question_n
+ # for question_n in range(multiq_loss.shape[1]):
+ # self.log(f'{step_name}/questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True)
+
+
@@ -274,22 +378,34 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals
# multiquestion_loss returns loss of shape (batch, question)
# torch.sum(multiquestion_loss, axis=1) gives loss of shape (batch). Equiv. to non-log product of question likelihoods.
- multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups)
+ multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups, careful=True)
if sum_over_questions:
return torch.sum(multiq_loss, axis=1)
else:
return multiq_loss
-def get_encoder_dim(encoder, input_size, channels):
- x = torch.randn(1, channels, input_size, input_size) # batch size of 1
- return encoder(x).shape[-1]
-
+# input_size doesn't matter as long as it's large enough to not be pooled to zero
+# channels doesn't matter at all but has to match encoder channels or shape error
+def get_encoder_dim(encoder, channels=3):
+ device = next(encoder.parameters()).device
+ try:
+ x = torch.randn(2, channels, 224, 224, device=device) # BCHW
+ return encoder(x).shape[-1]
+ except RuntimeError as e:
+ if 'channels instead' in str(e):
+ logging.info('encoder dim search failed on channels, trying with channels=1')
+ channels = 1
+ x = torch.randn(2, channels, 224, 224, device=device) # BCHW
+ return encoder(x).shape[-1]
+ else:
+ raise e
+
def get_pytorch_encoder(
architecture_name='efficientnet_b0',
channels=1,
- use_imagenet_weights=False,
+ # use_imagenet_weights=False,
**timm_kwargs
) -> nn.Module:
"""
@@ -316,12 +432,16 @@ def get_pytorch_encoder(
"""
# num_classes=0 gives pooled encoder
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/efficientnet.py
+
+ # if architecture_name == 'toy':
+ # logging.warning('Using toy encoder')
+ # return ToyEncoder()
# support older code that didn't specify effnet version
if architecture_name == 'efficientnet':
logging.warning('efficientnet variant not specified - please set architecture_name=efficientnet_b0 (or similar)')
architecture_name = 'efficientnet_b0'
- return timm.create_model(architecture_name, in_chans=channels, num_classes=0, pretrained=use_imagenet_weights, **timm_kwargs)
+ return timm.create_model(architecture_name, in_chans=channels, num_classes=0, **timm_kwargs)
def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float) -> torch.nn.Sequential:
@@ -357,3 +477,17 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop
modules_to_use.append(efficientnet_custom.custom_top_dirichlet(encoder_dim, output_dim))
return nn.Sequential(*modules_to_use)
+
+
+def schema_to_campaigns(schema):
+ # e.g. [gz2, dr12, ...]
+ return [question.text.split('-')[-1] for question in schema.questions]
+
+
+if __name__ == '__main__':
+ encoder = get_pytorch_encoder(channels=1)
+ dim = get_encoder_dim(encoder, channels=1)
+ print(dim)
+
+
+ ZoobotTree.load_from_checkpoint
\ No newline at end of file
diff --git a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py
index c5309e8b..4cf7efff 100644
--- a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py
+++ b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py
@@ -26,11 +26,6 @@
# For binary classification, the label column should have binary (0 or 1) labels for your classes
# To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.
- # load a pretrained checkpoint saved here
- # https://www.dropbox.com/s/7ixwo59imjfz4ay/effnetb0_greyscale_224px.ckpt?dl=0
- # see https://zoobot.readthedocs.io/en/latest/data_notes.html for more options
- checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt')
-
# save the finetuning results here
save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_binary_classification')
@@ -47,7 +42,7 @@
model = finetune.FinetuneableZoobotClassifier(
- checkpoint_loc=checkpoint_loc,
+ name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
num_classes=2,
n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
)
diff --git a/zoobot/pytorch/manchester.py b/zoobot/pytorch/manchester.py
index 2ed39e6a..dfab82fd 100644
--- a/zoobot/pytorch/manchester.py
+++ b/zoobot/pytorch/manchester.py
@@ -1,44 +1,58 @@
import logging
+import os
-from lightning_lite.plugins.environments import SLURMEnvironment
-# https://pytorch-lightning.readthedocs.io/en/stable/_modules/lightning_lite/plugins/environments/slurm.html#SLURMEnvironment
-# https://github.com/Lightning-AI/lightning/blob/9c20cad40e4142f8a5e945fe26e919e598f2bd56/src/lightning_lite/plugins/environments/slurm.py
-class ManchesterEnvironment(SLURMEnvironment):
+# from lightning_lite.plugins.environments import SLURMEnvironment
- def __init__(self, auto_requeue: bool = True, requeue_signal=None) -> None:
- logging.info('Using Manchester SLURM environment')
- super().__init__(auto_requeue, requeue_signal)
+# # https://pytorch-lightning.readthedocs.io/en/stable/_modules/lightning_lite/plugins/environments/slurm.html#SLURMEnvironment
+# # https://github.com/Lightning-AI/lightning/blob/9c20cad40e4142f8a5e945fe26e919e598f2bd56/src/lightning_lite/plugins/environments/slurm.py
+# class ManchesterEnvironment(SLURMEnvironment):
- # @staticmethod
- # def resolve_root_node_address(nodes: str) -> str:
- # root_node_address = super().resolve_root_node_address(nodes)
- # logging.info(f'root_node_address: {root_node_address}')
- # return root_node_address
+# def __init__(self, auto_requeue: bool = True, requeue_signal=None) -> None:
+# logging.info('Using Manchester SLURM environment')
+# super().__init__(auto_requeue, requeue_signal)
- @staticmethod
- def detect() -> bool:
- return True
+# # @staticmethod
+# # def resolve_root_node_address(nodes: str) -> str:
+# # root_node_address = super().resolve_root_node_address(nodes)
+# # logging.info(f'root_node_address: {root_node_address}')
+# # return root_node_address
+
+# @staticmethod
+# def detect() -> bool:
+# return True
- @property
- def main_port(self) -> int:
- main_port = super().main_port
- logging.info(f'main_port: {main_port}')
- return main_port
- # MASTER_PORT will override
+# @property
+# def main_port(self) -> int:
+# main_port = super().main_port
+# logging.info(f'main_port: {main_port}')
+# return main_port
+# # MASTER_PORT will override
+
+
+from pytorch_lightning.plugins.environments import SLURMEnvironment
+class GalahadEnvironment(SLURMEnvironment):
+ def __init__(self, **kwargs):
+ ntasks_per_node = os.environ["SLURM_TASKS_PER_NODE"].split("(")[0]
+ os.environ["SLURM_NTASKS_PER_NODE"] = ntasks_per_node
+ logging.warning(f'Within custom slurm environment, --n-tasks-per-node={ntasks_per_node}')
+ # os.environ["SLURM_NTASKS"] = str(os.environ["SLURM_NTASKS_PER_NODE"])
+ super().__init__(**kwargs)
+ self.nnodes = int(os.environ["SLURM_NNODES"])
+
-if __name__ == '__main__':
+# if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
+# logging.basicConfig(level=logging.INFO)
- # slurm_nodelist = "compute-0-[0,9]" # 0,9 works
- slurm_nodelist = "compute-0-[0,11]" # 0,11 hangs
- # 70017 8-9 works
+# # slurm_nodelist = "compute-0-[0,9]" # 0,9 works
+# slurm_nodelist = "compute-0-[0,11]" # 0,11 hangs
+# # 70017 8-9 works
- env = ManchesterEnvironment()
- root = env.resolve_root_node_address(slurm_nodelist)
- print(root)
+# env = GalahadEnvironment()
+# root = env.resolve_root_node_address(slurm_nodelist)
+# print(root)
- print(env.detect())
+# print(env.detect())
- print(env.main_port)
\ No newline at end of file
+# print(env.main_port)
\ No newline at end of file
diff --git a/zoobot/pytorch/predictions/predict_on_catalog.py b/zoobot/pytorch/predictions/predict_on_catalog.py
index 3a68ab88..918593c2 100644
--- a/zoobot/pytorch/predictions/predict_on_catalog.py
+++ b/zoobot/pytorch/predictions/predict_on_catalog.py
@@ -38,14 +38,14 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la
# crucial to specify the stage, or will error (as missing other catalogs)
predict_datamodule.setup(stage='predict')
# for images in predict_datamodule.predict_dataloader():
- # print(images)
- # print(images.shape)
+ # print(images)
+ # print(images.shape)
+ # exit()
# set up trainer (again)
trainer = pl.Trainer(
max_epochs=-1, # does nothing in this context, suppresses warning
- inference_mode=True, # no grads needed
**trainer_kwargs # e.g. gpus
)
@@ -85,3 +85,5 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la
end = datetime.datetime.fromtimestamp(time.time())
logging.info('Completed at: {}'.format(end.strftime('%Y-%m-%d %H:%M:%S')))
logging.info('Time elapsed: {}'.format(end - start))
+
+ return predictions
diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py
index 6f3cf3e2..67ef1997 100644
--- a/zoobot/pytorch/training/finetune.py
+++ b/zoobot/pytorch/training/finetune.py
@@ -1,19 +1,21 @@
-# Based on Inigo's BYOL FT step
-# https://github.com/inigoval/finetune/blob/main/finetune.py
import logging
import os
+from typing import Any, Union, Optional
import warnings
from functools import partial
+import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
+from pytorch_lightning.callbacks import LearningRateMonitor
import torch
import torch.nn.functional as F
import torchmetrics as tm
+import timm
-from zoobot.pytorch.training import losses
+from zoobot.pytorch.training import losses, schedulers
from zoobot.pytorch.estimators import define_model
from zoobot.shared import schemas
@@ -32,87 +34,128 @@ def freeze_batchnorm_layers(model):
class FinetuneableZoobotAbstract(pl.LightningModule):
"""
- Parent class of :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree`.
+ Parent class of :class:`FinetuneableZoobotClassifier`, :class:`FinetuneableZoobotRegressor`, :class:`FinetuneableZoobotTree`.
You cannot use this class directly - you must use the child classes above instead.
- This class defines the finetuning methods that those child classes both use.
- For example: when provided `checkpoint_loc`, it will load the encoder from that checkpoint.
- Both :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree`
- can (and should) be passed any of these arguments to customise finetuning.
+ This class defines the shared finetuning args and methods used by those child classes.
+ For example:
+ - When provided `name`, it will load the HuggingFace encoder with that name (see below for more).
+ - When provided `learning_rate` it will set the optimizer to use that learning rate.
- You could subclass this class to solve new finetuning tasks (like regression) - see :ref:`advanced_finetuning`.
+ Any FinetuneableZoobot model can be loaded in one of three ways:
+ - HuggingFace name e.g. `FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`. Recommended.
+ - Any PyTorch model in memory e.g. `FinetuneableZoobotX(encoder=some_model, ...)`
+ - ZoobotTree checkpoint e.g. `FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)`
+
+ You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`.
Args:
- checkpoint_loc (str, optional): Path to encoder checkpoint to load (likely a saved ZoobotTree). Defaults to None.
- encoder (pl.LightningModule, optional): Alternatively, pass an encoder directly. Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`.
- encoder_dim (int, optional): Output dimension of encoder. Defaults to 1280 (EfficientNetB0's encoder dim).
+ name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None.
+ encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory
+ zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_zoobot`. Defaults to None.
+
+ n_blocks (int, optional):
lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
learning_rate (float, optional): AdamW learning rate arg. Defaults to 1e-4.
dropout_prob (float, optional): P of dropout before final output layer. Defaults to 0.5.
- freeze_batchnorm (bool, optional): If True, do not update batchnorm stats during finetuning. Defaults to True.
+ always_train_batchnorm (bool, optional): Temporarily deprecated. Previously, if True, do not update batchnorm stats during finetuning. Defaults to True.
+ cosine_schedule (bool, optional): Reduce the learning rate each epoch according to a cosine schedule, after warmup_epochs. Defaults to False.
+ warmup_epochs (int, optional): Linearly increase the learning rate from 0 to ``learning_rate`` over the first ``warmup_epochs`` epochs, before applying cosine schedule. No effect if cosine_schedule=False.
+ max_cosine_epochs (int, optional): Epochs for the scheduled learning rate to decay to final learning rate (below). Warmup epochs don't count. No effect if ``cosine_schedule=False``.
+ max_learning_rate_reduction_factor (float, optional): Set final learning rate as ``learning_rate`` * ``max_learning_rate_reduction_factor``. No effect if ``cosine_schedule=False``.
+ from_scratch (bool, optional): Ignore all settings above and train from scratch at ``learning_rate`` for all layers. Useful for a quick baseline. Defaults to False.
prog_bar (bool, optional): Print progress bar during finetuning. Defaults to True.
visualize_images (bool, optional): Upload example images to WandB. Good for debugging but slow. Defaults to False.
seed (int, optional): random seed to use. Defaults to 42.
+ n_layers: No effect, deprecated. Use n_blocks instead.
"""
def __init__(
self,
- # can provide either checkpoint_loc, and will load this model as encoder...
- checkpoint_loc=None,
- # ...or directly pass model to use as encoder
- encoder=None,
- encoder_dim=1280, # as per current Zooot. TODO Could get automatically?
- n_epochs=100, # TODO early stopping
+
+ # load a pretrained timm encoder saved on huggingface hub
+ # (aimed at most users, easiest way to load published models)
+ name=None,
+
+ # ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later)
+ # (aimed at tinkering with new architectures e.g. SSL)
+ encoder=None, # use any torch model already loaded in memory (must have .forward() method)
+
+ # load a pretrained zoobottree model and grab the encoder (a timm model)
+ # requires the exact same zoobot version used for training, not very portable
+ # (aimed at supervised experiments)
+ zoobot_checkpoint_loc=None,
+
+ # finetuning settings
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
- always_train_batchnorm=True,
+ always_train_batchnorm=False, # temporarily deprecated
+ # n_layers=0, # for backward compat., n_blocks preferred. Now removed in v2.
+ # these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
+ cosine_schedule=False,
+ warmup_epochs=0,
+ max_cosine_epochs=100,
+ max_learning_rate_reduction_factor=0.01,
+ # escape hatch for 'from scratch' baselines
+ from_scratch=False,
+ # debugging utils
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
- n_layers=0 # for backward compat., n_blocks preferred
+ n_layers=None, # deprecated, no effect
):
super().__init__()
# adds every __init__ arg to model.hparams
# will also add to wandb if using logging=wandb, I think
# necessary if you want to reload!
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
+ # with warnings.catch_warnings():
+ # warnings.simplefilter("ignore")
# this raises a warning that encoder is already a Module hence saved in checkpoint hence no need to save as hparam
# true - except we need it to instantiate this class, so it's really handy to have saved as well
# therefore ignore the warning
- self.save_hyperparameters()
-
- if checkpoint_loc is not None:
- assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
- self.encoder = load_pretrained_encoder(checkpoint_loc)
+ self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy
+ # if you need the encoder to recreate, pass when loading checkpoint e.g.
+ # FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder)
+
+ if name is not None:
+ assert encoder is None, 'Cannot pass both name and encoder to use'
+ self.encoder = timm.create_model(name, num_classes=0, pretrained=True)
+ self.encoder_dim = self.encoder.num_features
+
+ elif zoobot_checkpoint_loc is not None:
+ assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
+ self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder
+ self.encoder_dim = self.encoder.num_features
else:
- assert checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
- assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
- self.encoder = encoder
+ assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
+ assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
+ self.encoder = encoder
+ # work out encoder dim 'manually'
+ self.encoder_dim = define_model.get_encoder_dim(self.encoder)
- self.encoder_dim = encoder_dim
self.n_blocks = n_blocks
- # for backwards compat.
- if n_layers:
- logging.warning('FinetuneableZoobot(n_layers) is now renamed to n_blocks, please update to pass n_blocks instead! For now, setting n_blocks=n_layers')
- self.n_blocks = n_layers
- logging.info('Layers to finetune: {}'.format(n_layers))
-
self.learning_rate = learning_rate
self.lr_decay = lr_decay
self.weight_decay = weight_decay
self.dropout_prob = dropout_prob
- self.n_epochs = n_epochs
+
+ self.cosine_schedule = cosine_schedule
+ self.warmup_epochs = warmup_epochs
+ self.max_cosine_epochs = max_cosine_epochs
+ self.max_learning_rate_reduction_factor = max_learning_rate_reduction_factor
+
+ self.from_scratch = from_scratch
self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
- logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned')
+ raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported')
+ # logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned')
self.train_loss_metric = tm.MeanMetric()
self.val_loss_metric = tm.MeanMetric()
@@ -122,48 +165,73 @@ def __init__(
self.prog_bar = prog_bar
self.visualize_images = visualize_images
- def configure_optimizers(self):
+ def configure_optimizers(self):
+ """
+ This controls which parameters get optimized
+
+ self.head is always optimized, with no learning rate decay
+ when self.n_blocks == 0, only self.head is optimized (i.e. frozen* encoder)
+
+ for self.encoder, we enumerate the blocks (groups of layers) to potentially finetune
+ and then pick the top self.n_blocks to finetune
+
+ weight_decay is applied to both the head and (if relevant) the encoder
+ learning rate decay is applied to the encoder only: lr x (lr_decay^block_n), ignoring the head (block 0)
+
+ What counts as a "block" is a bit fuzzy, but I generally use the self.encoder.stages from timm. I also count the stem as a block.
+
+ batch norm layers may optionally still have updated statistics using always_train_batchnorm
+ """
lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]
- if hasattr(self.encoder, 'blocks'):
- logging.info('Effnet detected')
- # TODO this actually excludes the first conv layer/bn
- encoder_blocks = self.encoder.blocks
- blocks_to_tune = list(encoder_blocks)
- elif hasattr(self.encoder, 'layer4'):
- logging.info('Resnet detected')
- # similarly, excludes first conv/bn
- blocks_to_tune = [
+ logging.info(f'Encoder architecture to finetune: {type(self.encoder)}')
+
+ if self.from_scratch:
+ logging.warning('self.from_scratch is True, training everything and ignoring all settings')
+ params += [{"params": self.encoder.parameters(), "lr": lr}]
+ return torch.optim.AdamW(params, weight_decay=self.weight_decay)
+
+ if isinstance(self.encoder, timm.models.EfficientNet): # includes v2
+ # TODO for now, these count as separate layers, not ideal
+ early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
+ encoder_blocks = list(self.encoder.blocks)
+ tuneable_blocks = early_tuneable_layers + encoder_blocks
+ elif isinstance(self.encoder, timm.models.ResNet):
+ # all timm resnets seem to have this structure
+ tuneable_blocks = [
+ # similarly
+ self.encoder.conv1,
+ self.encoder.bn1,
self.encoder.layer1,
self.encoder.layer2,
self.encoder.layer3,
self.encoder.layer4
]
- elif hasattr(self.encoder, 'stages'):
- logging.info('Max-ViT Tiny detected')
- blocks_to_tune = [
- # getattr as obj.0 is not allowed (why does timm call them 0!?)
- getattr(self.encoder.stages, '0'),
- getattr(self.encoder.stages, '1'),
- getattr(self.encoder.stages, '2'),
- getattr(self.encoder.stages, '3'),
- ]
+ elif isinstance(self.encoder, timm.models.MaxxVit):
+ tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
+ elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 4 blocks, for all sizes
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264
+ tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
else:
- raise ValueError('Encoder architecture not automatically recognised')
-
+ raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')
+
assert self.n_blocks <= len(
- blocks_to_tune
- ), f"Network only has {len(blocks_to_tune)} tuneable blocks, {self.n_blocks} specified for finetuning"
+ tuneable_blocks
+ ), f"Network only has {len(tuneable_blocks)} tuneable blocks, {self.n_blocks} specified for finetuning"
# take n blocks, ordered highest layer to lowest layer
- blocks_to_tune.reverse()
+ tuneable_blocks.reverse()
+ logging.info('possible blocks to tune: {}'.format(len(tuneable_blocks)))
# will finetune all params in first N
- blocks_to_tune = blocks_to_tune[:self.n_blocks]
+ logging.info('blocks that will be tuned: {}'.format(self.n_blocks))
+ blocks_to_tune = tuneable_blocks[:self.n_blocks]
# optionally, can finetune batchnorm params in remaining layers
- remaining_blocks = blocks_to_tune[self.n_blocks:]
+ remaining_blocks = tuneable_blocks[self.n_blocks:]
+ logging.info('Remaining blocks: {}'.format(len(remaining_blocks)))
+ assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining'
# Append parameters of layers for finetuning along with decayed learning rate
for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3'
@@ -172,29 +240,68 @@ def configure_optimizers(self):
"lr": lr * (self.lr_decay**i)
})
- logging.debug(params)
-
# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
for i, block in enumerate(remaining_blocks):
if self.always_train_batchnorm:
- params.append({
- "params": get_batch_norm_params_lighting(block),
- "lr": lr * (self.lr_decay**i)
- })
-
- # TODO this actually breaks training because the generator only iterates once!
- # total_params = sum(p.numel() for param_set in params.copy() for p in param_set['params'])
- # logging.info('Total params to fit: {}'.format(total_params))
-
+ raise NotImplementedError
+ # _, block_batch_norm_params = get_batch_norm_params_lighting(block)
+ # params.append({
+ # "params": block_batch_norm_params,
+ # "lr": lr * (self.lr_decay**i)
+ # })
+
+
+ logging.info('param groups: {}'.format(len(params)))
+
+ # because it iterates through the generators, THIS BREAKS TRAINING so only uncomment to debug params
+ # for param_group_n, param_group in enumerate(params):
+ # shapes_within_param_group = [p.shape for p in list(param_group['params'])]
+ # logging.debug('param group {}: {}'.format(param_group_n, shapes_within_param_group))
+ # print('head params to optimize', [p.shape for p in params[0]['params']]) # head only
+ # print(list(param_group['params']) for param_group in params)
+ # exit()
# Initialize AdamW optimizer
- opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict
+ opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict
+ logging.info('Optimizer ready, configuring scheduler')
+
+ if self.cosine_schedule:
+ logging.info('Using lightly cosine schedule, warmup for {} epochs, max for {} epochs'.format(self.warmup_epochs, self.max_cosine_epochs))
+ # from lightly.utils.scheduler import CosineWarmupScheduler #copied from here to avoid dependency
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers
+ # Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
+ lr_scheduler = schedulers.CosineWarmupScheduler(
+ optimizer=opt,
+ warmup_epochs=self.warmup_epochs,
+ max_epochs=self.max_cosine_epochs,
+ start_value=self.learning_rate,
+ end_value=self.learning_rate * self.max_learning_rate_reduction_factor,
+ )
+
+ # logging.info('Using CosineAnnealingLR schedule, warmup not supported, max for {} epochs'.format(self.max_cosine_epochs))
+ # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ # optimizer=opt,
+ # T_max=self.max_cosine_epochs,
+ # eta_min=self.learning_rate * self.max_learning_rate_reduction_factor
+ # )
+
+ return {
+ "optimizer": opt,
+ "lr_scheduler": {
+ 'scheduler': lr_scheduler,
+ 'interval': 'epoch',
+ 'frequency': 1
+ }
+ }
+ else:
+ logging.info('Learning rate scheduler not used')
return opt
-
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.head(x)
+ # TODO encoder output shape changes with input shape (of course) so need to specify explicitly or skip
return x
def make_step(self, batch):
@@ -219,6 +326,14 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
def test_step(self, batch, batch_idx, dataloader_idx=0):
return self.make_step(batch)
+
+ def predict_step(self, batch, batch_idx) -> Any:
+ # I can't work out how to get webdataset to return a single item im, not a tuple (im,).
+ # this is fine for training but annoying for predict
+ # help welcome. meanwhile, this works around it
+ if isinstance(batch, list) and len(batch) == 1:
+ return self(batch[0])
+ return self(batch)
def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
# v2 docs currently do not show dataloader_idx as train argument so unclear if this will value be updated properly
@@ -234,7 +349,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
on_epoch=True
)
- def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
+ def on_validation_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx=0):
self.val_loss_metric(outputs['loss'])
self.log(
"finetuning/val_loss",
@@ -247,7 +362,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx
if self.visualize_images:
self.upload_images_to_wandb(outputs, batch, batch_idx)
- def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
+ def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx=0):
self.test_loss_metric(outputs['loss'])
self.log(
"finetuning/test_loss",
@@ -264,6 +379,13 @@ def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
def upload_images_to_wandb(self, outputs, batch, batch_idx):
raise NotImplementedError('Must be subclassed')
+
+ @classmethod
+ def load_from_name(cls, name: str, **kwargs):
+ downloaded_loc = download_from_name(cls.__name__, name)
+ return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error
+
+
@@ -271,15 +393,17 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract):
"""
Pretrained Zoobot model intended for finetuning on a classification problem.
- You must also pass either ``checkpoint_loc`` (to a saved encoder checkpoint)
- or `encoder` (to a pytorch model already loaded in memory).
- See :class:FinetuneableZoobotAbstract for more options.
+ Any args not listed below are passed to :class:``FinetuneableZoobotAbstract`` (for example, `learning_rate`).
+ These are shared between classifier, regressor, and tree models.
+ See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
- Any args not in the list below are passed to :class:``FinetuneableZoobotAbstract`` (usually to specify how to carry out the finetuning)
+ Models can be loaded with `FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
Args:
num_classes (int): num. of target classes (e.g. 2 for binary classification).
label_smoothing (float, optional): See torch cross_entropy_loss docs. Defaults to 0.
+ class_weights (arraylike, optional): See torch cross_entropy_loss docs. Defaults to None.
"""
@@ -293,7 +417,7 @@ def __init__(
super().__init__(**super_kwargs)
logging.info('Using classification head and cross-entropy loss')
- self.head = LinearClassifier(
+ self.head = LinearHead(
input_dim=self.encoder_dim,
output_dim=num_classes,
dropout_prob=self.dropout_prob
@@ -314,7 +438,7 @@ def __init__(
self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
def step_to_dict(self, y, y_pred, loss):
- y_class_preds = torch.argmax(y_pred, axis=1)
+ y_class_preds = torch.argmax(y_pred, axis=1) # type: ignore
return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y, 'class_predictions': y_class_preds}
def on_train_batch_end(self, step_output, *args):
@@ -354,8 +478,11 @@ def on_test_batch_end(self, step_output, *args) -> None:
)
- def predict_step(self, x, batch_idx):
- x = self.forward(x) # logits from LinearClassifier
+ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx):
+ # see Abstract version
+ if isinstance(x, list) and len(x) == 1:
+ return self(x[0])
+ x = self.forward(x) # type: ignore # logits from LinearHead
# then applies softmax
return F.softmax(x, dim=1)
@@ -369,19 +496,129 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):
images = [img for img in x[:n_images]]
captions = [f'Ground Truth: {y_i} \nPrediction: {y_p_i}' for y_i, y_p_i in zip(
y[:n_images], y_pred_softmax[:n_images])]
- self.logger.log_image(
+ self.logger.log_image( # type: ignore
key='val_images',
images=images,
caption=captions)
+
+class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):
+ """
+ Pretrained Zoobot model intended for finetuning on a regression problem.
+
+ Any args not listed below are passed to :class:``FinetuneableZoobotAbstract`` (for example, `learning_rate`).
+ These are shared between classifier, regressor, and tree models.
+ See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
+
+ Models can be loaded with `FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
+
+
+ Args:
+ loss (str, optional): Loss function to use. Must be one of 'mse', 'mae'. Defaults to 'mse'.
+ unit_interval (bool, optional): If True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults to False.
+
+ """
+
+ def __init__(
+ self,
+ loss:str='mse',
+ unit_interval:bool=False,
+ **super_kwargs) -> None:
+
+ super().__init__(**super_kwargs)
+
+ self.unit_interval = unit_interval
+ if self.unit_interval:
+ logging.info('unit_interval=True, using sigmoid activation for finetunng head')
+ head_activation = torch.nn.functional.sigmoid
+ else:
+ head_activation = None
+
+ logging.info('Using classification head and cross-entropy loss')
+ self.head = LinearHead(
+ input_dim=self.encoder_dim,
+ output_dim=1,
+ dropout_prob=self.dropout_prob,
+ activation=head_activation
+ )
+ if loss in ['mse', 'mean_squared_error']:
+ self.loss = mse_loss
+ elif loss in ['mae', 'mean_absolute_error', 'l1', 'l1_loss']:
+ self.loss = l1_loss
+ else:
+ raise ValueError(f'Loss {loss} not recognised. Must be one of mse, mae')
+
+ # rmse metrics. loss is mse already.
+ self.train_rmse = tm.MeanSquaredError(squared=False)
+ self.val_rmse = tm.MeanSquaredError(squared=False)
+ self.test_rmse = tm.MeanSquaredError(squared=False)
+
+ def step_to_dict(self, y, y_pred, loss):
+ return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y}
+
+ def on_train_batch_end(self, step_output, *args):
+ super().on_train_batch_end(step_output, *args)
+
+ self.train_rmse(step_output['predictions'], step_output['labels'])
+ self.log(
+ 'finetuning/train_rmse',
+ self.train_rmse,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=self.prog_bar
+ )
+
+ def on_validation_batch_end(self, step_output, *args):
+ super().on_validation_batch_end(step_output, *args)
+
+ self.val_rmse(step_output['predictions'], step_output['labels'])
+ self.log(
+ 'finetuning/val_rmse',
+ self.val_rmse,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=self.prog_bar
+ )
+
+ def on_test_batch_end(self, step_output, *args) -> None:
+ super().on_test_batch_end(step_output, *args)
+
+ self.test_rmse(step_output['predictions'], step_output['labels'])
+ self.log(
+ "finetuning/test_rmse",
+ self.test_rmse,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=self.prog_bar
+ )
+
+
+ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx):
+ # see Abstract version
+ if isinstance(x, list) and len(x) == 1:
+ return self(x[0])
+ return self.forward(x)
+
+
class FinetuneableZoobotTree(FinetuneableZoobotAbstract):
"""
- Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem.
+ Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem.
+ Uses Dirichlet-Multinomial loss introduced in GZ DECaLS.
+ Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer,
+ and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution.
+
+ Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only.
+
+ If you're using this, you're probably working on a Galaxy Zoo catalog, and you should Slack Mike!
- You must also pass either ``checkpoint_loc`` (to a saved encoder checkpoint)
- or ``encoder`` (to a pytorch model already loaded in memory).
- See :class:FinetuneableZoobotAbstract for more options.
+ Any args not listed below are passed to :class:``FinetuneableZoobotAbstract`` (for example, `learning_rate`).
+ These are shared between classifier, regressor, and tree models.
+ See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
+
+ Models can be loaded with `FinetuneableZoobotTree(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
Args:
schema (schemas.Schema): description of the layout of the decision tree. See :class:`zoobot.shared.schemas.Schema`.
@@ -410,61 +647,133 @@ def __init__(
self.loss = define_model.get_dirichlet_loss_func(self.schema.question_index_groups)
def upload_images_to_wandb(self, outputs, batch, batch_idx):
- pass # not yet implemented
+ raise NotImplementedError
# other functions are simply inherited from FinetunedZoobotAbstract
-# https://github.com/inigoval/byol/blob/1da1bba7dc5cabe2b47956f9d7c6277decd16cc7/byol_main/networks/models.py#L29
-class LinearClassifier(torch.nn.Module):
- def __init__(self, input_dim, output_dim, dropout_prob=0.5):
+class LinearHead(torch.nn.Module):
+ def __init__(self, input_dim: int, output_dim: int, dropout_prob=0.5, activation=None):
+ """
+ Small utility class for a linear head with dropout and optional choice of activation.
+
+ - Apply dropout to features before the final linear layer.
+ - Apply a final linear layer
+ - Optionally, apply `activation` callable
+
+ Args:
+ input_dim (int): input dim of the linear layer (i.e. the encoder output dimension)
+ output_dim (int): output dim of the linear layer (often e.g. N for N classes, or 1 for regression)
+ dropout_prob (float, optional): Dropout probability. Defaults to 0.5.
+ activation (callable, optional): callable expecting tensor e.g. torch softmax. Defaults to None.
+ """
# input dim is representation dim, output_dim is num classes
- super(LinearClassifier, self).__init__()
+ super(LinearHead, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+
self.dropout = torch.nn.Dropout(p=dropout_prob)
self.linear = torch.nn.Linear(input_dim, output_dim)
+ self.activation = activation
def forward(self, x):
- # returns logits, as recommended for CrossEntropy loss
+ """returns logits, as recommended for CrossEntropy loss
+
+ Args:
+ x (torch.Tensor): encoded representation
+
+ Returns:
+ torch.Tensor: result (see docstring of LinearHead)
+ """
+ #
x = self.dropout(x)
x = self.linear(x)
- return x
+ if self.activation is not None:
+ x = self.activation(x)
+ if self.output_dim == 1:
+ return x.squeeze()
+ else:
+ return x
+
+
+
+def cross_entropy_loss(y_pred: torch.Tensor, y: torch.Tensor, label_smoothing: float=0., weight=None):
+ """
+ Calculate cross-entropy loss with optional label smoothing and class weights. No aggregation applied.
+ Trivial wrapper of torch.nn.functional.cross_entropy with reduction='none'.
+ Args:
+ y_pred (torch.Tensor): ints of shape (batch)
+ y (torch.Tensor): predictions of shape (batch, classes)
+ label_smoothing (float, optional): See docstring of torch.nn.functional.cross_entropy. Defaults to 0..
+ weight (arraylike, optional): See docstring of torch.nn.functional.cross_entropy. Defaults to None.
-def cross_entropy_loss(y_pred, y, label_smoothing=0., weight=None):
- # y should be shape (batch) and ints
- # y_pred should be shape (batch, classes)
- # returns loss of shape (batch)
- # will reduce myself
+ Returns:
+ torch.Tensor: unreduced cross-entropy loss
+ """
return F.cross_entropy(y_pred, y.long(), label_smoothing=label_smoothing, weight=weight, reduction='none')
-def dirichlet_loss(y_pred, y, question_index_groups):
- # aggregation equiv. to sum(axis=1).mean(), but fewer operations
- # returns loss of shape (batch)
- # my func uses sklearn convention y, y_pred
- return losses.calculate_multiquestion_loss(y, y_pred, question_index_groups).mean()*len(question_index_groups)
+def mse_loss(y_pred, y):
+ """
+ Trivial wrapper of torch.nn.functional.mse_loss with reduction='none'.
+
+ Args:
+ y_pred (torch.Tensor): See docstring of torch.nn.functional.mse_loss.
+ y (torch.Tensor): See docstring of torch.nn.functional.mse_loss.
+ Returns:
+ torch.Tensor: See docstring of torch.nn.functional.mse_loss.
+ """
+ return F.mse_loss(y_pred, y, reduction='none')
-class FinetunedZoobotClassifierBaseline(FinetuneableZoobotClassifier):
- # exactly as the Finetuned model above, but with a simple single learning rate
- # useful for training from-scratch model exactly as if it were finetuned, as a baseline
+def l1_loss(y_pred, y):
+ """
+ Trivial wrapper of torch.nn.functional.l1_loss with reduction='none'.
- def configure_optimizers(self):
- head_params = list(self.head.parameters())
- encoder_params = list(self.encoder.parameters())
- return torch.optim.AdamW(head_params + encoder_params, lr=self.learning_rate)
+ Args:
+ y_pred (torch.Tensor): See docstring of torch.nn.functional.l1_loss.
+ y (torch.Tensor): See docstring of torch.nn.functional.l1_loss.
+
+ Returns:
+ torch.Tensor: See docstring of torch.nn.functional.l1_loss.
+ """
+ return F.l1_loss(y_pred, y, reduction='none')
-def load_pretrained_encoder(checkpoint_loc: str) -> torch.nn.Sequential:
+def dirichlet_loss(y_pred: torch.Tensor, y: torch.Tensor, question_index_groups):
"""
+ Calculate Dirichlet-Multinomial loss for a batch of predictions and labels.
+ Returns a scalar loss (ready for gradient descent) by summing across answers and taking a mean across the batch.
+ Reduction equivalent to sum(axis=1).mean(), but with fewer operations.
+
Args:
- checkpoint_loc (str): path to saved LightningModule checkpoint, likely of :class:`ZoobotTree`, :class:`FinetuneableZoobotClassifier`, or :class:`FinetunabelZoobotTree`. Must have .encoder attribute.
+ y_pred (torch.Tensor): Predicted dirichlet distribution, of shape (batch, answers)
+ y (torch.Tensor): Count of volunteer votes for each answer, of shape (batch, answers)
+ question_index_groups (list): Answer indices for each question i.e. [(question.start_index, question.end_index), ...] for all questions. Useful for slicing model predictions by question. See :ref:`schemas`.
Returns:
- torch.nn.Sequential: pretrained PyTorch encoder within that LightningModule.
+ torch.Tensor: Dirichlet-Multinomial loss. Scalar, summing across answers and taking a mean across the batch i.e. sum(axis=1).mean())
+ """
+ # my func uses sklearn convention y, y_pred
+ return losses.calculate_multiquestion_loss(y, y_pred, question_index_groups).mean()*len(question_index_groups)
+
+
+
+def load_pretrained_zoobot(checkpoint_loc: str) -> torch.nn.Module:
"""
- return define_model.ZoobotTree.load_from_checkpoint(
- checkpoint_loc).encoder
+ Args:
+ checkpoint_loc (str): path to saved LightningModule checkpoint, likely of :class:`ZoobotTree`, :class:`FinetuneableZoobotClassifier`, or :class:`FinetunabelZoobotTree`. Must have .zoobot attribute.
+ Returns:
+ torch.nn.Module: pretrained PyTorch encoder within that LightningModule.
+ """
+ if torch.cuda.is_available():
+ map_location = None
+ else:
+ # necessary to load gpu-trained model on cpu
+ map_location = torch.device('cpu')
+ return define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc, map_location=map_location).encoder # type: ignore
+
def get_trainer(
save_dir: str,
@@ -478,9 +787,18 @@ def get_trainer(
**trainer_kwargs
) -> pl.Trainer:
"""
- PyTorch Lightning Trainer that carries out the finetuning process.
+ Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process.
Use like so: trainer.fit(model, datamodule)
+ `get_trainer` args are for common Trainer settings e.g. early stopping checkpointing, etc. By default:
+ - Saves the top-k models based on validation loss
+ - Uses early stopping with `patience` i.e. end training if validation loss does not improve after `patience` epochs.
+ - Monitors the learning rate (useful when using a learning rate scheduler)
+
+ Any extra args not listed below are passed directly to the PyTorch Lightning Trainer.
+ Use this to add any custom configuration not covered by the `get_trainer` args.
+ See https://lightning.ai/docs/pytorch/stable/common/trainer.html
+
Args:
save_dir (str): folder in which to save checkpoints and logs.
file_template (str, optional): custom naming for checkpoint files. See Lightning docs. Defaults to "{epoch}".
@@ -513,10 +831,12 @@ def get_trainer(
patience=patience
)
+ learning_rate_monitor_callback = LearningRateMonitor(logging_interval='epoch')
+
# Initialise pytorch lightning trainer
trainer = pl.Trainer(
logger=logger,
- callbacks=[checkpoint_callback, early_stopping_callback],
+ callbacks=[checkpoint_callback, early_stopping_callback, learning_rate_monitor_callback],
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
@@ -525,51 +845,32 @@ def get_trainer(
return trainer
-# TODO check exactly which layers get FTd
-def is_tuneable(block_of_layers):
- if len(list(block_of_layers.parameters())) == 0:
- logging.info('Skipping block with no params')
- logging.info(block_of_layers)
- return False
- else:
- # currently, allowed to include batchnorm
- return True
-
-def get_batch_norm_params_lighting(parent_module, current_params=[]):
- for child_module in parent_module.children():
- if isinstance(child_module, torch.nn.BatchNorm2d):
- current_params += child_module.parameters()
- else:
- current_params = get_batch_norm_params_lighting(child_module, current_params)
- return current_params
-
-
- # when ready (don't peek often, you'll overfit)
- # trainer.test(model, dataloaders=datamodule)
-
- # return model, checkpoint_callback.best_model_path
- # trainer.callbacks[checkpoint_callback].best_model_path?
-
-# def investigate_structure():
-
-# from zoobot.pytorch.estimators import define_model
+def download_from_name(class_name: str, hub_name: str):
+ """
+ Download a finetuned model from the HuggingFace Hub by name.
+ Used to load pretrained Zoobot models by name, e.g. FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...).
+ Downloaded models are saved to the HuggingFace cache directory for later use (typically ~/.cache/huggingface).
-# model = define_model.get_plain_pytorch_zoobot_model(output_dim=1280, include_top=False)
+ You shouldn't need to call this; it's used internally by the FinetuneableZoobot classes.
-# # print(model)
-# # with include_top=False, first and only child is EffNet
-# effnet_with_pool = list(model.children())[0]
+ Args:
+ class_name (str): one of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree
+ hub_name (str): e.g. mwalmsley/zoobot-encoder-convnext_nano
-# # 0th is actually EffNet, 1st and 2nd are AvgPool and Identity
-# effnet = list(effnet_with_pool.children())[0]
+ Returns:
+ str: path to downloaded model (in HuggingFace cache directory). Likely then loaded by Lightning.
+ """
+ from huggingface_hub import hf_hub_download
-# for layer_n, layer in enumerate(effnet.children()):
-# # first bunch are Sequential module wrapping e.g. 3 MBConv blocks
-# print('\n', layer_n)
-# if isinstance(layer, torch.nn.Sequential):
-# print(layer)
-# # so the blocks to finetune are each Sequential (repeated MBConv) block
-# # and other blocks can be left alone
-# # (also be careful to leave batch-norm alone)
+ if hub_name.startswith('hf_hub:'):
+ logging.info('Passed name with hf_hub: prefix, dropping prefix')
+ repo_id = hub_name.split('hf_hub:')[1]
+ else:
+ repo_id = hub_name
+ downloaded_loc = hf_hub_download(
+ repo_id=repo_id,
+ filename=f"{class_name}.ckpt"
+ )
+ return downloaded_loc
diff --git a/zoobot/pytorch/training/losses.py b/zoobot/pytorch/training/losses.py
index 39c521c9..712b6846 100755
--- a/zoobot/pytorch/training/losses.py
+++ b/zoobot/pytorch/training/losses.py
@@ -1,10 +1,11 @@
from typing import Tuple
+import logging
import torch
import pyro
-def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple) -> torch.Tensor:
+def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple, careful=True) -> torch.Tensor:
"""
The full decision tree loss used for training GZ DECaLS models
@@ -14,11 +15,21 @@ def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor
Args:
labels (torch.Tensor): (galaxy, k successes) where k successes dimension is indexed by question_index_groups.
predictions (torch.Tensor): Dirichlet concentrations, matching shape of labels
- question_index_groups (list): Paired (tuple) integers of (first, last) indices of answers to each question, listed for all questions. See :ref:`schemas`.
+ question_index_groups (list): Answer indices for each question i.e. [(question.start_index, question.end_index), ...] for all questions. Useful for slicing model predictions by question. See :ref:`schemas`.
Returns:
torch.Tensor: neg. log likelihood of shape (batch, question).
"""
+ if careful:
+ # some models give occasional nans for all predictions on a specific galaxy/row
+ # these are inputs to the loss and only happen many epochs in so probably not a case of bad labels, but rather some instability during training
+ # handle this by setting loss=0 for those rows and throwing a warning
+ nan_prediction = torch.isnan(predictions) | torch.isinf(predictions)
+ if nan_prediction.any():
+ logging.warning(f'Found nan values in predictions: {predictions}')
+ safety_value = torch.ones(1, device=predictions.device, dtype=predictions.dtype) # fill with 1 everywhere i.e. fully uncertain
+ predictions = torch.where(condition=nan_prediction, input=safety_value, other=predictions)
+
# very important that question_index_groups is fixed and discrete, else tf.function autograph will mess up
q_losses = []
# will give shape errors if model output dim is not labels dim, which can happen if losses.py substrings are missing an answer
@@ -26,7 +37,6 @@ def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor
q_indices = question_index_groups[q_n]
q_start = q_indices[0]
q_end = q_indices[1]
-
q_loss = dirichlet_loss(labels[:, q_start:q_end+1], predictions[:, q_start:q_end+1])
q_losses.append(q_loss)
@@ -54,7 +64,6 @@ def dirichlet_loss(labels_for_q, concentrations_for_q):
# you will get image batches of shape [N/4, 64, 64, 1] and hence have the wrong number of images vs. labels (and meaningless images)
# so check --shard-img-size carefully!
total_count = torch.sum(labels_for_q, axis=1)
- # logging.info(total_count)
# pytorch dirichlet multinomial implementation will not accept zero total votes, need to handle separately
return get_dirichlet_neg_log_prob(labels_for_q, total_count, concentrations_for_q)
@@ -105,5 +114,7 @@ def dirichlet_loss(labels_for_q, concentrations_for_q):
def get_dirichlet_neg_log_prob(labels_for_q, total_count, concentrations_for_q):
# https://docs.pyro.ai/en/stable/distributions.html#dirichletmultinomial
- dist = pyro.distributions.DirichletMultinomial(total_count=total_count, concentration=concentrations_for_q, is_sparse=False)
- return -dist.log_prob(labels_for_q) # important minus sign
+ # .int()s avoid rounding errors causing loss of around 1e-5 for questions with 0 votes
+ dist = pyro.distributions.DirichletMultinomial(
+ total_count=total_count.int(), concentration=concentrations_for_q, is_sparse=False, validate_args=True)
+ return -dist.log_prob(labels_for_q.int()) # important minus sign
diff --git a/zoobot/pytorch/training/representations.py b/zoobot/pytorch/training/representations.py
index f3f577cc..a350241b 100644
--- a/zoobot/pytorch/training/representations.py
+++ b/zoobot/pytorch/training/representations.py
@@ -1,15 +1,32 @@
+import logging
import pytorch_lightning as pl
+from timm import create_model
+
+
class ZoobotEncoder(pl.LightningModule):
- # very simple wrapper to turn pytorch model into lightning module
- # useful when we want to use lightning to make predictions with our encoder
- # (i.e. to get representations)
- def __init__(self, encoder, pyramid=False) -> None:
+ def __init__(self, encoder):
super().__init__()
+ logging.info('ZoobotEncoder: using provided in-memory encoder')
self.encoder = encoder # plain pytorch module e.g. Sequential
- if pyramid:
- raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features')
+
def forward(self, x):
+ if isinstance(x, list) and len(x) == 1:
+ return self(x[0])
return self.encoder(x)
+
+ @classmethod
+ def load_from_name(cls, name: str):
+ """
+ e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
+ Args:
+ name (str): huggingface hub name to load
+
+ Returns:
+ nn.Module: timm model
+ """
+ timm_model = create_model(name, pretrained=True)
+ return cls(timm_model)
+
\ No newline at end of file
diff --git a/zoobot/pytorch/training/schedulers.py b/zoobot/pytorch/training/schedulers.py
new file mode 100644
index 00000000..c93f844f
--- /dev/null
+++ b/zoobot/pytorch/training/schedulers.py
@@ -0,0 +1,147 @@
+import warnings
+
+import torch
+import numpy as np
+from typing import Optional
+
+
+def cosine_schedule(
+ step: int,
+ max_steps: int,
+ start_value: float,
+ end_value: float,
+ period: Optional[int] = None,
+) -> float:
+ """
+ Use cosine decay to gradually modify start_value to reach target end_value during
+ iterations.
+ Copied from lightly library (thank you for open sourcing)
+
+ Args:
+ step:
+ Current step number.
+ max_steps:
+ Total number of steps.
+ start_value:
+ Starting value.
+ end_value:
+ Target value.
+ period (optional):
+ The number of steps over which the cosine function completes a full cycle.
+ If not provided, it defaults to max_steps.
+
+ Returns:
+ Cosine decay value.
+
+ """
+ if step < 0:
+ raise ValueError("Current step number can't be negative")
+ if max_steps < 1:
+ raise ValueError("Total step number must be >= 1")
+ if period is None and step > max_steps:
+ warnings.warn(
+ f"Current step number {step} exceeds max_steps {max_steps}.",
+ category=RuntimeWarning,
+ )
+ if period is not None and period <= 0:
+ raise ValueError("Period must be >= 1")
+
+ decay: float
+ if period is not None: # "cycle" based on period, if provided
+ decay = (
+ end_value
+ - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
+ )
+ elif max_steps == 1:
+ # Avoid division by zero
+ decay = end_value
+ elif step == max_steps:
+ # Special case for Pytorch Lightning which updates LR scheduler also for epoch
+ # after last training epoch.
+ decay = end_value
+ else:
+ decay = (
+ end_value
+ - (end_value - start_value)
+ * (np.cos(np.pi * step / (max_steps - 1)) + 1)
+ / 2
+ )
+ return decay
+
+
+class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
+ """Cosine warmup scheduler for learning rate.
+
+ Args:
+ optimizer:
+ Optimizer object to schedule the learning rate.
+ warmup_epochs:
+ Number of warmup epochs or steps.
+ max_epochs:
+ Total number of training epochs or steps.
+ last_epoch:
+ The index of last epoch or step. Default: -1
+ start_value:
+ Starting learning rate scale. Default: 1.0
+ end_value:
+ Target learning rate scale. Default: 0.001
+ verbose:
+ If True, prints a message to stdout for each update. Default: False.
+
+ Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
+ can be used. The naming follows the Pytorch convention to use `epoch` for the steps
+ in the scheduler.
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_epochs: int,
+ max_epochs: int,
+ last_epoch: int = -1,
+ start_value: float = 1.0,
+ end_value: float = 0.001,
+ period: Optional[int] = None,
+ verbose: bool = False,
+ ) -> None:
+ self.warmup_epochs = warmup_epochs
+ self.max_epochs = max_epochs
+ self.start_value = start_value
+ self.end_value = end_value
+ self.period = period
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=self.scale_lr,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+ def scale_lr(self, epoch: int) -> float:
+ """
+ Scale learning rate according to the current epoch number.
+
+ Args:
+ epoch:
+ Current epoch number.
+
+ Returns:
+ Scaled learning rate.
+
+ """
+ if epoch < self.warmup_epochs:
+ return self.start_value * (epoch + 1) / self.warmup_epochs
+ elif self.period is not None:
+ return cosine_schedule(
+ step=epoch - self.warmup_epochs,
+ max_steps=1,
+ start_value=self.start_value,
+ end_value=self.end_value,
+ period=self.period,
+ )
+ else:
+ return cosine_schedule(
+ step=epoch - self.warmup_epochs,
+ max_steps=self.max_epochs - self.warmup_epochs,
+ start_value=self.start_value,
+ end_value=self.end_value,
+ )
diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py
index 5690e0e1..2c9e7524 100644
--- a/zoobot/pytorch/training/train_with_pytorch_lightning.py
+++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py
@@ -4,13 +4,17 @@
import torch
import pytorch_lightning as pl
+from pytorch_lightning.plugins import TorchSyncBatchNorm
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+from pytorch_lightning.loggers import CSVLogger
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
from zoobot.pytorch.estimators import define_model
+from zoobot.pytorch.datasets import webdatamodule
+
def train_default_zoobot_from_scratch(
@@ -22,14 +26,18 @@ def train_default_zoobot_from_scratch(
train_catalog=None,
val_catalog=None,
test_catalog=None,
+ train_urls=None,
+ val_urls=None,
+ test_urls=None,
+ cache_dir=None, # only works with webdataset urls
# training time parameters
epochs=1000,
patience=8,
# model hparams
- architecture_name='efficientnet_b0', # recently changed
+ architecture_name='efficientnet_b0',
+ timm_kwargs = {}, # e.g. {'drop_path_rate': 0.2, 'num_features': 1280}. Passed to timm model init method, depends on arch.
batch_size=128,
dropout_rate=0.2,
- drop_connect_rate=0.2,
learning_rate=1e-3,
betas=(0.9, 0.999),
weight_decay=0.01,
@@ -42,9 +50,11 @@ def train_default_zoobot_from_scratch(
# hardware parameters
nodes=1,
gpus=2,
+ sync_batchnorm=False,
num_workers=4,
prefetch_factor=4,
mixed_precision=False,
+ compile_encoder=False,
# checkpointing / logging
wandb_logger=None,
checkpoint_file_template=None,
@@ -56,42 +66,48 @@ def train_default_zoobot_from_scratch(
) -> Tuple[define_model.ZoobotTree, pl.Trainer]:
"""
Train Zoobot from scratch on a big galaxy catalog.
- Zoobot is a base deep learning model (anything from timm, typically a CNN) plus a dirichlet head.
- Images are augmented using the default transforms (flips, rotations, zooms)
- from `the galaxy-datasets repo `_.
- Once trained, Zoobot can be finetuned to new data.
- For finetuning, see zoobot/pytorch/training/finetune.py.
- Many pretrained models are already available - see :ref:`datanotes`.
+ **You don't need to use this**.
+ Training from scratch is becoming increasingly complicated (as you can see from the arguments) due to ongoing research on the best methods.
+ This will be refactored to a dedicated "foundation" repo.
Args:
save_dir (str): folder to save training logs and trained model checkpoints
+ schema (shared.schemas.Schema): Schema object with label_cols, question_answer_pairs, and dependencies
catalog (pd.DataFrame, optional): Galaxy catalog with columns `id_str` and `file_loc`. Will be automatically split to train and val (no test). Defaults to None.
train_catalog (pd.DataFrame, optional): As above, but already split by you for training. Defaults to None.
val_catalog (pd.DataFrame, optional): As above, for validation. Defaults to None.
test_catalog (pd.DataFrame, optional): As above, for testing. Defaults to None.
+ train_urls (list, optional): List of URLs to webdatasets for training. Defaults to None.
+ val_urls (list, optional): List of URLs to webdatasets for validation. Defaults to None.
+ test_urls (list, optional): List of URLs to webdatasets for testing. Defaults to None.
+ cache_dir (str, optional): Directory to cache webdatasets. Defaults to None.
epochs (int, optional): Max. number of epochs to train for. Defaults to 1000.
patience (int, optional): Max. number of epochs to wait for any loss improvement before ending training. Defaults to 8.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to 'efficientnet_b0'.
+ timm_kwargs (dict, optional): Additional kwargs to pass to timm model init method, for example {'drop_connect_rate': 0.2}. Defaults to {}.
+ batch_size (int, optional): Batch size. Defaults to 128.
dropout_rate (float, optional): Randomly drop activations prior to the output layer, with this probability. Defaults to 0.2.
- drop_connect_rate (float, optional): Randomly drop blocks with this probability, for regularisation. For supported timm models only. Defaults to 0.2.
learning_rate (float, optional): Base learning rate for AdamW. Defaults to 1e-3.
betas (tuple, optional): Beta args (i.e. momentum) for adamW. Defaults to (0.9, 0.999).
weight_decay (float, optional): Weight decay arg (i.e. L2 penalty) for AdamW. Defaults to 0.01.
- scheduler_params (dict, optional): Specify a learning rate scheduler. See code. Not recommended with AdamW. Defaults to {}.
+ scheduler_params (dict, optional): Specify a learning rate scheduler. See code below. Defaults to {}.
color (bool, optional): Train on RGB images rather than channel-averaged. Defaults to False.
resize_after_crop (int, optional): Input image size. After all transforms, images will be resized to this size. Defaults to 224.
crop_scale_bounds (tuple, optional): Off-center crop fraction (<1 means zoom in). Defaults to (0.7, 0.8).
crop_ratio_bounds (tuple, optional): Aspect ratio of crop above. Defaults to (0.9, 1.1).
nodes (int, optional): Multi-node training Unlikely to work on your cluster without tinkering. Defaults to 1 (i.e. one node).
gpus (int, optional): Multi-GPU training. Uses distributed data parallel - essentially, full dataset is split by GPU. See torch docs. Defaults to 2.
+ sync_batchnorm (bool, optional): Use synchronized batchnorm. Defaults to False.
num_workers (int, optional): Processes for loading data. See torch dataloader docs. Should be < num cpus. Defaults to 4.
prefetch_factor (int, optional): Num. batches to queue in memory per dataloader. See torch dataloader docs. Defaults to 4.
mixed_precision (bool, optional): Use (mostly) half-precision to halve memory requirements. May cause instability. See Lightning Trainer docs. Defaults to False.
+ compile_encoder (bool, optional): Compile the encoder with torch.compile (new in torch v2). Defaults to False.
wandb_logger (pl.loggers.wandb.WandbLogger, optional): Logger to track experiments on Weights and Biases. Defaults to None.
checkpoint_file_template (str, optional): formatting for checkpoint filename. See Lightning docs. Defaults to None.
auto_insert_metric_name (bool, optional): escape "/" in metric names when naming checkpoints. See Lightning docs. Defaults to True.
save_top_k (int, optional): Keep the k best checkpoints. See Lightning docs. Defaults to 3.
+ extra_callbacks (list, optional): Additional callbacks to pass to the Trainer. Defaults to None.
random_state (int, optional): Seed. Defaults to 42.
Returns:
@@ -106,7 +122,11 @@ def train_default_zoobot_from_scratch(
assert save_dir is not None
if not os.path.isdir(save_dir):
- os.mkdir(save_dir)
+ try:
+ os.mkdir(save_dir)
+ except FileExistsError:
+ pass # another gpu process may have just made it
+ logging.info(f'Saving to {save_dir}')
if color:
logging.warning(
@@ -121,19 +141,18 @@ def train_default_zoobot_from_scratch(
if (gpus is not None) and (gpus > 1):
strategy = DDPStrategy(find_unused_parameters=False) # static_graph=True TODO
logging.info('Using multi-gpu training')
- if nodes > 1: # I assume nobody is doing multi-node cpu training...
- logging.info('Using multi-node training') # purely for your info
+ # if nodes > 1: # I assume nobody is doing multi-node cpu training...
+ # logging.info('Using multi-node training') # purely for your info
# this is only needed for multi-node training
# our cluster sets TASKS_PER_NODE not NTASKS_PER_NODE
# (with srun, SLURM_STEP_TASKS_PER_NODE)
# https://slurm.schedmd.com/srun.html#OPT_SLURM_STEP_TASKS_PER_NODE
- if 'SLURM_NTASKS_PER_NODE' not in os.environ.keys():
- os.environ['SLURM_NTASKS_PER_NODE'] = os.environ['SLURM_TASKS_PER_NODE']
- # from lightning_lite.plugins.environments import SLURMEnvironment
- from zoobot.pytorch import manchester
- logging.warning('Using custom slurm environment')
- # https://pytorch-lightning.readthedocs.io/en/stable/clouds/cluster_advanced.html#enable-auto-wall-time-resubmitions
- plugins = [manchester.ManchesterEnvironment(auto_requeue=False)]
+ if 'SLURM_NTASKS_PER_NODE' not in os.environ.keys():
+ os.environ['SLURM_NTASKS_PER_NODE'] = os.environ['SLURM_TASKS_PER_NODE']
+ from zoobot.pytorch import manchester
+ logging.warning(f'Using custom slurm environment, --n-tasks-per-node={os.environ["SLURM_NTASKS_PER_NODE"]}')
+ # https://pytorch-lightning.readthedocs.io/en/stable/clouds/cluster_advanced.html#enable-auto-wall-time-resubmitions
+ plugins = [manchester.GalahadEnvironment(auto_requeue=False)]
if gpus > 0:
accelerator = 'gpu'
@@ -167,22 +186,6 @@ def train_default_zoobot_from_scratch(
Suggest reducing num_workers."""
)
-
- if catalog is not None:
- assert train_catalog is None
- assert val_catalog is None
- assert test_catalog is None
- catalogs_to_use = {
- 'catalog': catalog
- }
- else:
- assert catalog is None
- catalogs_to_use = {
- 'train_catalog': train_catalog,
- 'val_catalog': val_catalog,
- 'test_catalog': test_catalog # may be None
- }
-
if wandb_logger is not None:
wandb_logger.log_hyperparams({
'random_state': random_state,
@@ -200,45 +203,115 @@ def train_default_zoobot_from_scratch(
'prefetch_factor': prefetch_factor,
'framework': 'pytorch'
})
+ else:
+ logging.warning('No wandb_logger passed. Using CSV logging only')
+ wandb_logger = CSVLogger(save_dir=save_dir)
+
+ # work out what dataset the user has passed
+ single_catalog = catalog is not None
+ split_catalogs = train_catalog is not None
+ webdatasets = train_urls is not None
+
+ if single_catalog or split_catalogs:
+ # this branch will use GalaxyDataModule to load catalogs
+ assert not webdatasets
+ if single_catalog:
+ assert not split_catalogs
+ data_to_use = {
+ 'catalog': catalog
+ }
+ else:
+ data_to_use = {
+ 'train_catalog': train_catalog,
+ 'val_catalog': val_catalog,
+ 'test_catalog': test_catalog # may be None
+ }
+ datamodule = GalaxyDataModule(
+ label_cols=schema.label_cols,
+ # can take either a catalog (and split it), or a pre-split catalog
+ **data_to_use,
+ # augmentations parameters
+ greyscale=not color,
+ crop_scale_bounds=crop_scale_bounds,
+ crop_ratio_bounds=crop_ratio_bounds,
+ resize_after_crop=resize_after_crop,
+ # hardware parameters
+ batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch)
+ num_workers=num_workers,
+ prefetch_factor=prefetch_factor
+ )
+ else:
+ # this branch will use WebDataModule to load premade webdatasets
+
+ # temporary: use SSL-like transform
+ # from foundation.models import transforms
+ # train_transform_cfg = transforms.default_view_config()
+ # inference_transform_cfg = transforms.minimal_view_config()
+ # train_transform_cfg.output_size = resize_after_crop
+ # inference_transform_cfg.output_size = resize_after_crop
+
+ datamodule = webdatamodule.WebDataModule(
+ train_urls=train_urls,
+ val_urls=val_urls,
+ test_urls=test_urls,
+ label_cols=schema.label_cols,
+ # hardware
+ batch_size=batch_size,
+ num_workers=num_workers,
+ prefetch_factor=prefetch_factor,
+ cache_dir=cache_dir,
+ # augmentation args
+ greyscale=not color,
+ crop_scale_bounds=crop_scale_bounds,
+ crop_ratio_bounds=crop_ratio_bounds,
+ resize_after_crop=resize_after_crop,
+ # temporary: use SSL-like transform
+ # train_transform=transforms.GalaxyViewTransform(train_transform_cfg),
+ # inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg),
+ )
- datamodule = GalaxyDataModule(
- label_cols=schema.label_cols,
- # can take either a catalog (and split it), or a pre-split catalog
- **catalogs_to_use,
- # augmentations parameters
- greyscale=not color,
- crop_scale_bounds=crop_scale_bounds,
- crop_ratio_bounds=crop_ratio_bounds,
- resize_after_crop=resize_after_crop,
- # hardware parameters
- batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch)
- num_workers=num_workers,
- prefetch_factor=prefetch_factor
- )
+ # debug - check range of loaded images, should be 0-1
datamodule.setup(stage='fit')
+ for (images, _) in datamodule.train_dataloader():
+ logging.info(f'Using batches of {images.shape[0]} images for training')
+ logging.info('First batch image min/max: {}/{}'.format(images.min(), images.max()))
+ assert images.max() <= 1.0
+ assert images.min() >= 0.0
+ break
+ # exit()
# these args are automatically logged
lightning_model = define_model.ZoobotTree(
output_dim=len(schema.label_cols),
- question_index_groups=schema.question_index_groups,
+ # NEW - pass these from schema, for better logging
+ question_answer_pairs=schema.question_answer_pairs,
+ dependencies=schema.dependencies,
architecture_name=architecture_name,
channels=channels,
- use_imagenet_weights=False,
test_time_dropout=True,
dropout_rate=dropout_rate,
learning_rate=learning_rate,
- timm_kwargs={'drop_path_rate': drop_connect_rate},
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientnet.py#L75C9-L75C17
+ timm_kwargs=timm_kwargs,
+ compile_encoder=compile_encoder,
betas=betas,
weight_decay=weight_decay,
scheduler_params=scheduler_params
)
+
+ if sync_batchnorm:
+ logging.info('Using sync batchnorm')
+ lightning_model = TorchSyncBatchNorm().apply(lightning_model)
+
extra_callbacks = extra_callbacks if extra_callbacks else []
+ monitor_metric = 'validation/supervised_loss'
+
# used later for checkpoint_callback.best_model_path
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(save_dir, 'checkpoints'),
- monitor="validation/epoch_loss",
+ monitor=monitor_metric,
save_weights_only=True,
mode='min',
# custom filename for checkpointing due to / in metric
@@ -249,12 +322,12 @@ def train_default_zoobot_from_scratch(
save_top_k=save_top_k
)
- early_stopping_callback = EarlyStopping(monitor='validation/epoch_loss', patience=patience, check_finite=True)
-
+ early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True)
callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks
trainer = pl.Trainer(
- log_every_n_steps=150, # at batch 512 (A100 MP max), DR5 has ~161 train steps
+ num_sanity_val_steps=0,
+ log_every_n_steps=150,
accelerator=accelerator,
devices=devices, # per node
num_nodes=nodes,
@@ -264,34 +337,21 @@ def train_default_zoobot_from_scratch(
callbacks=callbacks,
max_epochs=epochs,
default_root_dir=save_dir,
- plugins=plugins
+ plugins=plugins,
+ gradient_clip_val=.3 # reduced from 1 to .3, having some nan issues
)
- logging.info((trainer.strategy, trainer.world_size,
- trainer.local_rank, trainer.global_rank, trainer.node_rank))
-
trainer.fit(lightning_model, datamodule) # uses batch size of datamodule
- test_trainer = pl.Trainer(
- accelerator=accelerator,
- devices=1,
- precision=precision,
- logger=wandb_logger,
- default_root_dir=save_dir
- )
-
best_model_path = trainer.checkpoint_callback.best_model_path
# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
- if test_catalog is not None:
+ if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
- test_trainer.validate(
- model=lightning_model,
- datamodule=datamodule,
- ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
- )
- test_trainer.test(
+ datamodule.setup(stage='test')
+ # TODO with webdataset, no need for new trainer/datamodule (actually it breaks), but might still be needed with normal dataset?
+ trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
diff --git a/zoobot/shared/load_predictions.py b/zoobot/shared/load_predictions.py
index 9373b488..874ae9f9 100644
--- a/zoobot/shared/load_predictions.py
+++ b/zoobot/shared/load_predictions.py
@@ -49,6 +49,7 @@ def load_hdf5s(hdf5_locs: List):
'id_str': f['id_str'].asstr()[:],
'hdf5_loc': [os.path.basename(loc) for _ in these_predictions]
}
+ assert len(these_predictions) == len(these_prediction_metadata['id_str']), (loc, len(these_predictions), len(these_prediction_metadata['id_str']) )
predictions.append(these_predictions) # will create a list where each element is 3D predictions stored in each hdf5
prediction_metadata.append(these_prediction_metadata) # also track id_str, similarly
@@ -72,7 +73,7 @@ def load_hdf5s(hdf5_locs: List):
'id_str': [p for metadata in prediction_metadata for p in metadata['id_str']],
'hdf5_loc': [l for metadata in prediction_metadata for l in metadata['hdf5_loc']]
}
- assert len(prediction_metadata['id_str']) == len(predictions)
+ assert len(prediction_metadata['id_str']) == len(predictions), (len(prediction_metadata['id_str']), len(predictions))
galaxy_id_df = pd.DataFrame(data=prediction_metadata)
@@ -163,10 +164,12 @@ def prediction_hdf5_to_summary_parquet(hdf5_loc: str, save_loc: str, schema: sch
upper_edge_cols = [col + '_90pc-upper' for col in label_cols]
proportion_asked_cols = [col + '_proportion-asked' for col in label_cols]
+
# make friendly dataframe with just masked fraction and description string
friendly_loc = save_loc.replace('.parquet', '_friendly.parquet')
fraction_df = pd.DataFrame(data=masked_fractions, columns=fraction_cols)
friendly_df = pd.concat([galaxy_id_df, fraction_df], axis=1)
+ friendly_df = convert_halfprecision_cols(friendly_df)
friendly_df.to_parquet(friendly_loc, index=False)
logging.info('Friendly summary table saved to {}'.format(friendly_loc))
@@ -177,11 +180,19 @@ def prediction_hdf5_to_summary_parquet(hdf5_loc: str, save_loc: str, schema: sch
upper_edge_df = pd.DataFrame(data=all_upper_edges, columns=upper_edge_cols)
proportion_df = pd.DataFrame(data=prob_of_asked_by_answer, columns=proportion_asked_cols)
advanced_df = pd.concat([galaxy_id_df, fraction_df, lower_edge_df, upper_edge_df, proportion_df], axis=1)
+ advanced_df = convert_halfprecision_cols(advanced_df)
advanced_df.to_parquet(advanced_loc, index=False)
logging.info('Advanced summary table saved to {}'.format(advanced_loc))
-def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
+def convert_halfprecision_cols(df):
+ # convert any half-precision columns, parquet can't save these
+ half_floats = df.select_dtypes(include="float16")
+ df[half_floats.columns] = half_floats.astype("float32")
+ return df
+
+
+def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False, subset_frac=None):
"""
Load predictions (or representations) saved as hdf5 into pd.DataFrame with id_str and label_cols columns
@@ -197,9 +208,11 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
_type_: _description_
"""
galaxy_id_df, predictions, label_cols = load_hdf5s(hdf5_locs)
+ logging.info('HDF5s loaded.')
predictions = predictions.squeeze()
-
+
+
if len(predictions.shape) > 2:
if drop_extra_dims:
predictions = predictions[:, :, 0]
@@ -210,9 +223,18 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
I suggest using load_hdf5s directly to work with np.arrays, not with DataFrame - see docstring'
)
prediction_df = pd.DataFrame(data=predictions, columns=label_cols)
+
+ if subset_frac is not None:
+ logging.warning('Selecting a random subset: {}'.format(subset_frac))
+ prediction_df = prediction_df.sample(frac=subset_frac, random_state=42)
+
+
+ del predictions
+ logging.info('Saving')
# copy over metadata (indices will align)
prediction_df['id_str'] = galaxy_id_df['id_str']
prediction_df['hdf5_loc'] = galaxy_id_df['hdf5_loc']
+ prediction_df = convert_halfprecision_cols(prediction_df)
return prediction_df
diff --git a/zoobot/shared/save_predictions.py b/zoobot/shared/save_predictions.py
index 5505ecc9..5e43e592 100755
--- a/zoobot/shared/save_predictions.py
+++ b/zoobot/shared/save_predictions.py
@@ -10,6 +10,9 @@
def predictions_to_hdf5(predictions, id_str, label_cols, save_loc, compression="gzip"):
logging.info(f'Saving predictions to {save_loc}')
assert save_loc.endswith('.hdf5')
+ if label_cols is None:
+ label_cols = get_default_label_cols(predictions)
+ # sometimes throws a "could not lock file" error but still saves fine. I don't understand why
with h5py.File(save_loc, "w") as f:
f.create_dataset(name='predictions', data=predictions, compression=compression)
# https://docs.h5py.org/en/stable/special.html#h5py.string_dtype
@@ -17,12 +20,13 @@ def predictions_to_hdf5(predictions, id_str, label_cols, save_loc, compression="
# predictions_dset.attrs['label_cols'] = label_cols # would be more conventional but is a little awkward
f.create_dataset(name='id_str', data=id_str, dtype=dt)
f.create_dataset(name='label_cols', data=label_cols, dtype=dt)
- # sometimes throws a "could not lock file" error but still saves fine. I don't understand why
-
def predictions_to_csv(predictions, id_str, label_cols, save_loc):
+
# not recommended - hdf5 is much more flexible and pretty easy to use once you check the package quickstart
assert save_loc.endswith('.csv')
+ if label_cols is None:
+ label_cols = get_default_label_cols(predictions)
data = [prediction_to_row(predictions[n], id_str[n], label_cols=label_cols) for n in range(len(predictions))]
predictions_df = pd.DataFrame(data)
# logging.info(predictions_df)
@@ -57,3 +61,8 @@ def prediction_to_row(prediction: np.ndarray, id_str: str, label_cols: List):
else:
row[answer + '_pred'] = json.dumps(list(answer_pred)) # it's not a scalar, write as json
return row
+
+def get_default_label_cols(predictions):
+ logging.warning('No label_cols passed - using default names e.g. feat_0, feat_1...')
+ label_cols = [f'feat_{n}' for n in range(predictions.shape[1])]
+ return label_cols
diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py
index 8d32f878..3f85dbbe 100755
--- a/zoobot/shared/schemas.py
+++ b/zoobot/shared/schemas.py
@@ -130,7 +130,8 @@ def set_dependencies(questions, dependencies):
class Schema():
- def __init__(self, question_answer_pairs:dict, dependencies):
+
+ def __init__(self, question_answer_pairs:dict, dependencies: dict):
"""
Relate the df label columns tor question/answer groups and to tfrecod label indices
Requires that labels be continguous by question - easily satisfied
@@ -141,6 +142,23 @@ def __init__(self, question_answer_pairs:dict, dependencies):
- answers in between will be included: these are used to slice
- df columns must be contigious by question (e.g. not smooth_yes, bar_no, smooth_no) for this to work!
+ The following schemas are available via the module (e.g. `from zoobot.shared.schemas import decals_dr5_ortho_schema`):
+ - decals_dr5_ortho_schema
+ - decals_dr8_ortho_schema
+ - decals_all_campaigns_ortho_schema
+ - gz2_ortho_schema
+ - gz_candels_ortho_schema
+ - gz_hubble_ortho_schema
+ - cosmic_dawn_ortho_schema
+ - cosmic_dawn_schema
+ - gz_rings_schema
+ - desi_schema
+ - gz_evo_v1_schema (this is the schema currently used for pretraining)
+ - gz_ukidss_schema
+ - gz_jwst_schema
+
+ "ortho" refers to the orthogonal question suffix (-cd, -dr8, etc).
+
Args:
question_answer_pairs (dict): e.g. {'smooth-or-featured: ['_smooth, _featured-or-disk, ...], ...}
dependencies (dict): dict mapping each question (e.g. disk-edge-on) to the answer on which it depends (e.g. smooth-or-featured_featured-or-disk)
@@ -278,3 +296,6 @@ def answers(self):
# so don't log anything during Schema.__init__!
gz_evo_v1_schema = Schema(label_metadata.gz_evo_v1_pairs, label_metadata.gz_evo_v1_dependencies)
+
+gz_ukidss_schema = Schema(label_metadata.ukidss_ortho_pairs, label_metadata.ukidss_ortho_dependencies)
+gz_jwst_schema = Schema(label_metadata.jwst_ortho_pairs, label_metadata.jwst_ortho_dependencies)
diff --git a/zoobot/tensorflow/training/losses.py b/zoobot/tensorflow/training/losses.py
index 12e5efa4..443ef117 100755
--- a/zoobot/tensorflow/training/losses.py
+++ b/zoobot/tensorflow/training/losses.py
@@ -12,7 +12,7 @@ def get_multiquestion_loss(question_index_groups, sum_over_questions=True, reduc
tf.keras.losses.Reduction.SUM will simply add everything up, so divide by the global batch size externally with tf.reduce_sum
Args:
- question_index_groups (list): Answer indices for each question i.e. [(question.start_index, question.end_index), ...] for all questions. Useful for slicing model predictions by question.
+ question_index_groups (list): Answer indices for each question i.e. [(question.start_index, question.end_index), ...] for all questions. Useful for slicing model predictions by question. See :ref:`schemas`.
Returns:
MultiquestionLoss: see above.
@@ -36,7 +36,7 @@ def calculate_multiquestion_loss(labels, predictions, question_index_groups, sum
Args:
labels (tf.Tensor): (galaxy, k successes) where k successes dimension is indexed by question_index_groups.
predictions (tf.Tensor): Dirichlet concentrations, matching shape of labels
- question_index_groups (list): Paired (tuple) integers of (first, last) indices of answers to each question, listed for all questions.
+ question_index_groups (list): Answer indices for each question i.e. [(question.start_index, question.end_index), ...] for all questions. Useful for slicing model predictions by question. See :ref:`schemas`.
Returns:
tf.Tensor: neg. log likelihood of shape (batch, question).