Skip to content

Commit

Permalink
set up for 300px training runs
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 23, 2023
1 parent b8d3fc0 commit 1b169ab
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 29 deletions.
23 changes: 13 additions & 10 deletions only_for_me/narval/make_webdataset_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dataset_to_webdataset(dataset_name, dataset_func, label_cols, divisor=4096):
catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=divisor)


def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=4096):
def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=4096):
for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
n_shards = len(catalog) // divisor
logging.info(n_shards)
Expand All @@ -33,7 +33,7 @@ def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog

save_loc = f"/home/walml/data/wds/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically

webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards)
webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df)

# webdataset_utils.load_wds_directly(save_loc)

Expand All @@ -53,7 +53,8 @@ def main():


# for converting other catalogs e.g. DESI
dataset_name = 'desi_labelled'
dataset_name = 'desi_labelled_300px_2048'
# dataset_name = 'desi_all_2048'
label_cols = label_metadata.decals_all_campaigns_ortho_label_cols
columns = [
'dr8_id', 'brickid', 'objid', 'ra', 'dec'
Expand Down Expand Up @@ -84,18 +85,20 @@ def main():
# print(len(df_dedup2))
# df_dedup.to_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet')


df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet')
df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet')
# df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet')
df_dedup['id_str'] = df_dedup['dr8_id']

# columns = ['id_str', 'smooth-or-featured-dr12_total-votes', 'smooth-or-featured-dr5_total-votes', 'smooth-or-featured-dr8_total-votes']

df_dedup_with_votes = pd.merge(df_dedup, votes, how='left', on='dr8_id')
# gets too big, need to only merge in label_df per shard
# df_dedup_with_votes = pd.merge(df_dedup, votes, how='left', on='dr8_id')

train_catalog, test_catalog = train_test_split(df_dedup_with_votes, test_size=0.2, random_state=42)
train_catalog.to_parquet('/home/walml/data/wds/desi_all/train_catalog_v1.parquet', index=False)
test_catalog.to_parquet('/home/walml/data/wds/desi_all/test_catalog_v1.parquet', index=False)
train_catalog, test_catalog = train_test_split(df_dedup, test_size=0.2, random_state=42)
train_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/train_catalog_v1.parquet', index=False)
test_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/test_catalog_v1.parquet', index=False)

catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=2048)
catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=2048, sparse_label_df=votes)



Expand Down
7 changes: 4 additions & 3 deletions only_for_me/narval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@

if os.path.isdir('/home/walml/repos/zoobot'):
logging.warning('local mode')
search_str = '/home/walml/data/wds/desi_labelled_2048/desi_labelled_train_*.tar'
search_str = '/home/walml/data/wds/desi_labelled_300px_2048/desi_labelled_train_*.tar'
cache_dir = None

else:
search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_2048/desi_labelled_train_*.tar'
search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_300px_2048/desi_labelled_train_*.tar'
cache_dir = os.environ['SLURM_TMPDIR'] + '/cache'

all_urls = glob.glob(search_str)
Expand Down Expand Up @@ -122,7 +122,8 @@
compile_encoder=args.compile_encoder, # NEW
random_state=random_state,
learning_rate=1e-3,
cache_dir=cache_dir
cache_dir=cache_dir,
crop_scale_bounds=(0.75, 0.85) # slightly increased to compensate for 424-400px crop when saving webdataset
# cache_dir='/tmp/cache'
# /tmp for ramdisk (400GB total, vs 4TB total for nvme)
)
Expand Down
6 changes: 3 additions & 3 deletions only_for_me/narval/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --time=23:30:0
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=20
#SBATCH --cpus-per-task=10
#SBATCH --mem-per-cpu 4G
#SBATCH --gres=gpu:v100:1

Expand All @@ -21,11 +21,11 @@ export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use t

REPO_DIR=/project/def-bovy/walml/zoobot
srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \
--save-dir $REPO_DIR/only_for_me/narval/desi_f128_2gpu \
--save-dir $REPO_DIR/only_for_me/narval/desi_300px_f128_1gpu \
--batch-size 256 \
--num-features 128 \
--gpus 1 \
--num-workers 20 \
--num-workers 10 \
--color --wandb --mixed-precision --compile-encoder

# srun python $SLURM_TMPDIR/zoobot/only_for_me/narval/finetune.py
Expand Down
18 changes: 9 additions & 9 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ def make_image_transform(self, mode="train"):
# if mode == "train":
# elif mode == "val":

# augmentation_transform = transforms.default_transforms(
# crop_scale_bounds=self.crop_scale_bounds,
# crop_ratio_bounds=self.crop_ratio_bounds,
# resize_after_crop=self.resize_after_crop,
# pytorch_greyscale=not self.color
# ) # A.Compose object

logging.warning('Minimal augmentations for speed test')
augmentation_transform = transforms.minimal_transforms(
augmentation_transform = transforms.default_transforms(
crop_scale_bounds=self.crop_scale_bounds,
crop_ratio_bounds=self.crop_ratio_bounds,
resize_after_crop=self.resize_after_crop,
pytorch_greyscale=not self.color
) # A.Compose object

# logging.warning('Minimal augmentations for speed test')
# augmentation_transform = transforms.fast_transforms(
# resize_after_crop=self.resize_after_crop,
# pytorch_greyscale=not self.color
# ) # A.Compose object


def do_transform(img):
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
Expand Down
61 changes: 57 additions & 4 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from itertools import islice
import glob


import albumentations as A

import tqdm
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,26 +40,76 @@ def make_mock_wds(save_dir: str, label_cols: List, n_shards: int, shard_size: in



def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int):
def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None):

assert '.tar' in save_loc
df['id_str'] = df['id_str'].astype(str).str.replace('.', '_')

if sparse_label_df is not None:
logging.info(f'Using sparse label df: {len(sparse_label_df)}')
shard_dfs = np.array_split(df, n_shards)
logging.info(f'shards: {len(shard_dfs)}. Shard size: {len(shard_dfs[0])}')

transforms_to_apply = [
# below, for 224px fast training fast augs setup
# A.Resize(
# height=350, # now more aggressive, 65% crop effectively
# width=350, # now more aggressive, 65% crop effectively
# interpolation=cv2.INTER_AREA # slow and good interpolation
# ),
# A.CenterCrop(
# height=224,
# width=224,
# always_apply=True
# ),
# below, for standard training default augs
# small boundary trim and then resize expecting further 224px crop
# we want 0.7-0.8 effective crop
# in augs that could be 0.x-1.0, and here a pre-crop to 0.8 i.e. 340px
# but this would change the centering
# let's stick to small boundary crop and 0.75-0.85 in augs
A.CenterCrop(
height=400,
width=400,
always_apply=True
),
A.Resize(
height=300,
width=300,
interpolation=cv2.INTER_AREA # slow and good interpolation
)
]
transform = A.Compose(transforms_to_apply)
# transform = None

for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)):
if sparse_label_df is not None:
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
logging.info(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for _, galaxy in shard_df.iterrows():
sink.write(galaxy_to_wds(galaxy, label_cols))
sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
sink.close()


def galaxy_to_wds(galaxy: pd.Series, label_cols):
def galaxy_to_wds(galaxy: pd.Series, label_cols, transform=None):

im = cv2.imread(galaxy['file_loc'])
# cv2 loads BGR for 'history', fix
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# if central_crop is not None:
# width, height, _ = im.shape
# # assert width == height, (width, height)
# mid = int(width/2)
# half_central_crop = int(central_crop/2)
# low_edge, high_edge = mid - half_central_crop, mid + half_central_crop
# im = im[low_edge:high_edge, low_edge:high_edge]
# assert im.shape == (central_crop, central_crop, 3)

# apply albumentations
if transform is not None:
im = transform(image=im)['image']

labels = json.dumps(galaxy[label_cols].to_dict())
id_str = str(galaxy['id_str'])
return {
Expand Down

0 comments on commit 1b169ab

Please sign in to comment.