Skip to content

Commit

Permalink
Feat val freq (jpata#298)
Browse files Browse the repository at this point in the history
* WIP: validate every val_freq training steps

* update count parameters function

* feat: implement intermediate validation every val_freq training steps

* fix: black formatting

* make ffn_dist_num_layers configurable in pytorch training

* Update pyg-clic-hits.yaml

* Update pyg-clic.yaml

* Update pyg-cms.yaml

* Update pyg-delphes.yaml

* update pt search space

* add script to count model parameters given a config file
  • Loading branch information
erwulff authored Mar 4, 2024
1 parent 1200034 commit 8d9065c
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 45 deletions.
38 changes: 38 additions & 0 deletions mlpf/count_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import yaml

sys.path.append("../mlpf")

from pyg.mlpf import MLPF
from pyg.utils import (
CLASS_LABELS,
X_FEATURES,
count_parameters,
)


with open(sys.argv[1], "r") as stream: # load config (includes: which physics samples, model params)
config = yaml.safe_load(stream)

model_kwargs = {
"input_dim": len(X_FEATURES[config["dataset"]]),
"num_classes": len(CLASS_LABELS[config["dataset"]]),
"pt_mode": config["model"]["pt_mode"],
"eta_mode": config["model"]["eta_mode"],
"sin_phi_mode": config["model"]["sin_phi_mode"],
"cos_phi_mode": config["model"]["cos_phi_mode"],
"energy_mode": config["model"]["energy_mode"],
"attention_type": config["model"]["attention"]["attention_type"],
**config["model"][config["conv_type"]],
}
model = MLPF(**model_kwargs)

trainable_params, nontrainable_params, table = count_parameters(model)

print(table)

print("Model conv type:", model.conv_type)
print("conv_type HPs", config["model"][config["conv_type"]])
print("Trainable parameters:", trainable_params)
print("Non-trainable parameters:", nontrainable_params)
print("Total parameters:", trainable_params + nontrainable_params)
2 changes: 2 additions & 0 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
layernorm=True,
num_node_messages=2,
ffn_dist_hidden_dim=128,
ffn_dist_num_layers=2,
# self-attention specific parameters
num_heads=2,
# mamba specific parameters
Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(
"num_node_messages": num_node_messages,
"dropout": dropout,
"ffn_dist_hidden_dim": ffn_dist_hidden_dim,
"ffn_dist_num_layers": ffn_dist_num_layers,
}
self.conv_id.append(CombinedGraphLayer(**gnn_conf))
self.conv_reg.append(CombinedGraphLayer(**gnn_conf))
Expand Down
94 changes: 86 additions & 8 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime
import tqdm
import yaml
import csv

import numpy as np

Expand Down Expand Up @@ -199,14 +200,17 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
def train_and_valid(
rank,
world_size,
outdir,
model,
optimizer,
data_loader,
train_loader,
valid_loader,
is_train=True,
lr_schedule=None,
comet_experiment=None,
comet_step_freq=None,
epoch=None,
val_freq=None,
dtype=torch.float32,
):
"""
Expand All @@ -221,8 +225,10 @@ def train_and_valid(

if is_train:
model.train()
data_loader = train_loader
else:
model.eval()
data_loader = valid_loader

# only show progress bar on rank 0
if (world_size > 1) and (rank != 0):
Expand All @@ -234,6 +240,7 @@ def train_and_valid(

device_type = "cuda" if isinstance(rank, int) else "cpu"

val_freq_time_0 = time.time()
for itrain, batch in iterator:
batch = batch.to(rank, non_blocking=True)

Expand Down Expand Up @@ -282,6 +289,57 @@ def train_and_valid(
comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=step)
comet_experiment.log_metric("learning_rate", lr_schedule.get_last_lr(), step=step)

if val_freq is not None and is_train:
if itrain != 0 and itrain % val_freq == 0:
# time since last intermediate validation run
val_freq_time = torch.tensor(time.time() - val_freq_time_0, device=rank)
if world_size > 1:
torch.distributed.all_reduce(val_freq_time)
# compute intermediate training loss
intermediate_losses_t = {key: epoch_loss[key] for key in epoch_loss}
for loss_ in epoch_loss:
# sum up the losses from all workers and dicide by
if world_size > 1:
torch.distributed.all_reduce(intermediate_losses_t[loss_])
intermediate_losses_t[loss_] = intermediate_losses_t[loss_].cpu().item() / itrain

# compute intermediate validation loss
intermediate_losses_v = train_and_valid(
rank,
world_size,
outdir,
model,
optimizer,
train_loader,
valid_loader,
is_train=False,
epoch=epoch,
dtype=dtype,
)
intermediate_metrics = dict(
loss=intermediate_losses_t["Total"],
reg_loss=intermediate_losses_t["Regression"],
cls_loss=intermediate_losses_t["Classification"],
charge_loss=intermediate_losses_t["Charge"],
val_loss=intermediate_losses_v["Total"],
val_reg_loss=intermediate_losses_v["Regression"],
val_cls_loss=intermediate_losses_v["Classification"],
val_charge_loss=intermediate_losses_v["Charge"],
inside_epoch=epoch,
step=(epoch - 1) * len(data_loader) + itrain,
val_freq_time=val_freq_time.cpu().item(),
)
val_freq_log = os.path.join(outdir, "val_freq_log.csv")
if (rank == 0) or (rank == "cpu"):
with open(val_freq_log, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=intermediate_metrics.keys())
if os.stat(val_freq_log).st_size == 0: # only write header if file is empty
writer.writeheader()
writer.writerow(intermediate_metrics)
if comet_experiment:
comet_experiment.log_metrics(intermediate_losses_v, prefix="valid", step=step)
val_freq_time_0 = time.time() # reset intermediate validation spacing timer

num_data = torch.tensor(len(data_loader), device=rank)
# sum up the number of steps from all workers
if world_size > 1:
Expand Down Expand Up @@ -316,6 +374,7 @@ def train_mlpf(
checkpoint_freq=None,
comet_experiment=None,
comet_step_freq=None,
val_freq=None,
):
"""
Will run a full training by calling train().
Expand Down Expand Up @@ -357,30 +416,45 @@ def train_mlpf(
) as prof:
with record_function("model_train"):
losses_t = train_and_valid(
rank, world_size, model, optimizer, train_loader, is_train=True, lr_schedule=lr_schedule, dtype=dtype
rank,
world_size,
outdir,
model,
optimizer,
train_loader=train_loader,
valid_loader=valid_loader,
is_train=True,
lr_schedule=lr_schedule,
val_freq=val_freq,
dtype=dtype,
)
prof.export_chrome_trace("trace.json")
else:
losses_t = train_and_valid(
rank,
world_size,
outdir,
model,
optimizer,
train_loader,
train_loader=train_loader,
valid_loader=valid_loader,
is_train=True,
lr_schedule=lr_schedule,
comet_experiment=comet_experiment,
comet_step_freq=comet_step_freq,
epoch=epoch,
val_freq=val_freq,
dtype=dtype,
)

losses_v = train_and_valid(
rank,
world_size,
outdir,
model,
optimizer,
valid_loader,
train_loader=train_loader,
valid_loader=valid_loader,
is_train=False,
lr_schedule=None,
comet_experiment=comet_experiment,
Expand Down Expand Up @@ -663,6 +737,7 @@ def run(rank, world_size, config, args, outdir, logfile):
checkpoint_freq=config["checkpoint_freq"],
comet_experiment=comet_experiment,
comet_step_freq=config["comet_step_freq"],
val_freq=config["val_freq"],
)

checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank))
Expand Down Expand Up @@ -852,8 +927,13 @@ def train_ray_trial(config, args, outdir=None):
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

trainable_params, nontrainable_params, table = count_parameters(model)
print(table)

if (rank == 0) or (rank == "cpu"):
with open(os.path.join(outdir, "num_params.csv"), "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["trainable_params", "nontrainable_params", "total_params"])
writer.writerow([trainable_params, nontrainable_params, trainable_params + nontrainable_params])
_logger.info(model)
_logger.info(f"Trainable parameters: {trainable_params}")
_logger.info(f"Non-trainable parameters: {nontrainable_params}")
Expand Down Expand Up @@ -926,6 +1006,8 @@ def train_ray_trial(config, args, outdir=None):
checkpoint_freq=config["checkpoint_freq"],
comet_experiment=comet_experiment,
comet_step_freq=config["comet_step_freq"],
dtype=getattr(torch, config["dtype"]),
val_freq=config["val_freq"],
)


Expand Down Expand Up @@ -1121,10 +1203,6 @@ def run_hpo(config, args):
print(result_df)
print(result_df.columns)

print("Number of errored trials: {}".format(result_grid.num_errors))
print("Number of terminated (not errored) trials: {}".format(result_grid.num_terminated))
print("Ray Tune experiment path: {}".format(result_grid.experiment_path))

logging.info("Total time of Tuner.fit(): {}".format(end - start))
logging.info(
"Best hyperparameters found according to {} were: {}".format(config["raytune"]["default_metric"], best_config)
Expand Down
12 changes: 5 additions & 7 deletions mlpf/pyg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=-
steps_per_epoch=steps_per_epoch,
epochs=epochs,
last_epoch=last_batch,
pct_start=config["lr_schedule_config"]["onecycle"]["pct_start"] or 0.3,
)
elif config["lr_schedule"] == "cosinedecay":
lr_schedule = CosineAnnealingLR(opt, T_max=steps_per_epoch * epochs, last_epoch=last_batch)
Expand All @@ -281,7 +282,8 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=-


def count_parameters(model):
table = pd.DataFrame(columns=["Modules", "Trainable params", "Non-tranable params"])
column_names = ["Modules", "Trainable parameters", "Non-tranable parameters"]
table = pd.DataFrame(columns=column_names)
trainable_params = 0
nontrainable_params = 0
for ii, (name, parameter) in enumerate(model.named_parameters()):
Expand All @@ -290,19 +292,15 @@ def count_parameters(model):
table = pd.concat(
[
table,
pd.DataFrame(
{"Modules": name, "Trainable Parameters": "-", "Non-tranable Parameters": params}, index=[ii]
),
pd.DataFrame({column_names[0]: name, column_names[1]: 0, column_names[2]: params}, index=[ii]),
]
)
nontrainable_params += params
else:
table = pd.concat(
[
table,
pd.DataFrame(
{"Modules": name, "Trainable Parameters": params, "Non-tranable Parameters": "-"}, index=[ii]
),
pd.DataFrame({column_names[0]: name, column_names[1]: params, column_names[2]: 0}, index=[ii]),
]
)
trainable_params += params
Expand Down
1 change: 1 addition & 0 deletions mlpf/pyg_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset")
parser.add_argument("--ntest", type=int, default=None, help="training samples to use, if None use entire dataset")
parser.add_argument("--nvalid", type=int, default=None, help="validation samples to use")
parser.add_argument("--val-freq", type=int, default=None, help="run extra validation every val_freq training steps")
parser.add_argument("--checkpoint-freq", type=int, default=None, help="epoch frequency for checkpointing")
parser.add_argument("--hpo", type=str, default=None, help="perform hyperparameter optimization, name of HPO experiment")
parser.add_argument("--ray-train", action="store_true", help="run training using Ray Train")
Expand Down
55 changes: 35 additions & 20 deletions mlpf/raytune/pt_search_space.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,53 @@
from ray.tune import choice # grid_search, choice, loguniform, quniform
from ray.tune import grid_search # grid_search, choice, loguniform, quniform

raytune_num_samples = 16 # Number of random samples to draw from search space. Set to 1 for grid search.
samp = choice
raytune_num_samples = 1 # Number of random samples to draw from search space. Set to 1 for grid search.
samp = grid_search

# gnn scan
search_space = {
# dataset parameters
"ntrain": samp([500]),
# "ntest": samp([10000]),
"nvalid": samp([500]),
"num_epochs": samp([10]),
# optimizer parameters
"lr": samp([1e-4, 3e-4, 1e-3, 3e-3, 1e-2]),
"gpu_batch_multiplier": samp([1, 4, 8, 16]),
# model arch parameters
"activation": samp(["elu", "relu", "relu6", "leakyrelu"]),
"conv_type": samp(["gravnet"]), # can be "gnn_lsh", "gravnet", "attention"
"embedding_dim": samp([32, 64, 128, 252, 512, 1024]),
"width": samp([32, 64, 128, 256, 512, 1024]),
"num_convs": samp([1, 2, 3, 4, 5, 6]),
"dropout": samp([0.0, 0.01, 0.1, 0.4]),
"lr": samp([1e-4, 3e-4, 1e-3, 3e-3]),
"lr_schedule": samp(["onecycle"]),
"pct_start": samp([0.05]),
# "gpu_batch_multiplier": samp([1, 4, 8, 16]),
# "patience": samp([9999]),
# model arch parameters
# "activation": samp(["elu", "relu", "relu6", "leakyrelu"]),
"conv_type": samp(["attention"]), # can be "gnn_lsh", "gravnet", "attention"
# "embedding_dim": samp([32, 64, 128, 252, 512, 1024]),
# "width": samp([32, 64, 128, 256, 512, 1024]),
# "num_convs": samp([1, 2, 3, 4, 5, 6]),
# "dropout": samp([0.0, 0.01, 0.1, 0.4]),
# only for gravnet
"k": samp([8, 16, 32]),
"propagate_dimensions": samp([8, 16, 32, 64, 128]),
"space_dimensions": samp([4]),
# "k": samp([8, 16, 32]),
# "propagate_dimensions": samp([8, 16, 32, 64, 128]),
# "space_dimensions": samp([4]),
# only for gnn-lsh
# "bin_size": samp([160, 320, 640]),
# "max_num_bins": samp([200]),
# "distance_dim": samp([16, 32, 64, 128, 256]),
# "layernorm": samp([True, False]),
# "num_node_messages": samp([1, 2, 3, 4, 5]),
# "ffn_dist_hidden_dim": samp([16, 32, 64, 128, 256]),
# mamba specific variables
"d_state": samp([16]),
"d_conv": samp([4]),
"expand": samp([2]),
# "ffn_dist_num_layers": samp([1, 2, 3, 4, 5, 6]),
# mamba specific parameters
# "d_state": samp([16]),
# "d_conv": samp([4]),
# "expand": samp([2]),
# "num_heads": samp([2, 4, 6, 8, 10, 12]),
# attention specifica parameters
"num_heads": samp([2, 4, 8, 16]),
# "attention_type": samp(["flash"]), # flash, efficient, math
}


def set_hps_from_search_space(search_space, config):
varaible_names = ["lr", "gpu_batch_multiplier"]
varaible_names = ["lr", "lr_schedule", "gpu_batch_multiplier", "ntrain", "ntest", "nvalid", "num_epochs", "patience"]
for var in varaible_names:
if var in search_space.keys():
config[var] = search_space[var]
Expand Down Expand Up @@ -81,4 +93,7 @@ def set_hps_from_search_space(search_space, config):
if var in search_space.keys():
config["model"][conv_type][var] = search_space[var]

if "pct_start" in search_space.keys():
config["lr_schedule_config"]["onecycle"]["pct_start"] = search_space["pct_start"]

return config
Loading

0 comments on commit 8d9065c

Please sign in to comment.