diff --git a/only_for_me/narval/finetune.py b/only_for_me/narval/finetune.py index f56db02f..e4bf8580 100644 --- a/only_for_me/narval/finetune.py +++ b/only_for_me/narval/finetune.py @@ -2,6 +2,8 @@ import os import shutil +from pytorch_lightning.loggers import WandbLogger + from zoobot.pytorch.training import finetune from galaxy_datasets import galaxy_mnist from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule @@ -19,9 +21,10 @@ # logging.info(glob.glob(os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/data/galaxy_mnist'))) import torch + torch.set_float32_matmul_precision('medium') assert torch.cuda.is_available() - batch_size = 128 + batch_size = 256 num_workers= 8 n_blocks = 1 # EffnetB0 is divided into 7 blocks. set 0 to only fit the head weights. Set 1, 2, etc to finetune deeper. max_epochs = 6 # 6 epochs should get you ~93% accuracy. Set much higher (e.g. 1000) for harder problems, to use Zoobot's default early stopping. \ @@ -36,6 +39,8 @@ # load a pretrained checkpoint saved here # rsync -avz --no-g --no-p /home/walml/repos/zoobot/data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt walml@narval.alliancecan.ca:/project/def-bovy/walml/zoobot/data/pretrained_models/pytorch checkpoint_loc = '/project/def-bovy/walml/zoobot/data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt' + + logger = WandbLogger(name='debug', save_dir='/project/def-bovy/walml/wandb/debug', project='narval', log_model=False, offline=True) datamodule = GalaxyDataModule( label_cols=label_cols, @@ -49,6 +54,14 @@ num_classes=num_classes, n_blocks=n_blocks ) - trainer = finetune.get_trainer(os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/checkpoints'), accelerator='auto', max_epochs=max_epochs) + trainer = finetune.get_trainer( + os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/checkpoints'), + accelerator='gpu', + devices=2, + strategy='ddp', + precision='16-mixed', + max_epochs=max_epochs, + logger=logger + ) trainer.fit(model, datamodule) - trainer.test(model, datamodule) + # trainer.test(model, datamodule) diff --git a/only_for_me/narval/finetune.sh b/only_for_me/narval/finetune.sh index e2cc3278..8f076053 100644 --- a/only_for_me/narval/finetune.sh +++ b/only_for_me/narval/finetune.sh @@ -1,9 +1,15 @@ #!/bin/bash -#SBATCH --mem=16G +#SBATCH --mem=32G #SBATCH --nodes=1 -#SBATCH --ntasks-per-node=8 -#SBATCH --time=0:15:0 -#SBATCH --gres=gpu:a100:1 +#SBATCH --time=0:10:0 +#SBATCH --ntasks-per-node=16 +#SBATCH --gres=gpu:a100:2 + +#### SBATCH --mem=16G +#### SBATCH --nodes=1 +#### SBATCH --time=0:10:0 +#### SBATCH --ntasks-per-node=8 +#### SBATCH --gres=gpu:a100:1 #### SBATCH --mail-user= #### SBATCH --mail-type=ALL @@ -22,6 +28,10 @@ cp -r /project/def-bovy/walml/data/roots/galaxy_mnist $SLURM_TMPDIR/walml/finetu ls $SLURM_TMPDIR/walml/finetune/data/galaxy_mnist +pip install --no-index wandb + +wandb offline # only write metadata locally + $PYTHON /project/def-bovy/walml/zoobot/only_for_me/narval/finetune.py ls $SLURM_TMPDIR/walml/finetune/checkpoints