From 1b169ab44f35a7d78b63d342a1961970e26c3f3c Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Thu, 23 Nov 2023 11:06:42 -0500 Subject: [PATCH] set up for 300px training runs --- only_for_me/narval/make_webdataset_script.py | 23 ++++---- only_for_me/narval/train.py | 7 ++- only_for_me/narval/train.sh | 6 +- zoobot/pytorch/datasets/webdatamodule.py | 18 +++--- zoobot/pytorch/datasets/webdataset_utils.py | 61 ++++++++++++++++++-- 5 files changed, 86 insertions(+), 29 deletions(-) diff --git a/only_for_me/narval/make_webdataset_script.py b/only_for_me/narval/make_webdataset_script.py index 9cb9e824..2113136a 100644 --- a/only_for_me/narval/make_webdataset_script.py +++ b/only_for_me/narval/make_webdataset_script.py @@ -23,7 +23,7 @@ def dataset_to_webdataset(dataset_name, dataset_func, label_cols, divisor=4096): catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=divisor) -def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=4096): +def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=4096): for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]: n_shards = len(catalog) // divisor logging.info(n_shards) @@ -33,7 +33,7 @@ def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog save_loc = f"/home/walml/data/wds/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically - webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards) + webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df) # webdataset_utils.load_wds_directly(save_loc) @@ -53,7 +53,8 @@ def main(): # for converting other catalogs e.g. DESI - dataset_name = 'desi_labelled' + dataset_name = 'desi_labelled_300px_2048' + # dataset_name = 'desi_all_2048' label_cols = label_metadata.decals_all_campaigns_ortho_label_cols columns = [ 'dr8_id', 'brickid', 'objid', 'ra', 'dec' @@ -84,18 +85,20 @@ def main(): # print(len(df_dedup2)) # df_dedup.to_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet') - - df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet') + df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet') + # df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet') + df_dedup['id_str'] = df_dedup['dr8_id'] # columns = ['id_str', 'smooth-or-featured-dr12_total-votes', 'smooth-or-featured-dr5_total-votes', 'smooth-or-featured-dr8_total-votes'] - df_dedup_with_votes = pd.merge(df_dedup, votes, how='left', on='dr8_id') + # gets too big, need to only merge in label_df per shard + # df_dedup_with_votes = pd.merge(df_dedup, votes, how='left', on='dr8_id') - train_catalog, test_catalog = train_test_split(df_dedup_with_votes, test_size=0.2, random_state=42) - train_catalog.to_parquet('/home/walml/data/wds/desi_all/train_catalog_v1.parquet', index=False) - test_catalog.to_parquet('/home/walml/data/wds/desi_all/test_catalog_v1.parquet', index=False) + train_catalog, test_catalog = train_test_split(df_dedup, test_size=0.2, random_state=42) + train_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/train_catalog_v1.parquet', index=False) + test_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/test_catalog_v1.parquet', index=False) - catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=2048) + catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=2048, sparse_label_df=votes) diff --git a/only_for_me/narval/train.py b/only_for_me/narval/train.py index e9de4a0c..5f74f156 100644 --- a/only_for_me/narval/train.py +++ b/only_for_me/narval/train.py @@ -64,11 +64,11 @@ if os.path.isdir('/home/walml/repos/zoobot'): logging.warning('local mode') - search_str = '/home/walml/data/wds/desi_labelled_2048/desi_labelled_train_*.tar' + search_str = '/home/walml/data/wds/desi_labelled_300px_2048/desi_labelled_train_*.tar' cache_dir = None else: - search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_2048/desi_labelled_train_*.tar' + search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_300px_2048/desi_labelled_train_*.tar' cache_dir = os.environ['SLURM_TMPDIR'] + '/cache' all_urls = glob.glob(search_str) @@ -122,7 +122,8 @@ compile_encoder=args.compile_encoder, # NEW random_state=random_state, learning_rate=1e-3, - cache_dir=cache_dir + cache_dir=cache_dir, + crop_scale_bounds=(0.75, 0.85) # slightly increased to compensate for 424-400px crop when saving webdataset # cache_dir='/tmp/cache' # /tmp for ramdisk (400GB total, vs 4TB total for nvme) ) diff --git a/only_for_me/narval/train.sh b/only_for_me/narval/train.sh index f1ee9ecc..ec66aab8 100644 --- a/only_for_me/narval/train.sh +++ b/only_for_me/narval/train.sh @@ -2,7 +2,7 @@ #SBATCH --time=23:30:0 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=20 +#SBATCH --cpus-per-task=10 #SBATCH --mem-per-cpu 4G #SBATCH --gres=gpu:v100:1 @@ -21,11 +21,11 @@ export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use t REPO_DIR=/project/def-bovy/walml/zoobot srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \ - --save-dir $REPO_DIR/only_for_me/narval/desi_f128_2gpu \ + --save-dir $REPO_DIR/only_for_me/narval/desi_300px_f128_1gpu \ --batch-size 256 \ --num-features 128 \ --gpus 1 \ - --num-workers 20 \ + --num-workers 10 \ --color --wandb --mixed-precision --compile-encoder # srun python $SLURM_TMPDIR/zoobot/only_for_me/narval/finetune.py diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index 18ad2296..f8fc1451 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -74,19 +74,19 @@ def make_image_transform(self, mode="train"): # if mode == "train": # elif mode == "val": - # augmentation_transform = transforms.default_transforms( - # crop_scale_bounds=self.crop_scale_bounds, - # crop_ratio_bounds=self.crop_ratio_bounds, - # resize_after_crop=self.resize_after_crop, - # pytorch_greyscale=not self.color - # ) # A.Compose object - - logging.warning('Minimal augmentations for speed test') - augmentation_transform = transforms.minimal_transforms( + augmentation_transform = transforms.default_transforms( + crop_scale_bounds=self.crop_scale_bounds, + crop_ratio_bounds=self.crop_ratio_bounds, resize_after_crop=self.resize_after_crop, pytorch_greyscale=not self.color ) # A.Compose object + # logging.warning('Minimal augmentations for speed test') + # augmentation_transform = transforms.fast_transforms( + # resize_after_crop=self.resize_after_crop, + # pytorch_greyscale=not self.color + # ) # A.Compose object + def do_transform(img): return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) diff --git a/zoobot/pytorch/datasets/webdataset_utils.py b/zoobot/pytorch/datasets/webdataset_utils.py index 7ee46bbf..0abff0a2 100644 --- a/zoobot/pytorch/datasets/webdataset_utils.py +++ b/zoobot/pytorch/datasets/webdataset_utils.py @@ -6,6 +6,9 @@ from itertools import islice import glob + +import albumentations as A + import tqdm import numpy as np import pandas as pd @@ -37,26 +40,76 @@ def make_mock_wds(save_dir: str, label_cols: List, n_shards: int, shard_size: in -def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int): +def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None): + assert '.tar' in save_loc df['id_str'] = df['id_str'].astype(str).str.replace('.', '_') - + if sparse_label_df is not None: + logging.info(f'Using sparse label df: {len(sparse_label_df)}') shard_dfs = np.array_split(df, n_shards) logging.info(f'shards: {len(shard_dfs)}. Shard size: {len(shard_dfs[0])}') + + transforms_to_apply = [ + # below, for 224px fast training fast augs setup + # A.Resize( + # height=350, # now more aggressive, 65% crop effectively + # width=350, # now more aggressive, 65% crop effectively + # interpolation=cv2.INTER_AREA # slow and good interpolation + # ), + # A.CenterCrop( + # height=224, + # width=224, + # always_apply=True + # ), + # below, for standard training default augs + # small boundary trim and then resize expecting further 224px crop + # we want 0.7-0.8 effective crop + # in augs that could be 0.x-1.0, and here a pre-crop to 0.8 i.e. 340px + # but this would change the centering + # let's stick to small boundary crop and 0.75-0.85 in augs + A.CenterCrop( + height=400, + width=400, + always_apply=True + ), + A.Resize( + height=300, + width=300, + interpolation=cv2.INTER_AREA # slow and good interpolation + ) + ] + transform = A.Compose(transforms_to_apply) + # transform = None + for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)): + if sparse_label_df is not None: + shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar') logging.info(shard_save_loc) sink = wds.TarWriter(shard_save_loc) for _, galaxy in shard_df.iterrows(): - sink.write(galaxy_to_wds(galaxy, label_cols)) + sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform)) sink.close() -def galaxy_to_wds(galaxy: pd.Series, label_cols): +def galaxy_to_wds(galaxy: pd.Series, label_cols, transform=None): im = cv2.imread(galaxy['file_loc']) # cv2 loads BGR for 'history', fix im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + # if central_crop is not None: + # width, height, _ = im.shape + # # assert width == height, (width, height) + # mid = int(width/2) + # half_central_crop = int(central_crop/2) + # low_edge, high_edge = mid - half_central_crop, mid + half_central_crop + # im = im[low_edge:high_edge, low_edge:high_edge] + # assert im.shape == (central_crop, central_crop, 3) + + # apply albumentations + if transform is not None: + im = transform(image=im)['image'] + labels = json.dumps(galaxy[label_cols].to_dict()) id_str = str(galaxy['id_str']) return {