From 8d9065cba1af49b97c63c4701789a7f7a1fbcd47 Mon Sep 17 00:00:00 2001 From: Eric Wulff <31319227+erwulff@users.noreply.github.com> Date: Mon, 4 Mar 2024 09:04:52 +0100 Subject: [PATCH] Feat val freq (#298) * WIP: validate every val_freq training steps * update count parameters function * feat: implement intermediate validation every val_freq training steps * fix: black formatting * make ffn_dist_num_layers configurable in pytorch training * Update pyg-clic-hits.yaml * Update pyg-clic.yaml * Update pyg-cms.yaml * Update pyg-delphes.yaml * update pt search space * add script to count model parameters given a config file --- mlpf/count_parameters.py | 38 +++++++++++ mlpf/pyg/mlpf.py | 2 + mlpf/pyg/training.py | 94 ++++++++++++++++++++++++--- mlpf/pyg/utils.py | 12 ++-- mlpf/pyg_pipeline.py | 1 + mlpf/raytune/pt_search_space.py | 55 ++++++++++------ parameters/pytorch/pyg-clic-hits.yaml | 10 ++- parameters/pytorch/pyg-clic.yaml | 20 ++++-- parameters/pytorch/pyg-cms.yaml | 6 ++ parameters/pytorch/pyg-delphes.yaml | 8 ++- 10 files changed, 201 insertions(+), 45 deletions(-) create mode 100644 mlpf/count_parameters.py diff --git a/mlpf/count_parameters.py b/mlpf/count_parameters.py new file mode 100644 index 000000000..78165168d --- /dev/null +++ b/mlpf/count_parameters.py @@ -0,0 +1,38 @@ +import sys +import yaml + +sys.path.append("../mlpf") + +from pyg.mlpf import MLPF +from pyg.utils import ( + CLASS_LABELS, + X_FEATURES, + count_parameters, +) + + +with open(sys.argv[1], "r") as stream: # load config (includes: which physics samples, model params) + config = yaml.safe_load(stream) + +model_kwargs = { + "input_dim": len(X_FEATURES[config["dataset"]]), + "num_classes": len(CLASS_LABELS[config["dataset"]]), + "pt_mode": config["model"]["pt_mode"], + "eta_mode": config["model"]["eta_mode"], + "sin_phi_mode": config["model"]["sin_phi_mode"], + "cos_phi_mode": config["model"]["cos_phi_mode"], + "energy_mode": config["model"]["energy_mode"], + "attention_type": config["model"]["attention"]["attention_type"], + **config["model"][config["conv_type"]], +} +model = MLPF(**model_kwargs) + +trainable_params, nontrainable_params, table = count_parameters(model) + +print(table) + +print("Model conv type:", model.conv_type) +print("conv_type HPs", config["model"][config["conv_type"]]) +print("Trainable parameters:", trainable_params) +print("Non-trainable parameters:", nontrainable_params) +print("Total parameters:", trainable_params + nontrainable_params) diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index 513249a4a..f363920c8 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -139,6 +139,7 @@ def __init__( layernorm=True, num_node_messages=2, ffn_dist_hidden_dim=128, + ffn_dist_num_layers=2, # self-attention specific parameters num_heads=2, # mamba specific parameters @@ -208,6 +209,7 @@ def __init__( "num_node_messages": num_node_messages, "dropout": dropout, "ffn_dist_hidden_dim": ffn_dist_hidden_dim, + "ffn_dist_num_layers": ffn_dist_num_layers, } self.conv_id.append(CombinedGraphLayer(**gnn_conf)) self.conv_reg.append(CombinedGraphLayer(**gnn_conf)) diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index 10f61a109..83592b5e0 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -10,6 +10,7 @@ from datetime import datetime import tqdm import yaml +import csv import numpy as np @@ -199,14 +200,17 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: def train_and_valid( rank, world_size, + outdir, model, optimizer, - data_loader, + train_loader, + valid_loader, is_train=True, lr_schedule=None, comet_experiment=None, comet_step_freq=None, epoch=None, + val_freq=None, dtype=torch.float32, ): """ @@ -221,8 +225,10 @@ def train_and_valid( if is_train: model.train() + data_loader = train_loader else: model.eval() + data_loader = valid_loader # only show progress bar on rank 0 if (world_size > 1) and (rank != 0): @@ -234,6 +240,7 @@ def train_and_valid( device_type = "cuda" if isinstance(rank, int) else "cpu" + val_freq_time_0 = time.time() for itrain, batch in iterator: batch = batch.to(rank, non_blocking=True) @@ -282,6 +289,57 @@ def train_and_valid( comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=step) comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), step=step) + if val_freq is not None and is_train: + if itrain != 0 and itrain % val_freq == 0: + # time since last intermediate validation run + val_freq_time = torch.tensor(time.time() - val_freq_time_0, device=rank) + if world_size > 1: + torch.distributed.all_reduce(val_freq_time) + # compute intermediate training loss + intermediate_losses_t = {key: epoch_loss[key] for key in epoch_loss} + for loss_ in epoch_loss: + # sum up the losses from all workers and dicide by + if world_size > 1: + torch.distributed.all_reduce(intermediate_losses_t[loss_]) + intermediate_losses_t[loss_] = intermediate_losses_t[loss_].cpu().item() / itrain + + # compute intermediate validation loss + intermediate_losses_v = train_and_valid( + rank, + world_size, + outdir, + model, + optimizer, + train_loader, + valid_loader, + is_train=False, + epoch=epoch, + dtype=dtype, + ) + intermediate_metrics = dict( + loss=intermediate_losses_t["Total"], + reg_loss=intermediate_losses_t["Regression"], + cls_loss=intermediate_losses_t["Classification"], + charge_loss=intermediate_losses_t["Charge"], + val_loss=intermediate_losses_v["Total"], + val_reg_loss=intermediate_losses_v["Regression"], + val_cls_loss=intermediate_losses_v["Classification"], + val_charge_loss=intermediate_losses_v["Charge"], + inside_epoch=epoch, + step=(epoch - 1) * len(data_loader) + itrain, + val_freq_time=val_freq_time.cpu().item(), + ) + val_freq_log = os.path.join(outdir, "val_freq_log.csv") + if (rank == 0) or (rank == "cpu"): + with open(val_freq_log, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=intermediate_metrics.keys()) + if os.stat(val_freq_log).st_size == 0: # only write header if file is empty + writer.writeheader() + writer.writerow(intermediate_metrics) + if comet_experiment: + comet_experiment.log_metrics(intermediate_losses_v, prefix="valid", step=step) + val_freq_time_0 = time.time() # reset intermediate validation spacing timer + num_data = torch.tensor(len(data_loader), device=rank) # sum up the number of steps from all workers if world_size > 1: @@ -316,6 +374,7 @@ def train_mlpf( checkpoint_freq=None, comet_experiment=None, comet_step_freq=None, + val_freq=None, ): """ Will run a full training by calling train(). @@ -357,30 +416,45 @@ def train_mlpf( ) as prof: with record_function("model_train"): losses_t = train_and_valid( - rank, world_size, model, optimizer, train_loader, is_train=True, lr_schedule=lr_schedule, dtype=dtype + rank, + world_size, + outdir, + model, + optimizer, + train_loader=train_loader, + valid_loader=valid_loader, + is_train=True, + lr_schedule=lr_schedule, + val_freq=val_freq, + dtype=dtype, ) prof.export_chrome_trace("trace.json") else: losses_t = train_and_valid( rank, world_size, + outdir, model, optimizer, - train_loader, + train_loader=train_loader, + valid_loader=valid_loader, is_train=True, lr_schedule=lr_schedule, comet_experiment=comet_experiment, comet_step_freq=comet_step_freq, epoch=epoch, + val_freq=val_freq, dtype=dtype, ) losses_v = train_and_valid( rank, world_size, + outdir, model, optimizer, - valid_loader, + train_loader=train_loader, + valid_loader=valid_loader, is_train=False, lr_schedule=None, comet_experiment=comet_experiment, @@ -663,6 +737,7 @@ def run(rank, world_size, config, args, outdir, logfile): checkpoint_freq=config["checkpoint_freq"], comet_experiment=comet_experiment, comet_step_freq=config["comet_step_freq"], + val_freq=config["val_freq"], ) checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) @@ -852,8 +927,13 @@ def train_ray_trial(config, args, outdir=None): optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"]) trainable_params, nontrainable_params, table = count_parameters(model) + print(table) if (rank == 0) or (rank == "cpu"): + with open(os.path.join(outdir, "num_params.csv"), "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["trainable_params", "nontrainable_params", "total_params"]) + writer.writerow([trainable_params, nontrainable_params, trainable_params + nontrainable_params]) _logger.info(model) _logger.info(f"Trainable parameters: {trainable_params}") _logger.info(f"Non-trainable parameters: {nontrainable_params}") @@ -926,6 +1006,8 @@ def train_ray_trial(config, args, outdir=None): checkpoint_freq=config["checkpoint_freq"], comet_experiment=comet_experiment, comet_step_freq=config["comet_step_freq"], + dtype=getattr(torch, config["dtype"]), + val_freq=config["val_freq"], ) @@ -1121,10 +1203,6 @@ def run_hpo(config, args): print(result_df) print(result_df.columns) - print("Number of errored trials: {}".format(result_grid.num_errors)) - print("Number of terminated (not errored) trials: {}".format(result_grid.num_terminated)) - print("Ray Tune experiment path: {}".format(result_grid.experiment_path)) - logging.info("Total time of Tuner.fit(): {}".format(end - start)) logging.info( "Best hyperparameters found according to {} were: {}".format(config["raytune"]["default_metric"], best_config) diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index 4d10e2e0b..327a4cda4 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -272,6 +272,7 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=- steps_per_epoch=steps_per_epoch, epochs=epochs, last_epoch=last_batch, + pct_start=config["lr_schedule_config"]["onecycle"]["pct_start"] or 0.3, ) elif config["lr_schedule"] == "cosinedecay": lr_schedule = CosineAnnealingLR(opt, T_max=steps_per_epoch * epochs, last_epoch=last_batch) @@ -281,7 +282,8 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=- def count_parameters(model): - table = pd.DataFrame(columns=["Modules", "Trainable params", "Non-tranable params"]) + column_names = ["Modules", "Trainable parameters", "Non-tranable parameters"] + table = pd.DataFrame(columns=column_names) trainable_params = 0 nontrainable_params = 0 for ii, (name, parameter) in enumerate(model.named_parameters()): @@ -290,9 +292,7 @@ def count_parameters(model): table = pd.concat( [ table, - pd.DataFrame( - {"Modules": name, "Trainable Parameters": "-", "Non-tranable Parameters": params}, index=[ii] - ), + pd.DataFrame({column_names[0]: name, column_names[1]: 0, column_names[2]: params}, index=[ii]), ] ) nontrainable_params += params @@ -300,9 +300,7 @@ def count_parameters(model): table = pd.concat( [ table, - pd.DataFrame( - {"Modules": name, "Trainable Parameters": params, "Non-tranable Parameters": "-"}, index=[ii] - ), + pd.DataFrame({column_names[0]: name, column_names[1]: params, column_names[2]: 0}, index=[ii]), ] ) trainable_params += params diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index a0fa2b0ad..c6533056b 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -60,6 +60,7 @@ parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") parser.add_argument("--ntest", type=int, default=None, help="training samples to use, if None use entire dataset") parser.add_argument("--nvalid", type=int, default=None, help="validation samples to use") +parser.add_argument("--val-freq", type=int, default=None, help="run extra validation every val_freq training steps") parser.add_argument("--checkpoint-freq", type=int, default=None, help="epoch frequency for checkpointing") parser.add_argument("--hpo", type=str, default=None, help="perform hyperparameter optimization, name of HPO experiment") parser.add_argument("--ray-train", action="store_true", help="run training using Ray Train") diff --git a/mlpf/raytune/pt_search_space.py b/mlpf/raytune/pt_search_space.py index 871adbd03..d73258464 100644 --- a/mlpf/raytune/pt_search_space.py +++ b/mlpf/raytune/pt_search_space.py @@ -1,25 +1,32 @@ -from ray.tune import choice # grid_search, choice, loguniform, quniform +from ray.tune import grid_search # grid_search, choice, loguniform, quniform -raytune_num_samples = 16 # Number of random samples to draw from search space. Set to 1 for grid search. -samp = choice +raytune_num_samples = 1 # Number of random samples to draw from search space. Set to 1 for grid search. +samp = grid_search # gnn scan search_space = { + # dataset parameters + "ntrain": samp([500]), + # "ntest": samp([10000]), + "nvalid": samp([500]), + "num_epochs": samp([10]), # optimizer parameters - "lr": samp([1e-4, 3e-4, 1e-3, 3e-3, 1e-2]), - "gpu_batch_multiplier": samp([1, 4, 8, 16]), - # model arch parameters - "activation": samp(["elu", "relu", "relu6", "leakyrelu"]), - "conv_type": samp(["gravnet"]), # can be "gnn_lsh", "gravnet", "attention" - "embedding_dim": samp([32, 64, 128, 252, 512, 1024]), - "width": samp([32, 64, 128, 256, 512, 1024]), - "num_convs": samp([1, 2, 3, 4, 5, 6]), - "dropout": samp([0.0, 0.01, 0.1, 0.4]), + "lr": samp([1e-4, 3e-4, 1e-3, 3e-3]), + "lr_schedule": samp(["onecycle"]), + "pct_start": samp([0.05]), + # "gpu_batch_multiplier": samp([1, 4, 8, 16]), # "patience": samp([9999]), + # model arch parameters + # "activation": samp(["elu", "relu", "relu6", "leakyrelu"]), + "conv_type": samp(["attention"]), # can be "gnn_lsh", "gravnet", "attention" + # "embedding_dim": samp([32, 64, 128, 252, 512, 1024]), + # "width": samp([32, 64, 128, 256, 512, 1024]), + # "num_convs": samp([1, 2, 3, 4, 5, 6]), + # "dropout": samp([0.0, 0.01, 0.1, 0.4]), # only for gravnet - "k": samp([8, 16, 32]), - "propagate_dimensions": samp([8, 16, 32, 64, 128]), - "space_dimensions": samp([4]), + # "k": samp([8, 16, 32]), + # "propagate_dimensions": samp([8, 16, 32, 64, 128]), + # "space_dimensions": samp([4]), # only for gnn-lsh # "bin_size": samp([160, 320, 640]), # "max_num_bins": samp([200]), @@ -27,15 +34,20 @@ # "layernorm": samp([True, False]), # "num_node_messages": samp([1, 2, 3, 4, 5]), # "ffn_dist_hidden_dim": samp([16, 32, 64, 128, 256]), - # mamba specific variables - "d_state": samp([16]), - "d_conv": samp([4]), - "expand": samp([2]), + # "ffn_dist_num_layers": samp([1, 2, 3, 4, 5, 6]), + # mamba specific parameters + # "d_state": samp([16]), + # "d_conv": samp([4]), + # "expand": samp([2]), + # "num_heads": samp([2, 4, 6, 8, 10, 12]), + # attention specifica parameters + "num_heads": samp([2, 4, 8, 16]), + # "attention_type": samp(["flash"]), # flash, efficient, math } def set_hps_from_search_space(search_space, config): - varaible_names = ["lr", "gpu_batch_multiplier"] + varaible_names = ["lr", "lr_schedule", "gpu_batch_multiplier", "ntrain", "ntest", "nvalid", "num_epochs", "patience"] for var in varaible_names: if var in search_space.keys(): config[var] = search_space[var] @@ -81,4 +93,7 @@ def set_hps_from_search_space(search_space, config): if var in search_space.keys(): config["model"][conv_type][var] = search_space[var] + if "pct_start" in search_space.keys(): + config["lr_schedule_config"]["onecycle"]["pct_start"] = search_space["pct_start"] + return config diff --git a/parameters/pytorch/pyg-clic-hits.yaml b/parameters/pytorch/pyg-clic-hits.yaml index 38279764b..67f0b2e59 100644 --- a/parameters/pytorch/pyg-clic-hits.yaml +++ b/parameters/pytorch/pyg-clic-hits.yaml @@ -12,7 +12,7 @@ lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: -nvalid: 500 +nvalid: num_workers: 0 prefetch_factor: checkpoint_freq: @@ -20,6 +20,7 @@ comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 dtype: float32 +val_freq: # run an extra validation run every val_freq training steps model: pt_mode: linear @@ -42,6 +43,7 @@ model: layernorm: True num_node_messages: 1 ffn_dist_hidden_dim: 128 + ffn_dist_num_layers: 2 gravnet: conv_type: gravnet @@ -78,10 +80,14 @@ model: d_conv: 4 expand: 4 +lr_schedule_config: + onecycle: + pct_start: 0.3 + raytune: local_dir: # Note: please specify an absolute path sched: asha # asha, hyperband - search_alg: hyperopt # bayes, bohb, hyperopt, nevergrad, scikit + search_alg: # bayes, bohb, hyperopt, nevergrad, scikit default_metric: "val_loss" default_mode: "min" # Tune schedule specific parameters diff --git a/parameters/pytorch/pyg-clic.yaml b/parameters/pytorch/pyg-clic.yaml index 8452a216d..934df2aae 100644 --- a/parameters/pytorch/pyg-clic.yaml +++ b/parameters/pytorch/pyg-clic.yaml @@ -6,21 +6,22 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: -num_epochs: 2 +num_epochs: 10 patience: 20 lr: 0.0001 lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: -nvalid: 500 +nvalid: num_workers: 0 prefetch_factor: checkpoint_freq: comet_name: particleflow-pt comet_offline: False -comet_step_freq: 10 +comet_step_freq: 100 dtype: float32 +val_freq: # run an extra validation run every val_freq training steps model: pt_mode: linear @@ -43,6 +44,7 @@ model: layernorm: True num_node_messages: 2 ffn_dist_hidden_dim: 128 + ffn_dist_num_layers: 2 gravnet: conv_type: gravnet @@ -81,10 +83,14 @@ model: d_conv: 4 expand: 2 +lr_schedule_config: + onecycle: + pct_start: 0.3 + raytune: local_dir: # Note: please specify an absolute path sched: asha # asha, hyperband - search_alg: hyperopt # bayes, bohb, hyperopt, nevergrad, scikit + search_alg: # bayes, bohb, hyperopt, nevergrad, scikit default_metric: "val_loss" default_mode: "min" # Tune schedule specific parameters @@ -104,7 +110,7 @@ raytune: train_dataset: clic: physical: - batch_size: 100 + batch_size: 1 samples: clic_edm_qq_pf: version: 1.5.0 @@ -120,7 +126,7 @@ train_dataset: valid_dataset: clic: physical: - batch_size: 100 + batch_size: 1 samples: clic_edm_qq_pf: version: 1.5.0 @@ -128,7 +134,7 @@ valid_dataset: test_dataset: clic: physical: - batch_size: 100 + batch_size: 1 samples: clic_edm_qq_pf: version: 1.5.0 diff --git a/parameters/pytorch/pyg-cms.yaml b/parameters/pytorch/pyg-cms.yaml index 7ee03154d..7953cb9ef 100644 --- a/parameters/pytorch/pyg-cms.yaml +++ b/parameters/pytorch/pyg-cms.yaml @@ -21,6 +21,7 @@ comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 dtype: bfloat16 +val_freq: # run an extra validation run every val_freq training steps model: pt_mode: linear @@ -43,6 +44,7 @@ model: layernorm: True num_node_messages: 2 ffn_dist_hidden_dim: 128 + ffn_dist_num_layers: 2 gravnet: conv_type: gravnet @@ -79,6 +81,10 @@ model: d_conv: 4 expand: 2 +lr_schedule_config: + onecycle: + pct_start: 0.3 + raytune: local_dir: # Note: please specify an absolute path sched: asha # asha, hyperband diff --git a/parameters/pytorch/pyg-delphes.yaml b/parameters/pytorch/pyg-delphes.yaml index 6746bcda6..ba0163643 100644 --- a/parameters/pytorch/pyg-delphes.yaml +++ b/parameters/pytorch/pyg-delphes.yaml @@ -13,7 +13,7 @@ lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: -nvalid: 500 +nvalid: num_workers: 0 prefetch_factor: checkpoint_freq: @@ -21,6 +21,7 @@ comet_name: particleflow-pt comet_offline: False comet_step_freq: 10 dtype: float32 +val_freq: # run an extra validation run every val_freq training steps model: pt_mode: linear @@ -43,6 +44,7 @@ model: layernorm: True num_node_messages: 2 ffn_dist_hidden_dim: 128 + ffn_dist_num_layers: 2 gravnet: conv_type: gravnet @@ -81,6 +83,10 @@ model: d_conv: 4 expand: 2 +lr_schedule_config: + onecycle: + pct_start: 0.3 + raytune: local_dir: # Note: please specify an absolute path sched: asha # asha, hyperband