From 56fbf57dedaac8e51f33270f33fbd8fa6dead450 Mon Sep 17 00:00:00 2001 From: Eric Wulff <31319227+erwulff@users.noreply.github.com> Date: Mon, 11 Dec 2023 09:51:00 +0100 Subject: [PATCH] Learning rate schedules and Mamba layer (#282) * fix: update parameter files * fix: better comet-ml logging * update flatiron Ray Train submissions scripts * update sbatch script * log overridden config to comet-ml instead of original * fix: checkpoint loading specify full path to checkpoint using --load-cehckpoint * feat: implement LR schedules in the PyTorch training code * update sbatch scripts * feat: LR schedules support checkpointing and resuming training * update sbatch scripts * update ray tune search space * fix: dropout parameter not taking effect on torch gnn-lsh model * make more gnn-lsh parameters confgiurable * make activation function configurable * update raytune search space * feat: add MambaLayer * update raytune search space * update pyg-cms.yaml * fix loading of checkpoint in testing with raytrain based run --- mlpf/pyg/mlpf.py | 78 ++++++++-- mlpf/pyg/training.py | 134 ++++++++++++------ mlpf/pyg/utils.py | 35 +++++ mlpf/pyg_pipeline.py | 26 ++-- mlpf/raytune/pt_search_space.py | 86 +++++++---- parameters/pyg-clic.yaml | 32 ++++- parameters/pyg-cms-physical.yaml | 30 +++- parameters/pyg-cms-small-highqcd.yaml | 30 +++- parameters/pyg-cms-small.yaml | 30 +++- parameters/pyg-cms-test-qcdhighpt.yaml | 32 ++++- parameters/pyg-cms.yaml | 73 +++++++++- parameters/pyg-delphes.yaml | 32 ++++- parameters/pyg-workflow-test.yaml | 32 ++++- scripts/flatiron/pt_raytrain_a100.slurm | 4 +- scripts/flatiron/pt_raytrain_h100.slurm | 2 +- .../flatiron/pt_raytune_a100_1GPUperTrial.sh | 5 +- .../flatiron/pt_raytune_a100_4GPUsperTrial.sh | 5 +- .../flatiron/pt_raytune_h100_1GPUperTrial.sh | 7 +- .../flatiron/pt_raytune_h100_2GPUsperTrial.sh | 3 +- .../flatiron/pt_raytune_h100_4GPUsperTrial.sh | 3 +- .../flatiron/pt_raytune_h100_8GPUsperTrial.sh | 3 +- 21 files changed, 558 insertions(+), 124 deletions(-) diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index bd2e7cb76..4274bdb47 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -41,6 +41,32 @@ def forward(self, x, mask): return x +class MambaLayer(nn.Module): + def __init__(self, embedding_dim=128, num_heads=2, width=128, dropout=0.1, d_state=16, d_conv=4, expand=2): + super(MambaLayer, self).__init__() + self.act = nn.ELU + from mamba_ssm import Mamba + + self.mamba = Mamba( + d_model=embedding_dim, + d_state=d_state, + d_conv=d_conv, + expand=expand, + ) + self.norm0 = torch.nn.LayerNorm(embedding_dim) + self.seq = torch.nn.Sequential( + nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act() + ) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x, mask): + x = self.mamba(x) + x = self.norm0(x + self.seq(x)) + x = self.dropout(x) + x = x * (~mask.unsqueeze(-1)) + return x + + def ffn(input_dim, output_dim, width, act, dropout): return nn.Sequential( nn.Linear(input_dim, width), @@ -59,22 +85,45 @@ def __init__( embedding_dim=128, width=128, num_convs=2, + dropout=0.0, + activation="elu", + # gravnet specific parameters k=32, propagate_dimensions=32, space_dimensions=4, - dropout=0.4, conv_type="gravnet", + # gnn-lsh specific parameters + bin_size=640, + max_num_bins=200, + distance_dim=128, + layernorm=True, + num_node_messages=2, + ffn_dist_hidden_dim=128, + # self-attention specific parameters + num_heads=2, + # mamba specific parameters + d_state=16, + d_conv=4, + expand=2, ): super(MLPF, self).__init__() self.conv_type = conv_type - self.act = nn.ELU + if activation == "elu": + self.act = nn.ELU + elif activation == "relu": + self.act = nn.ReLU + elif activation == "relu6": + self.act = nn.ReLU6 + elif activation == "leakyrelu": + self.act = nn.LeakyReLU + self.dropout = dropout self.input_dim = input_dim self.num_convs = num_convs - self.bin_size = 640 + self.bin_size = bin_size # embedding of the inputs if num_convs != 0: @@ -89,8 +138,14 @@ def __init__( self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() for i in range(num_convs): - self.conv_id.append(SelfAttentionLayer(embedding_dim)) - self.conv_reg.append(SelfAttentionLayer(embedding_dim)) + self.conv_id.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout)) + self.conv_reg.append(SelfAttentionLayer(embedding_dim, num_heads, width, dropout)) + elif self.conv_type == "mamba": + self.conv_id = nn.ModuleList() + self.conv_reg = nn.ModuleList() + for i in range(num_convs): + self.conv_id.append(MambaLayer(embedding_dim, num_heads, width, dropout, d_state, d_conv, expand)) + self.conv_reg.append(MambaLayer(embedding_dim, num_heads, width, dropout, d_state, d_conv, expand)) elif self.conv_type == "gnn_lsh": self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() @@ -98,12 +153,12 @@ def __init__( gnn_conf = { "inout_dim": embedding_dim, "bin_size": self.bin_size, - "max_num_bins": 200, - "distance_dim": 128, - "layernorm": True, - "num_node_messages": 2, - "dropout": 0.0, - "ffn_dist_hidden_dim": 128, + "max_num_bins": max_num_bins, + "distance_dim": distance_dim, + "layernorm": layernorm, + "num_node_messages": num_node_messages, + "dropout": dropout, + "ffn_dist_hidden_dim": ffn_dist_hidden_dim, } self.conv_id.append(CombinedGraphLayer(**gnn_conf)) self.conv_reg.append(CombinedGraphLayer(**gnn_conf)) @@ -123,7 +178,6 @@ def __init__( self.nn_charge = ffn(decoding_dim + num_classes, 3, width, self.act, dropout) def forward(self, X_features, batch_or_mask): - embeddings_id, embeddings_reg = [], [] if self.num_convs != 0: embedding = self.nn0(X_features) diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index ba69a650d..b98631cfe 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -34,6 +34,7 @@ CLASS_LABELS, X_FEATURES, save_HPs, + get_lr_schedule, ) @@ -156,7 +157,16 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: def train_and_valid( - rank, world_size, model, optimizer, data_loader, is_train, comet_experiment=None, comet_step_freq=None, epoch=None + rank, + world_size, + model, + optimizer, + data_loader, + is_train, + lr_schedule=None, + comet_experiment=None, + comet_step_freq=None, + epoch=None, ): """ Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch. @@ -206,6 +216,8 @@ def train_and_valid( if is_train: loss["Total"].backward() optimizer.step() + if lr_schedule: + lr_schedule.step() for loss_ in epoch_loss: epoch_loss[loss_] += loss[loss_].detach() @@ -213,7 +225,9 @@ def train_and_valid( if comet_experiment and is_train: if itrain % comet_step_freq == 0: # this loss is not normalized to batch size - comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=(epoch - 1) * len(data_loader) + itrain) + step = (epoch - 1) * len(data_loader) + itrain + comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=step) + comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), step=step) num_data = torch.tensor(len(data_loader), device=rank) # sum up the number of steps from all workers @@ -242,6 +256,8 @@ def train_mlpf( num_epochs, patience, outdir, + start_epoch=1, + lr_schedule=None, use_ray=False, checkpoint_freq=None, comet_experiment=None, @@ -274,19 +290,6 @@ def train_mlpf( losses["train"][loss], losses["valid"][loss] = [], [] stale_epochs, best_val_loss = torch.tensor(0, device=rank), float("inf") - start_epoch = 1 - - if use_ray: - import ray - from ray.train import Checkpoint - - checkpoint = ray.train.get_checkpoint() - if checkpoint: - with checkpoint.as_directory() as checkpoint_dir: - with checkpoint.as_directory() as checkpoint_dir: - checkpoint = torch.load(Path(checkpoint_dir) / "checkpoint.pth", map_location=torch.device(rank)) - model, optimizer = load_checkpoint(checkpoint, model, optimizer) - start_epoch = checkpoint["extra_state"]["epoch"] + 1 for epoch in range(start_epoch, num_epochs + 1): t0 = time.time() @@ -297,24 +300,25 @@ def train_mlpf( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True ) as prof: with record_function("model_train"): - losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True) + losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True, lr_schedule) prof.export_chrome_trace("trace.json") else: losses_t = train_and_valid( - rank, world_size, model, optimizer, train_loader, True, comet_experiment, comet_step_freq, epoch + rank, world_size, model, optimizer, train_loader, True, lr_schedule, comet_experiment, comet_step_freq, epoch ) losses_v = train_and_valid( - rank, world_size, model, optimizer, valid_loader, False, comet_experiment, comet_step_freq, epoch + rank, world_size, model, optimizer, valid_loader, False, None, comet_experiment, comet_step_freq, epoch ) if comet_experiment: comet_experiment.log_metrics(losses_t, prefix="epoch_train_loss", epoch=epoch) comet_experiment.log_metrics(losses_v, prefix="epoch_valid_loss", epoch=epoch) + comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), epoch=epoch) comet_experiment.log_epoch_end(epoch) if (rank == 0) or (rank == "cpu"): - extra_state = {"epoch": epoch} + extra_state = {"epoch": epoch, "lr_schedule_state_dict": lr_schedule.state_dict()} if losses_v["Total"] < best_val_loss: best_val_loss = losses_v["Total"] stale_epochs = 0 @@ -333,6 +337,9 @@ def train_mlpf( save_checkpoint(checkpoint_path, model, optimizer, extra_state) if use_ray: + import ray + from ray.train import Checkpoint + # save model, optimizer and epoch number for HPO-supported checkpointing # Ray automatically syncs the checkpoint to persistent storage metrics = dict( @@ -439,31 +446,34 @@ def run(rank, world_size, config, args, outdir, logfile): if (rank == 0) or (rank == "cpu"): # keep writing the logs _configLogger("mlpf", filename=logfile) + start_epoch = 1 + if config["load"]: # load a pre-trained model - loaddir = config["load"] # in case both --load and --train are provided + if Path(config["load"]).name == "checkpoint.pth": + # the checkpoint is likely from a Ray Train run and we need to step one dir higher up + loaddir = str(Path(config["load"]).parent.parent.parent) + else: + # the checkpoint is likely from a DDP run and we need to step up one dir less + loaddir = str(Path(config["load"]).parent.parent) with open(f"{loaddir}/model_kwargs.pkl", "rb") as f: model_kwargs = pkl.load(f) - model = MLPF(**model_kwargs) + model = MLPF(**model_kwargs).to(torch.device(rank)) optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"]) - if args.load_checkpoint: - checkpoint = torch.load(f"{args.load_checkpoint}", map_location=torch.device(rank)) - if (rank == 0) or (rank == "cpu"): - _logger.info(f"Loaded model weights from {loaddir}/checkpoints/{args.load_checkpoint}") - else: - checkpoint = torch.load(f"{loaddir}/best_weights.pth", map_location=torch.device(rank)) - if (rank == 0) or (rank == "cpu"): - _logger.info(f"Loaded model weights from {loaddir}/best_weights.pth") + checkpoint = torch.load(config["load"], map_location=torch.device(rank)) + testdir_name = "_" + Path(config["load"]).name + if (rank == 0) or (rank == "cpu"): + _logger.info("Loaded model weights from {}".format(config["load"]), color="bold") model, optimizer = load_checkpoint(checkpoint, model, optimizer) - - if args.load_checkpoint: - testdir_name = f"_{args.load_checkpoint[:13]}" - else: - testdir_name = "_bestweights" - + elif args.resume_training: + raise NotImplementedError( + "Resuming an interrupted training is only supported in our \ + Ray Train-based training. Consider using `--load` instead, \ + which starts a new training using model weights from a pre-trained checkpoint." + ) else: # instantiate a new model in the outdir created testdir_name = "_bestweights" @@ -494,7 +504,7 @@ def run(rank, world_size, config, args, outdir, logfile): comet_experiment = create_comet_experiment( config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir ) - comet_experiment.set_name(f"rank_{rank}") + comet_experiment.set_name(f"rank_{rank}_{Path(outdir).name}") comet_experiment.log_parameter("run_id", Path(outdir).name) comet_experiment.log_parameter("world_size", world_size) comet_experiment.log_parameter("rank", rank) @@ -518,6 +528,9 @@ def run(rank, world_size, config, args, outdir, logfile): pad_3d, use_ray=False, ) + steps_per_epoch = len(loaders["train"]) + last_epoch = -1 if start_epoch == 1 else start_epoch - 1 + lr_schedule = get_lr_schedule(config, optimizer, config["num_epochs"], steps_per_epoch, last_epoch) train_mlpf( rank, @@ -529,6 +542,8 @@ def run(rank, world_size, config, args, outdir, logfile): config["num_epochs"], config["patience"], outdir, + start_epoch=start_epoch, + lr_schedule=lr_schedule, use_ray=False, checkpoint_freq=config["checkpoint_freq"], comet_experiment=comet_experiment, @@ -544,7 +559,7 @@ def run(rank, world_size, config, args, outdir, logfile): assert args.train, "Please train a model before testing, or load a model with --load" assert outdir is not None, "Error: no outdir to evaluate model from" else: - outdir = config["load"] + outdir = str(Path(config["load"]).parent.parent) for type_ in config["test_dataset"][config["dataset"]]: # will be "physical", "gun" batch_size = config["test_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"] @@ -645,7 +660,7 @@ def device_agnostic_run(config, args, world_size, outdir): logfile = f"{outdir}/train.log" _configLogger("mlpf", filename=logfile) else: - outdir = args.load + outdir = str(Path(args.load).parent.parent) logfile = f"{outdir}/test.log" _configLogger("mlpf", filename=logfile) @@ -719,7 +734,7 @@ def train_ray_trial(config, args, outdir=None): comet_experiment = create_comet_experiment( config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir ) - comet_experiment.set_name(f"world_rank_{world_rank}") + comet_experiment.set_name(f"world_rank_{world_rank}_{Path(outdir).name}") comet_experiment.log_parameter("run_id", Path(outdir).name) comet_experiment.log_parameter("world_size", world_size) comet_experiment.log_parameter("rank", rank) @@ -737,6 +752,24 @@ def train_ray_trial(config, args, outdir=None): else: comet_experiment = None + steps_per_epoch = len(loaders["train"]) + start_epoch = 1 + lr_schedule = get_lr_schedule(config, optimizer, config["num_epochs"], steps_per_epoch, last_epoch=-1) + + checkpoint = ray.train.get_checkpoint() + if checkpoint: + with checkpoint.as_directory() as checkpoint_dir: + with checkpoint.as_directory() as checkpoint_dir: + checkpoint = torch.load(Path(checkpoint_dir) / "checkpoint.pth", map_location=torch.device(rank)) + if args.resume_training: + model, optimizer = load_checkpoint(checkpoint, model, optimizer) + start_epoch = checkpoint["extra_state"]["epoch"] + 1 + lr_schedule = get_lr_schedule( + config, optimizer, config["num_epochs"], steps_per_epoch, last_epoch=start_epoch - 1 + ) + else: # start a new training with model weights loaded from a pre-trained model + model = load_checkpoint(checkpoint, model) + train_mlpf( rank, world_size, @@ -747,6 +780,8 @@ def train_ray_trial(config, args, outdir=None): config["num_epochs"], config["patience"], outdir, + start_epoch=start_epoch, + lr_schedule=lr_schedule, use_ray=True, checkpoint_freq=config["checkpoint_freq"], comet_experiment=comet_experiment, @@ -767,6 +802,9 @@ def run_ray_training(config, args, outdir): if not args.local: ray.init(address="auto") + if args.resume_training: + outdir = args.resume_training # continue training in the same directory + _configLogger("mlpf", filename=f"{outdir}/train.log") num_workers = args.gpus @@ -785,9 +823,21 @@ def run_ray_training(config, args, outdir): sync_config=ray.train.SyncConfig(sync_artifacts=True), ) trainable = tune.with_parameters(train_ray_trial, args=args, outdir=outdir) - trainer = TorchTrainer( - train_loop_per_worker=trainable, train_loop_config=config, scaling_config=scaling_config, run_config=run_config - ) + # Resume from checkpoint if a checkpoitn is found in outdir + if TorchTrainer.can_restore(outdir): + _logger.info(f"Restoring Ray Trainer from {outdir}", color="bold") + trainer = TorchTrainer.restore(outdir, train_loop_per_worker=trainable, scaling_config=scaling_config) + else: + resume_from_checkpoint = ray.train.Checkpoint(config["load"]) if config["load"] else None + if resume_from_checkpoint: + _logger.info("Loading checkpoint {}".format(config["load"]), color="bold") + trainer = TorchTrainer( + train_loop_per_worker=trainable, + train_loop_config=config, + scaling_config=scaling_config, + run_config=run_config, + resume_from_checkpoint=resume_from_checkpoint, + ) result = trainer.fit() _logger.info("Final loss: {}".format(result.metrics["loss"]), color="bold") diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index 42fecfe89..386ec1f9e 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -3,6 +3,8 @@ import torch import torch.utils.data +from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR, ConstantLR + # https://github.com/ahlinist/cmssw/blob/1df62491f48ef964d198f574cdfcccfd17c70425/DataFormats/ParticleFlowReco/interface/PFBlockElement.h#L33 # https://github.com/cms-sw/cmssw/blob/master/DataFormats/ParticleFlowCandidate/src/PFCandidate.cc#L254 @@ -205,3 +207,36 @@ def save_checkpoint(checkpoint_path, model, optimizer=None, extra_state=None): }, checkpoint_path, ) + + +def load_lr_schedule(lr_schedule, checkpoint): + "Loads the lr_schedule's state dict from checkpoint" + if "lr_schedule_state_dict" in checkpoint["extra_state"].keys(): + lr_schedule.load_state_dict(checkpoint["extra_state"]["lr_schedule_state_dict"]) + return lr_schedule + else: + raise KeyError( + "Couldn't find LR schedule state dict in checkpoint. extra_state contains: {}".format( + checkpoint["extra_state"].keys() + ) + ) + + +def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=-1): + # we step teh schedule every mini-batch so need to multiply by steps_per_epoch + last_batch = last_epoch * steps_per_epoch - 1 if last_epoch != -1 else -1 + if config["lr_schedule"] == "constant": + lr_schedule = ConstantLR(opt, factor=1.0, total_iters=steps_per_epoch * epochs) + elif config["lr_schedule"] == "onecycle": + lr_schedule = OneCycleLR( + opt, + max_lr=config["lr"], + steps_per_epoch=steps_per_epoch, + epochs=epochs, + last_epoch=last_batch, + ) + elif config["lr_schedule"] == "cosinedecay": + lr_schedule = CosineAnnealingLR(opt, T_max=steps_per_epoch * epochs, last_epoch=last_batch) + else: + raise ValueError("Supported values for lr_schedule are 'constant', 'onecycle' and 'cosinedecay'.") + return lr_schedule diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 92f36b7f6..0ec9fc08d 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -33,14 +33,22 @@ ) parser.add_argument("--num-workers", type=int, default=None, help="number of processes to load the data") parser.add_argument("--prefetch-factor", type=int, default=None, help="number of samples to fetch & prefetch at every call") -parser.add_argument("--load", type=str, default=None, help="dir from which to load a saved model") +parser.add_argument( + "--resume-training", type=str, default=None, help="training dir containing the checkpointed training to resume" +) +parser.add_argument("--load", type=str, default=None, help="load checkpoint and start new training from epoch 1") + parser.add_argument("--train", action="store_true", default=None, help="initiates a training") parser.add_argument("--test", action="store_true", default=None, help="tests the model") parser.add_argument("--num-epochs", type=int, default=None, help="number of training epochs") parser.add_argument("--patience", type=int, default=None, help="patience before early stopping") parser.add_argument("--lr", type=float, default=None, help="learning rate") parser.add_argument( - "--conv-type", type=str, default=None, help="which graph layer to use", choices=["gravnet", "attention", "gnn_lsh"] + "--conv-type", + type=str, + default=None, + help="which graph layer to use", + choices=["gravnet", "attention", "gnn_lsh", "mamba"], ) parser.add_argument("--make-plots", action="store_true", default=None, help="make plots of the test predictions") parser.add_argument("--export-onnx", action="store_true", default=None, help="exports the model to onnx") @@ -53,9 +61,6 @@ parser.add_argument("--local", action="store_true", default=None, help="perform HPO locally, without a Ray cluster") parser.add_argument("--ray-cpus", type=int, default=None, help="CPUs per trial for HPO") parser.add_argument("--ray-gpus", type=int, default=None, help="GPUs per trial for HPO") -parser.add_argument( - "--load-checkpoint", type=str, default=None, help="which checkpoint to load. if None then will load best weights" -) parser.add_argument("--comet", action="store_true", help="use comet ml logging") parser.add_argument("--comet-offline", action="store_true", help="save comet logs locally") parser.add_argument("--comet-step-freq", type=int, default=None, help="step frequency for saving comet metrics") @@ -77,10 +82,13 @@ def main(): if args.hpo: run_hpo(config, args) else: - outdir = create_experiment_dir( - prefix=(args.prefix or "") + Path(args.config).stem + "_", - experiments_dir=args.experiments_dir if args.experiments_dir else "experiments", - ) + if args.resume_training: + outdir = args.resume_training + else: + outdir = create_experiment_dir( + prefix=(args.prefix or "") + Path(args.config).stem + "_", + experiments_dir=args.experiments_dir if args.experiments_dir else "experiments", + ) # Save config for later reference. Note that saving happens after parameters are overwritten by cmd line args. config_filename = "train-config.yaml" if args.train else "test-config.yaml" with open((Path(outdir) / config_filename), "w") as file: diff --git a/mlpf/raytune/pt_search_space.py b/mlpf/raytune/pt_search_space.py index e8076c491..871adbd03 100644 --- a/mlpf/raytune/pt_search_space.py +++ b/mlpf/raytune/pt_search_space.py @@ -1,48 +1,84 @@ from ray.tune import choice # grid_search, choice, loguniform, quniform -raytune_num_samples = 8 # Number of random samples to draw from search space. Set to 1 for grid search. +raytune_num_samples = 16 # Number of random samples to draw from search space. Set to 1 for grid search. samp = choice # gnn scan search_space = { # optimizer parameters - "lr": samp([1e-4, 1e-3, 1e-2]), - # "gpu_batch_multiplier": samp([10, 20, 40]), + "lr": samp([1e-4, 3e-4, 1e-3, 3e-3, 1e-2]), + "gpu_batch_multiplier": samp([1, 4, 8, 16]), # model arch parameters - "conv_type": samp(["gnn_lsh"]), - "embedding_dim": samp([128, 252, 512]), - # "width": samp([512]), - # "num_convs": samp([3]), - # "dropout": samp([0.0]), - # "patience": samp([20]) + "activation": samp(["elu", "relu", "relu6", "leakyrelu"]), + "conv_type": samp(["gravnet"]), # can be "gnn_lsh", "gravnet", "attention" + "embedding_dim": samp([32, 64, 128, 252, 512, 1024]), + "width": samp([32, 64, 128, 256, 512, 1024]), + "num_convs": samp([1, 2, 3, 4, 5, 6]), + "dropout": samp([0.0, 0.01, 0.1, 0.4]), + # "patience": samp([9999]), + # only for gravnet + "k": samp([8, 16, 32]), + "propagate_dimensions": samp([8, 16, 32, 64, 128]), + "space_dimensions": samp([4]), + # only for gnn-lsh + # "bin_size": samp([160, 320, 640]), + # "max_num_bins": samp([200]), + # "distance_dim": samp([16, 32, 64, 128, 256]), + # "layernorm": samp([True, False]), + # "num_node_messages": samp([1, 2, 3, 4, 5]), + # "ffn_dist_hidden_dim": samp([16, 32, 64, 128, 256]), + # mamba specific variables + "d_state": samp([16]), + "d_conv": samp([4]), + "expand": samp([2]), } def set_hps_from_search_space(search_space, config): - if "lr" in search_space.keys(): - config["lr"] = search_space["lr"] - - if "gpu_batch_multiplier" in search_space.keys(): - config["gpu_batch_multiplier"] = search_space["gpu_batch_multiplier"] + varaible_names = ["lr", "gpu_batch_multiplier"] + for var in varaible_names: + if var in search_space.keys(): + config[var] = search_space[var] if "conv_type" in search_space.keys(): conv_type = search_space["conv_type"] config["conv_type"] = conv_type - if conv_type == "gnn_lsh" or conv_type == "transformer": - if "embedding_dim" in search_space.keys(): - config["model"][conv_type]["embedding_dim"] = search_space["embedding_dim"] + common_varaible_names = ["embedding_dim", "width", "num_convs", "activation"] + if conv_type == "gnn_lsh" or conv_type == "gravnet" or conv_type == "attention": + for var in common_varaible_names: + if var in search_space.keys(): + config["model"][conv_type][var] = search_space[var] - if "width" in search_space.keys(): - config["model"][conv_type]["width"] = search_space["width"] + gravnet_variable_names = ["k", "propagate_dimensions", "space_dimensions"] + if conv_type == "gravnet": + for var in gravnet_variable_names: + if var in search_space.keys(): + config["model"][conv_type][var] = search_space[var] - if "num_convs" in search_space.keys(): - config["model"][conv_type]["num_convs"] = search_space["num_convs"] + attention_variables = ["num_heads"] + if conv_type == "attention": + for var in attention_variables: + if var in search_space.keys(): + config["model"][conv_type][var] = search_space[var] - if "num_convs" in search_space.keys(): - config["model"][conv_type]["num_convs"] = search_space["num_convs"] + mamba_variables = ["num_heads", "d_state", "d_conv", "expand"] + if conv_type == "mamba": + for var in mamba_variables: + if var in search_space.keys(): + config["model"][conv_type][var] = search_space[var] - if "embedding_dim" in search_space.keys(): - config["embedding_dim"] = search_space["embedding_dim"] + gnn_lsh_varaible_names = [ + "bin_size", + "max_num_bins", + "distance_dim", + "layernorm", + "num_node_messages", + "ffn_dist_hidden_dim", + ] + if conv_type == "gnn_lsh": + for var in gnn_lsh_varaible_names: + if var in search_space.keys(): + config["model"][conv_type][var] = search_space[var] return config diff --git a/parameters/pyg-clic.yaml b/parameters/pyg-clic.yaml index e01815514..e2eb3598a 100644 --- a/parameters/pyg-clic.yaml +++ b/parameters/pyg-clic.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: @@ -26,23 +27,50 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 - propagate_dimensions: 22 + propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba embedding_dim: 128 width: 128 num_convs: 2 dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-cms-physical.yaml b/parameters/pyg-cms-physical.yaml index 59e0ef92c..c40f392e3 100644 --- a/parameters/pyg-cms-physical.yaml +++ b/parameters/pyg-cms-physical.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: @@ -26,16 +27,26 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention @@ -43,6 +54,23 @@ model: width: 256 num_convs: 3 dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-cms-small-highqcd.yaml b/parameters/pyg-cms-small-highqcd.yaml index bcd3c155b..c6f1c3af8 100644 --- a/parameters/pyg-cms-small-highqcd.yaml +++ b/parameters/pyg-cms-small-highqcd.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: @@ -26,16 +27,26 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention @@ -43,6 +54,23 @@ model: width: 256 num_convs: 3 dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-cms-small.yaml b/parameters/pyg-cms-small.yaml index 2642af156..d7677f906 100644 --- a/parameters/pyg-cms-small.yaml +++ b/parameters/pyg-cms-small.yaml @@ -8,6 +8,7 @@ load: num_epochs: 10 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gravnet ntrain: 500 ntest: 500 @@ -26,16 +27,26 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention @@ -43,6 +54,23 @@ model: width: 256 num_convs: 3 dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-cms-test-qcdhighpt.yaml b/parameters/pyg-cms-test-qcdhighpt.yaml index 637e2af81..ae8eb5680 100644 --- a/parameters/pyg-cms-test-qcdhighpt.yaml +++ b/parameters/pyg-cms-test-qcdhighpt.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: @@ -26,23 +27,50 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 - propagate_dimensions: 22 + propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba embedding_dim: 128 width: 128 num_convs: 2 dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-cms.yaml b/parameters/pyg-cms.yaml index 742c1863d..5d01c7c1f 100644 --- a/parameters/pyg-cms.yaml +++ b/parameters/pyg-cms.yaml @@ -5,13 +5,14 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: -num_epochs: 2 +num_epochs: 50 patience: 20 lr: 0.0001 +lr_schedule: cosinedecay # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: -nvalid: 500 +nvalid: num_workers: 0 prefetch_factor: checkpoint_freq: @@ -26,23 +27,50 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 - propagate_dimensions: 22 + propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba embedding_dim: 128 width: 128 num_convs: 2 dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path @@ -99,7 +127,7 @@ train_dataset: cms_pf_single_proton: version: 1.6.0 multiparticlegun: - batch_size: 2 + batch_size: 4 samples: cms_pf_multi_particle_gun: version: 1.6.0 @@ -109,8 +137,41 @@ valid_dataset: physical: batch_size: 1 samples: + cms_pf_ttbar: + version: 1.6.0 + cms_pf_qcd: + version: 1.6.0 + cms_pf_ztt: + version: 1.6.0 cms_pf_qcd_high_pt: version: 1.6.0 + cms_pf_sms_t1tttt: + version: 1.6.0 + gun: + batch_size: 20 + samples: + cms_pf_single_electron: + version: 1.6.0 + cms_pf_single_gamma: + version: 1.6.0 + cms_pf_single_pi0: + version: 1.6.0 + cms_pf_single_neutron: + version: 1.6.0 + cms_pf_single_pi: + version: 1.6.0 + cms_pf_single_tau: + version: 1.6.0 + cms_pf_single_mu: + version: 1.6.0 + cms_pf_single_proton: + version: 1.6.0 + multiparticlegun: + batch_size: 4 + samples: + cms_pf_multi_particle_gun: + version: 1.6.0 + test_dataset: cms: @@ -148,7 +209,7 @@ test_dataset: cms_pf_single_proton: version: 1.6.0 multiparticlegun: - batch_size: 2 + batch_size: 4 samples: cms_pf_multi_particle_gun: version: 1.6.0 diff --git a/parameters/pyg-delphes.yaml b/parameters/pyg-delphes.yaml index ace86b05c..cdd19ca0b 100644 --- a/parameters/pyg-delphes.yaml +++ b/parameters/pyg-delphes.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gnn_lsh ntrain: ntest: @@ -26,23 +27,50 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 - propagate_dimensions: 22 + propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba embedding_dim: 128 width: 128 num_convs: 2 dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 raytune: local_dir: # Note: please specify an absolute path diff --git a/parameters/pyg-workflow-test.yaml b/parameters/pyg-workflow-test.yaml index d3a179fd4..197c7c96c 100644 --- a/parameters/pyg-workflow-test.yaml +++ b/parameters/pyg-workflow-test.yaml @@ -8,6 +8,7 @@ load: num_epochs: 2 patience: 20 lr: 0.0001 +lr_schedule: constant # constant, cosinedecay, onecycle conv_type: gravnet ntrain: ntest: @@ -26,23 +27,50 @@ model: width: 512 num_convs: 3 dropout: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 gravnet: conv_type: gravnet embedding_dim: 512 width: 512 num_convs: 3 + dropout: 0.0 + activation: "elu" + # gravnet specific parameters k: 16 - propagate_dimensions: 22 + propagate_dimensions: 32 space_dimensions: 4 - dropout: 0.0 attention: conv_type: attention + embedding_dim: 256 + width: 256 + num_convs: 3 + dropout: 0.0 + activation: "elu" + # attention specific paramters + num_heads: 2 + + mamba: + conv_type: mamba embedding_dim: 128 width: 128 num_convs: 2 dropout: 0.0 + activation: "elu" + # transformer specific paramters + num_heads: 2 + # mamba specific paramters + d_state: 16 + d_conv: 4 + expand: 2 train_dataset: cms: diff --git a/scripts/flatiron/pt_raytrain_a100.slurm b/scripts/flatiron/pt_raytrain_a100.slurm index ecd68ba6b..54ca8516d 100644 --- a/scripts/flatiron/pt_raytrain_a100.slurm +++ b/scripts/flatiron/pt_raytrain_a100.slurm @@ -2,7 +2,7 @@ # Walltime limit #SBATCH -t 1:00:00 -#SBATCH -N 2 +#SBATCH -N 1 #SBATCH --exclusive #SBATCH --tasks-per-node=1 #SBATCH -p gpu @@ -11,7 +11,7 @@ #SBATCH --constraint=a100-80gb,ib # Job name -#SBATCH -J pt_train +#SBATCH -J pt_raytrain # Output and error logs #SBATCH -o logs_slurm/log_%x_%j.out diff --git a/scripts/flatiron/pt_raytrain_h100.slurm b/scripts/flatiron/pt_raytrain_h100.slurm index 0cb7d731e..f02781f98 100644 --- a/scripts/flatiron/pt_raytrain_h100.slurm +++ b/scripts/flatiron/pt_raytrain_h100.slurm @@ -11,7 +11,7 @@ #SBATCH --constraint=h100,ib # Job name -#SBATCH -J pt_train +#SBATCH -J pt_raytrain # Output and error logs #SBATCH -o logs_slurm/log_%x_%j.out diff --git a/scripts/flatiron/pt_raytune_a100_1GPUperTrial.sh b/scripts/flatiron/pt_raytune_a100_1GPUperTrial.sh index 0bf6da5e8..d0b6bd84a 100755 --- a/scripts/flatiron/pt_raytune_a100_1GPUperTrial.sh +++ b/scripts/flatiron/pt_raytune_a100_1GPUperTrial.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH -t 168:00:00 -#SBATCH -N 6 +#SBATCH -N 4 #SBATCH --tasks-per-node=1 #SBATCH -p gpu #SBATCH --constraint=a100-80gb,ib @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK/4)) \ - --ray-gpus 1 \ - --gpus "0" \ + --gpus 1 \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2 diff --git a/scripts/flatiron/pt_raytune_a100_4GPUsperTrial.sh b/scripts/flatiron/pt_raytune_a100_4GPUsperTrial.sh index 92d47bafc..d073672f2 100755 --- a/scripts/flatiron/pt_raytune_a100_4GPUsperTrial.sh +++ b/scripts/flatiron/pt_raytune_a100_4GPUsperTrial.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH -t 168:00:00 -#SBATCH -N 6 +#SBATCH -N 2 #SBATCH --tasks-per-node=1 #SBATCH -p gpu #SBATCH --constraint=a100-80gb,ib @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK)) \ - --ray-gpus $num_gpus \ - --gpus "0,1,2,3" \ + --gpus $num_gpus \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2 diff --git a/scripts/flatiron/pt_raytune_h100_1GPUperTrial.sh b/scripts/flatiron/pt_raytune_h100_1GPUperTrial.sh index 9bb71447e..2d6210c45 100755 --- a/scripts/flatiron/pt_raytune_h100_1GPUperTrial.sh +++ b/scripts/flatiron/pt_raytune_h100_1GPUperTrial.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH -t 168:00:00 -#SBATCH -N 3 +#SBATCH -N 2 #SBATCH --tasks-per-node=1 #SBATCH -p gpu #SBATCH --constraint=h100,ib @@ -31,7 +31,7 @@ which python3 python3 --version export CUDA_VISIBLE_DEVICES=0,1,2,3 -num_gpus=8 +num_gpus=${SLURM_GPUS_PER_TASK} # gpus per compute node ################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ############### @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK/8)) \ - --ray-gpus 1 \ - --gpus "0" \ + --gpus 1 \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2 diff --git a/scripts/flatiron/pt_raytune_h100_2GPUsperTrial.sh b/scripts/flatiron/pt_raytune_h100_2GPUsperTrial.sh index 99b4d5e4b..8607755d3 100755 --- a/scripts/flatiron/pt_raytune_h100_2GPUsperTrial.sh +++ b/scripts/flatiron/pt_raytune_h100_2GPUsperTrial.sh @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK/4)) \ - --ray-gpus 2 \ - --gpus "0,1" \ + --gpus 2 \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2 diff --git a/scripts/flatiron/pt_raytune_h100_4GPUsperTrial.sh b/scripts/flatiron/pt_raytune_h100_4GPUsperTrial.sh index 4171b9715..5a62e515d 100755 --- a/scripts/flatiron/pt_raytune_h100_4GPUsperTrial.sh +++ b/scripts/flatiron/pt_raytune_h100_4GPUsperTrial.sh @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK/2)) \ - --ray-gpus 4 \ - --gpus "0,1,2,3" \ + --gpus 4 \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2 diff --git a/scripts/flatiron/pt_raytune_h100_8GPUsperTrial.sh b/scripts/flatiron/pt_raytune_h100_8GPUsperTrial.sh index a40f47cc6..50aff995e 100755 --- a/scripts/flatiron/pt_raytune_h100_8GPUsperTrial.sh +++ b/scripts/flatiron/pt_raytune_h100_8GPUsperTrial.sh @@ -75,8 +75,7 @@ python3 -u mlpf/pyg_pipeline.py --train \ --config $1 \ --hpo $2 \ --ray-cpus $((SLURM_CPUS_PER_TASK)) \ - --ray-gpus $num_gpus \ - --gpus "0,1,2,3,4,5,6,7" \ + --gpus 8 \ --gpu-batch-multiplier 4 \ --num-workers 1 \ --prefetch-factor 2