Skip to content

Commit

Permalink
try wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Oct 18, 2023
1 parent da92e41 commit b063b49
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
19 changes: 16 additions & 3 deletions only_for_me/narval/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import shutil

from pytorch_lightning.loggers import WandbLogger

from zoobot.pytorch.training import finetune
from galaxy_datasets import galaxy_mnist
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
Expand All @@ -19,9 +21,10 @@
# logging.info(glob.glob(os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/data/galaxy_mnist')))

import torch
torch.set_float32_matmul_precision('medium')
assert torch.cuda.is_available()

batch_size = 128
batch_size = 256
num_workers= 8
n_blocks = 1 # EffnetB0 is divided into 7 blocks. set 0 to only fit the head weights. Set 1, 2, etc to finetune deeper.
max_epochs = 6 # 6 epochs should get you ~93% accuracy. Set much higher (e.g. 1000) for harder problems, to use Zoobot's default early stopping. \
Expand All @@ -36,6 +39,8 @@
# load a pretrained checkpoint saved here
# rsync -avz --no-g --no-p /home/walml/repos/zoobot/data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt [email protected]:/project/def-bovy/walml/zoobot/data/pretrained_models/pytorch
checkpoint_loc = '/project/def-bovy/walml/zoobot/data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt'

logger = WandbLogger(name='debug', save_dir='/project/def-bovy/walml/wandb/debug', project='narval', log_model=False, offline=True)

datamodule = GalaxyDataModule(
label_cols=label_cols,
Expand All @@ -49,6 +54,14 @@
num_classes=num_classes,
n_blocks=n_blocks
)
trainer = finetune.get_trainer(os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/checkpoints'), accelerator='auto', max_epochs=max_epochs)
trainer = finetune.get_trainer(
os.path.join(os.environ['SLURM_TMPDIR'], 'walml/finetune/checkpoints'),
accelerator='gpu',
devices=2,
strategy='ddp',
precision='16-mixed',
max_epochs=max_epochs,
logger=logger
)
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
# trainer.test(model, datamodule)
18 changes: 14 additions & 4 deletions only_for_me/narval/finetune.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#!/bin/bash
#SBATCH --mem=16G
#SBATCH --mem=32G
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --time=0:15:0
#SBATCH --gres=gpu:a100:1
#SBATCH --time=0:10:0
#SBATCH --ntasks-per-node=16
#SBATCH --gres=gpu:a100:2

#### SBATCH --mem=16G
#### SBATCH --nodes=1
#### SBATCH --time=0:10:0
#### SBATCH --ntasks-per-node=8
#### SBATCH --gres=gpu:a100:1

#### SBATCH --mail-user=<[email protected]>
#### SBATCH --mail-type=ALL
Expand All @@ -22,6 +28,10 @@ cp -r /project/def-bovy/walml/data/roots/galaxy_mnist $SLURM_TMPDIR/walml/finetu

ls $SLURM_TMPDIR/walml/finetune/data/galaxy_mnist

pip install --no-index wandb

wandb offline # only write metadata locally

$PYTHON /project/def-bovy/walml/zoobot/only_for_me/narval/finetune.py

ls $SLURM_TMPDIR/walml/finetune/checkpoints

0 comments on commit b063b49

Please sign in to comment.