From 2579e768c71a32611c809c72ed8bce30f59fedc0 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Fri, 16 Aug 2024 14:28:19 +0300 Subject: [PATCH] CMS dataset relabel, generate v2.1.0 with more stats, separate binary classifier (#340) * dataset relabeling --- mlpf/data_cms/postprocessing2.py | 4 + mlpf/heptfds/clic_pf_edm4hep/qq.py | 3 +- mlpf/heptfds/clic_pf_edm4hep/ttbar.py | 3 +- mlpf/heptfds/cms_pf/cms_utils.py | 6 +- mlpf/heptfds/cms_pf/qcd.py | 3 +- mlpf/heptfds/cms_pf/ttbar.py | 3 +- mlpf/heptfds/cms_pf/vbf.py | 66 ++++++++++ mlpf/plotting/plot_utils.py | 63 +++++++++ mlpf/pyg/PFDataset.py | 29 ++++- mlpf/pyg/inference.py | 5 +- mlpf/pyg/mlpf.py | 16 ++- mlpf/pyg/training.py | 69 +++++++++- mlpf/pyg/utils.py | 13 +- notebooks/cms/cms-simvalidation.ipynb | 55 +++++++- parameters/pytorch/pyg-cms-ttbar-nopu.yaml | 122 ++++++++++++++++++ parameters/pytorch/pyg-cms.yaml | 39 +----- scripts/clic/postprocessing_jobs.py | 4 +- .../tallinn/a100/pytorch-small-eval-clic.sh | 4 +- .../tallinn/a100/pytorch-small-eval-cms.sh | 12 +- scripts/tallinn/a100/pytorch.sh | 8 +- 20 files changed, 442 insertions(+), 85 deletions(-) create mode 100644 mlpf/heptfds/cms_pf/vbf.py create mode 100644 parameters/pytorch/pyg-cms-ttbar-nopu.yaml diff --git a/mlpf/data_cms/postprocessing2.py b/mlpf/data_cms/postprocessing2.py index 423a39aab..1a36e4fac 100644 --- a/mlpf/data_cms/postprocessing2.py +++ b/mlpf/data_cms/postprocessing2.py @@ -885,6 +885,10 @@ def process(args): "genmet": genmet, } + # print("trk", ygen[Xelem["typ"] == 1]["typ"]) + # print("ecal", ygen[Xelem["typ"] == 4]["typ"]) + # print("hcal", ygen[Xelem["typ"] == 4]["typ"]) + if args.save_full_graph: data["full_graph"] = g diff --git a/mlpf/heptfds/clic_pf_edm4hep/qq.py b/mlpf/heptfds/clic_pf_edm4hep/qq.py index 5d7149439..89c22ef5f 100644 --- a/mlpf/heptfds/clic_pf_edm4hep/qq.py +++ b/mlpf/heptfds/clic_pf_edm4hep/qq.py @@ -26,7 +26,7 @@ class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder): - VERSION = tfds.core.Version("2.0.0") + VERSION = tfds.core.Version("2.1.0") RELEASE_NOTES = { "1.0.0": "Initial release.", "1.1.0": "update stats, move to 380 GeV", @@ -36,6 +36,7 @@ class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder): "1.4.0": "Fix ycand matching", "1.5.0": "Regenerate with ARRAY_RECORD", "2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1", + "2.1.0": "Bump dataset size", } MANUAL_DOWNLOAD_INSTRUCTIONS = """ For the raw input files in ROOT EDM4HEP format, please see the citation above. diff --git a/mlpf/heptfds/clic_pf_edm4hep/ttbar.py b/mlpf/heptfds/clic_pf_edm4hep/ttbar.py index 47af2aade..9a01aa81f 100644 --- a/mlpf/heptfds/clic_pf_edm4hep/ttbar.py +++ b/mlpf/heptfds/clic_pf_edm4hep/ttbar.py @@ -26,7 +26,7 @@ class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder): - VERSION = tfds.core.Version("2.0.0") + VERSION = tfds.core.Version("2.1.0") RELEASE_NOTES = { "1.0.0": "Initial release.", "1.1.0": "update stats, move to 380 GeV", @@ -35,6 +35,7 @@ class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder): "1.4.0": "Fix ycand matching", "1.5.0": "Regenerate with ARRAY_RECORD", "2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1", + "2.1.0": "Bump dataset size", } MANUAL_DOWNLOAD_INSTRUCTIONS = """ For the raw input files in ROOT EDM4HEP format, please see the citation above. diff --git a/mlpf/heptfds/cms_pf/cms_utils.py b/mlpf/heptfds/cms_pf/cms_utils.py index 1f154592b..db0f8fe67 100644 --- a/mlpf/heptfds/cms_pf/cms_utils.py +++ b/mlpf/heptfds/cms_pf/cms_utils.py @@ -115,7 +115,7 @@ ] -def prepare_data_cms(fn, with_jet_idx=False): +def prepare_data_cms(fn): Xs = [] ygens = [] ycands = [] @@ -147,10 +147,6 @@ def prepare_data_cms(fn, with_jet_idx=False): ygen["typ_idx"] = np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32) ycand["typ_idx"] = np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32) - if with_jet_idx: - ygen["jet_idx"] = np.zeros(len(ygen["typ"]), dtype=np.float32) - ycand["jet_idx"] = np.zeros(len(ycand["typ"]), dtype=np.float32) - Xelem_flat = ak.to_numpy( np.stack( [Xelem[k] for k in X_FEATURES], diff --git a/mlpf/heptfds/cms_pf/qcd.py b/mlpf/heptfds/cms_pf/qcd.py index eaf2d21ca..f772bead8 100644 --- a/mlpf/heptfds/cms_pf/qcd.py +++ b/mlpf/heptfds/cms_pf/qcd.py @@ -21,7 +21,7 @@ class CmsPfQcd(tfds.core.GeneratorBasedBuilder): """DatasetBuilder for cms_pf_qcd dataset.""" - VERSION = tfds.core.Version("2.0.0") + VERSION = tfds.core.Version("2.1.0") RELEASE_NOTES = { "1.3.0": "12_2_0_pre2 generation with updated caloparticle/trackingparticle", "1.3.1": "Remove PS again", @@ -32,6 +32,7 @@ class CmsPfQcd(tfds.core.GeneratorBasedBuilder): "1.7.0": "Add cluster shape vars", "1.7.1": "Increase stats to 400k events", "2.0.0": "New truth def based primarily on CaloParticles", + "2.1.0": "Additional stats", } MANUAL_DOWNLOAD_INSTRUCTIONS = """ rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_qcd ~/tensorflow_datasets/ diff --git a/mlpf/heptfds/cms_pf/ttbar.py b/mlpf/heptfds/cms_pf/ttbar.py index 4a2e1933b..2eec180d7 100644 --- a/mlpf/heptfds/cms_pf/ttbar.py +++ b/mlpf/heptfds/cms_pf/ttbar.py @@ -21,7 +21,7 @@ class CmsPfTtbar(tfds.core.GeneratorBasedBuilder): """DatasetBuilder for cms_pf dataset.""" - VERSION = tfds.core.Version("2.0.0") + VERSION = tfds.core.Version("2.1.0") RELEASE_NOTES = { "1.0.0": "Initial release.", "1.1.0": "Add muon type, fix electron GSF association", @@ -36,6 +36,7 @@ class CmsPfTtbar(tfds.core.GeneratorBasedBuilder): "1.7.1": "Increase stats to 400k events", "1.8.0": "Add ispu, genjets, genmet; disable genjet_idx; improved merging", "2.0.0": "New truth def based primarily on CaloParticles", + "2.1.0": "Additional stats", } MANUAL_DOWNLOAD_INSTRUCTIONS = """ rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_ttbar ~/tensorflow_datasets/ diff --git a/mlpf/heptfds/cms_pf/vbf.py b/mlpf/heptfds/cms_pf/vbf.py new file mode 100644 index 000000000..ee05cbc1c --- /dev/null +++ b/mlpf/heptfds/cms_pf/vbf.py @@ -0,0 +1,66 @@ +"""CMS PF TTbar dataset.""" +import cms_utils +import tensorflow as tf + +import tensorflow_datasets as tfds + +X_FEATURES = cms_utils.X_FEATURES +Y_FEATURES = cms_utils.Y_FEATURES + +_DESCRIPTION = """ +Dataset generated with CMSSW and full detector sim. + +VBF events with PU 55-75 in a Run3 setup. +""" + +# TODO(cms_pf): BibTeX citation +_CITATION = """ +""" + + +class CmsPfVbf(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for cms_pf_vbf dataset.""" + + VERSION = tfds.core.Version("2.1.0") + RELEASE_NOTES = { + "1.7.1": "First version", + "1.8.0": "Add ispu, genjets, genmet; disable genjet_idx; improved merging", + "2.0.0": "New truth def based primarily on CaloParticles", + "2.1.0": "Additional statistics", + } + MANUAL_DOWNLOAD_INSTRUCTIONS = """ + rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_vbf ~/tensorflow_datasets/ + """ + + def __init__(self, *args, **kwargs): + kwargs["file_format"] = tfds.core.FileFormat.ARRAY_RECORD + super(CmsPfVbf, self).__init__(*args, **kwargs) + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict( + { + "X": tfds.features.Tensor(shape=(None, len(X_FEATURES)), dtype=tf.float32), + "ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), + "ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32), + "genmet": tfds.features.Scalar(dtype=tf.float32), + "genjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32), + } + ), + supervised_keys=("X", "ygen"), + homepage="", + citation=_CITATION, + metadata=tfds.core.MetadataDict(x_features=X_FEATURES, y_features=Y_FEATURES), + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + path = dl_manager.manual_dir + sample_dir = "VBF_TuneCP5_14TeV_pythia8_cfi" + return cms_utils.split_sample(path / sample_dir / "raw") + + def _generate_examples(self, files): + return cms_utils.generate_examples(files) diff --git a/mlpf/plotting/plot_utils.py b/mlpf/plotting/plot_utils.py index ef4c89a12..25875c0d7 100644 --- a/mlpf/plotting/plot_utils.py +++ b/mlpf/plotting/plot_utils.py @@ -115,6 +115,7 @@ def get_class_names(sample_name): "clic_edm_zh_tautau_pf": r"$e^+e^- \rightarrow ZH \rightarrow \tau \tau$", "cms_pf_qcd": r"QCD $p_T \in [15, 3000]\ \mathrm{GeV}$+PU", "cms_pf_ztt": r"$\mathrm{Z}\rightarrow \mathrm{\tau}\mathrm{\tau}$+PU", + "cms_pf_vbf": r"VBF+PU", "cms_pf_ttbar": r"$\mathrm{t}\overline{\mathrm{t}}$+PU", "cms_pf_ttbar_nopu": r"$\mathrm{t}\overline{\mathrm{t}}$", "cms_pf_qcd_nopu": r"QCD $p_T \in [15, 3000]\ \mathrm{GeV}$", @@ -1076,6 +1077,68 @@ def plot_particle_multiplicity(X, yvals, class_names, epoch=None, cp_dir=None, c ) +def plot_particle_ratio(yvals, class_names, epoch=None, cp_dir=None, comet_experiment=None, title=None, sample=None, dataset=None): + msk_cand = yvals["cand_cls_id"] != 0 + msk_pred = yvals["pred_cls_id"] != 0 + msk_gen = yvals["gen_cls_id"] != 0 + + cand_pt = awkward.to_numpy(awkward.flatten(yvals["cand_pt"][msk_gen & msk_cand])) + pred_pt = awkward.to_numpy(awkward.flatten(yvals["pred_pt"][msk_gen & msk_pred])) + gen_cand_pt = awkward.to_numpy(awkward.flatten(yvals["gen_pt"][msk_gen & msk_cand])) + gen_pred_pt = awkward.to_numpy(awkward.flatten(yvals["gen_pt"][msk_gen & msk_pred])) + ratio_cand_pt = cand_pt / gen_cand_pt + ratio_pred_pt = pred_pt / gen_pred_pt + + cand_e = awkward.to_numpy(awkward.flatten(yvals["cand_energy"][msk_gen & msk_cand])) + pred_e = awkward.to_numpy(awkward.flatten(yvals["pred_energy"][msk_gen & msk_pred])) + gen_cand_e = awkward.to_numpy(awkward.flatten(yvals["gen_energy"][msk_gen & msk_cand])) + gen_pred_e = awkward.to_numpy(awkward.flatten(yvals["gen_energy"][msk_gen & msk_pred])) + ratio_cand_e = cand_e / gen_cand_e + ratio_pred_e = pred_e / gen_pred_e + + gen_cls_id = awkward.flatten(yvals["gen_cls_id"][msk_gen]) + gen_cls_id1 = awkward.flatten(yvals["gen_cls_id"][msk_gen & msk_cand]) + gen_cls_id2 = awkward.flatten(yvals["gen_cls_id"][msk_gen & msk_pred]) + cls_ids = np.unique(awkward.values_astype(gen_cls_id, np.int64)) + print("cls_ids", cls_ids) + for cls_id in cls_ids: + if cls_id == 0: + continue + clname = class_names[cls_id] + + plt.figure() + b = np.linspace(0, 5, 100) + plt.hist(ratio_cand_pt[gen_cls_id1 == cls_id], bins=b, label="PF", histtype="step") + plt.hist(ratio_pred_pt[gen_cls_id2 == cls_id], bins=b, label="MLPF", histtype="step") + plt.legend(loc="best") + if title: + plt.title(title + ", " + clname) + save_img( + "particle_pt_ratio_{}.png".format(cls_id), + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) + plt.xlabel("Reconstructed / target $p_T$") + plt.clf() + + plt.figure() + b = np.linspace(0, 5, 100) + plt.hist(ratio_cand_e[gen_cls_id1 == cls_id], bins=b, label="PF", histtype="step") + plt.hist(ratio_pred_e[gen_cls_id2 == cls_id], bins=b, label="MLPF", histtype="step") + plt.legend(loc="best") + if title: + plt.title(title + ", " + clname) + save_img( + "particle_e_ratio_{}.png".format(cls_id), + epoch, + cp_dir=cp_dir, + comet_experiment=comet_experiment, + ) + plt.xlabel("Reconstructed / target $E$") + plt.clf() + + def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None, sample=None, dataset=None): msk_cand = yvals["cand_cls_id"] != 0 cand_pt = awkward.to_numpy(awkward.flatten(yvals["cand_pt"][msk_cand], axis=1)) diff --git a/mlpf/pyg/PFDataset.py b/mlpf/pyg/PFDataset.py index b9ee72541..7ca195e43 100644 --- a/mlpf/pyg/PFDataset.py +++ b/mlpf/pyg/PFDataset.py @@ -34,6 +34,33 @@ def __getitem__(self, item): ret["ycand"] = ret["ycand"][sortidx] ret["ygen"] = ret["ygen"][sortidx] + if self.ds.dataset_info.name.startswith("cms_"): + ret["ygen"][:, 0][(ret["X"][:, 0] == 1) & (ret["ygen"][:, 0] == 2)] = 1 + ret["ygen"][:, 0][(ret["X"][:, 0] == 1) & (ret["ygen"][:, 0] == 5)] = 1 + + ret["ygen"][:, 0][(ret["X"][:, 0] == 4) & (ret["ygen"][:, 0] == 1)] = 5 + ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 1)] = 2 + # ret["ygen"][:, 0][(ret["X"][:, 0]==4) & (ret["ygen"][:, 0] == 6)] = 5 + ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 6)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 4) & (ret["ygen"][:, 0] == 7)] = 5 + ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 7)] = 2 + + ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 1)] = 4 + ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 1)] = 3 + ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 2)] = 4 + ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 2)] = 3 + ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 6)] = 4 + ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 6)] = 3 + ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 7)] = 4 + ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 7)] = 3 + + ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 1)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 1)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 6)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 6)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 7)] = 2 + ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 7)] = 2 + return ret def __len__(self): @@ -181,7 +208,7 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray): if world_size > 1: sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: - sampler = torch.utils.data.RandomSampler(dataset) + sampler = torch.utils.data.SequentialSampler(dataset) # build dataloaders batch_size = config[f"{split}_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"] diff --git a/mlpf/pyg/inference.py b/mlpf/pyg/inference.py index 5b01624a3..8d874c025 100644 --- a/mlpf/pyg/inference.py +++ b/mlpf/pyg/inference.py @@ -11,6 +11,7 @@ import vector from jet_utils import build_dummy_array, match_two_jet_collections from plotting.plot_utils import ( + get_class_names, compute_met_and_ratio, load_eval_data, plot_jets, @@ -22,6 +23,7 @@ plot_met_response_binned, plot_num_elements, plot_particles, + plot_particle_ratio, plot_sum_energy, ) @@ -178,7 +180,7 @@ def make_plots(outpath, sample, dataset, dir_name=""): """Uses the predictions stored as .parquet files (see above) to make plots.""" mplhep.style.use(mplhep.styles.CMS) - + class_names = get_class_names(sample) os.system(f"mkdir -p {outpath}/plots{dir_name}/{sample}") plots_path = Path(f"{outpath}/plots{dir_name}/{sample}/") @@ -241,3 +243,4 @@ def make_plots(outpath, sample, dataset, dir_name=""): plot_met_response_binned(met_data, cp_dir=plots_path, dataset=dataset, sample=sample) plot_particles(yvals, cp_dir=plots_path, dataset=dataset, sample=sample) + plot_particle_ratio(yvals, class_names, cp_dir=plots_path, dataset=dataset, sample=sample) diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index 12b654a7a..60d3d50fd 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -282,10 +282,11 @@ def __init__( decoding_dim = self.input_dim + embedding_dim # DNN that acts on the node level to predict the PID - self.nn_id = ffn(decoding_dim, num_classes, width, self.act, dropout_ff) + self.nn_binary_particle = ffn(decoding_dim, 2, width, self.act, dropout_ff) + self.nn_pid = ffn(decoding_dim, num_classes, width, self.act, dropout_ff) # elementwise DNN for node momentum regression - embed_dim = decoding_dim + num_classes + embed_dim = decoding_dim + 2 + num_classes self.nn_pt = RegressionOutput(pt_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero) self.nn_eta = RegressionOutput(eta_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero) self.nn_sin_phi = RegressionOutput(sin_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero) @@ -319,20 +320,21 @@ def forward(self, X_features, mask): out_padded = conv(conv_input, mask) embeddings_reg.append(out_padded) + # id input if self.learned_representation_mode == "concat": final_embedding_id = torch.cat([Xfeat_normed] + embeddings_id, axis=-1) elif self.learned_representation_mode == "last": final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1) - preds_id = self.nn_id(final_embedding_id) + preds_binary_particle = self.nn_binary_particle(final_embedding_id) + preds_pid = self.nn_pid(final_embedding_id) # pred_charge = self.nn_charge(final_embedding_id) # regression input if self.learned_representation_mode == "concat": - final_embedding_reg = torch.cat([Xfeat_normed] + embeddings_reg + [preds_id], axis=-1) + final_embedding_reg = torch.cat([Xfeat_normed] + embeddings_reg + [preds_binary_particle.detach(), preds_pid.detach()], axis=-1) elif self.learned_representation_mode == "last": - final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1) - final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_id], axis=-1) + final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_binary_particle.detach(), preds_pid.detach()], axis=-1) # The PFElement feature order in X_features defined in fcc/postprocessing.py preds_pt = self.nn_pt(X_features, final_embedding_reg, X_features[..., 1:2]) @@ -342,4 +344,4 @@ def forward(self, X_features, mask): preds_energy = self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6]) preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1) - return preds_id, preds_momentum + return preds_binary_particle, preds_pid, preds_momentum diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index aa9e9d770..da0d5a534 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -12,7 +12,8 @@ import yaml import csv import json - +import sklearn +import sklearn.metrics import numpy as np # comet needs to be imported before torch @@ -85,13 +86,28 @@ def mlpf_loss(y, ypred, batch): y["momentum"] = y["momentum"] * msk_true_particle # in case of the 3D-padded mode, pytorch expects (batch, num_classes, ...) + ypred["cls_binary"] = ypred["cls_binary"].permute((0, 2, 1)) ypred["cls_id_onehot"] = ypred["cls_id_onehot"].permute((0, 2, 1)) - loss_classification = 100 * loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]).reshape(y["cls_id"].shape) + # binary loss for particle / no-particle classification + loss_binary_classification = 100 * loss_obj_id(ypred["cls_binary"], (y["cls_id"] != 0).long()).reshape(y["cls_id"].shape) + + # compare the particle type, only for cases where there was a true particle + loss_pid_classification = 100 * loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]).reshape(y["cls_id"].shape) + loss_pid_classification[y["cls_id"] == 0] *= 0 + + # compare particle momentum, only for cases where there was a true particle loss_regression = 10 * torch.nn.functional.huber_loss(ypred["momentum"], y["momentum"], reduction="none") + loss_regression[y["cls_id"] == 0] *= 0 + + # set the loss to 0 on padded elements in the batch + loss_binary_classification[batch.mask == 0] *= 0 + loss_pid_classification[batch.mask == 0] *= 0 + loss_regression[batch.mask == 0] *= 0 # average over all elements that were not padded - loss["Classification"] = loss_classification.sum() / nelem + loss["Classification_binary"] = loss_binary_classification.sum() / nelem + loss["Classification"] = loss_pid_classification.sum() / nelem # normalize loss with stddev to stabilize across batches with very different pt, E distributions mom_normalizer = y["momentum"][y["cls_id"] != 0].std(axis=0) @@ -101,7 +117,7 @@ def mlpf_loss(y, ypred, batch): loss["Regression"] = (reg_losses / mom_normalizer).sum() / npart # in case we are using the 3D-padded mode, we can compute a few additional event-level monitoring losses - msk_pred_particle = torch.unsqueeze(torch.argmax(ypred["cls_id_onehot"].detach(), axis=1) != 0, axis=-1) + msk_pred_particle = torch.unsqueeze(torch.argmax(ypred["cls_binary"].detach(), axis=1) != 0, axis=-1) # pt * cos_phi px = ypred["momentum"][..., 0:1].detach() * ypred["momentum"][..., 3:4].detach() * msk_pred_particle # pt * sin_phi @@ -112,8 +128,9 @@ def mlpf_loss(y, ypred, batch): loss["MET"] = torch.nn.functional.huber_loss(pred_met.squeeze(dim=-1), batch.genmet).mean() loss["Sliced_Wasserstein_Loss"] = sliced_wasserstein_loss(ypred["momentum"].detach(), y["momentum"]).mean() - loss["Total"] = loss["Classification"] + loss["Regression"] + loss["Total"] = loss["Classification_binary"] + loss["Classification"] + loss["Regression"] + loss["Classification_binary"] = loss["Classification_binary"].detach() loss["Classification"] = loss["Classification"].detach() loss["Regression"] = loss["Regression"].detach() return loss @@ -254,6 +271,12 @@ def train_and_valid( loss_accum = 0.0 val_freq_time_0 = time.time() + + if not is_train: + cm_X_gen = np.zeros((13, 13)) + cm_X_pred = np.zeros((13, 13)) + cm_id = np.zeros((13, 13)) + for itrain, batch in iterator: batch = batch.to(rank, non_blocking=True) @@ -271,6 +294,17 @@ def train_and_valid( ypred = unpack_predictions(ypred) + if not is_train: + cm_X_gen += sklearn.metrics.confusion_matrix( + batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), ygen["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + ) + cm_X_pred += sklearn.metrics.confusion_matrix( + batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), ypred["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + ) + cm_id += sklearn.metrics.confusion_matrix( + ygen["cls_id"][batch.mask].detach().cpu().numpy(), ypred["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + ) + with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"): if is_train: loss = mlpf_loss(ygen, ypred, batch) @@ -302,6 +336,14 @@ def train_and_valid( tensorboard_writer.add_scalar("step/learning_rate", lr_schedule.get_last_lr()[0], step) tensorboard_writer.flush() loss_accum = 0.0 + + extra_state = {"step": step, "lr_schedule_state_dict": lr_schedule.state_dict()} + torch.save( + {"model_state_dict": get_model_state_dict(model), "optimizer_state_dict": optimizer.state_dict()}, + f"{outdir}/step_weights.pth", + ) + save_checkpoint(f"{outdir}/step_weights.pth", model, optimizer, extra_state) + if not (comet_experiment is None) and (itrain % comet_step_freq == 0): # this loss is not normalized to batch size comet_experiment.log_metrics(loss, prefix=f"{train_or_valid}", step=step) @@ -338,9 +380,11 @@ def train_and_valid( loss=intermediate_losses_t["Total"], reg_loss=intermediate_losses_t["Regression"], cls_loss=intermediate_losses_t["Classification"], + cls_binary_loss=intermediate_losses_t["Classification_binary"], val_loss=intermediate_losses_v["Total"], val_reg_loss=intermediate_losses_v["Regression"], val_cls_loss=intermediate_losses_v["Classification"], + val_cls_binary_loss=intermediate_losses_v["Classification_binary"], inside_epoch=epoch, step=(epoch - 1) * len(data_loader) + itrain, val_freq_time=val_freq_time.cpu().item(), @@ -356,6 +400,17 @@ def train_and_valid( comet_experiment.log_metrics(intermediate_losses_v, prefix="valid", step=step) val_freq_time_0 = time.time() # reset intermediate validation spacing timer + if not is_train and comet_experiment: + comet_experiment.log_confusion_matrix( + matrix=cm_X_gen, title="Element to target", row_label="X", column_label="target", epoch=epoch, file_name="cm_X_gen.json" + ) + comet_experiment.log_confusion_matrix( + matrix=cm_X_pred, title="Element to pred", row_label="X", column_label="pred", epoch=epoch, file_name="cm_X_pred.json" + ) + comet_experiment.log_confusion_matrix( + matrix=cm_id, title="Target to pred", row_label="gen", column_label="pred", epoch=epoch, file_name="cm_id.json" + ) + num_data = torch.tensor(len(data_loader), device=rank) # sum up the number of steps from all workers if world_size > 1: @@ -414,7 +469,7 @@ def train_mlpf( t0_initial = time.time() - losses_of_interest = ["Total", "Classification", "Regression"] + losses_of_interest = ["Total", "Classification", "Classification_binary", "Regression"] losses = {} losses["train"], losses["valid"] = {}, {} @@ -519,9 +574,11 @@ def train_mlpf( loss=losses_t["Total"], reg_loss=losses_t["Regression"], cls_loss=losses_t["Classification"], + cls_binary_loss=losses_t["Classification_binary"], val_loss=losses_v["Total"], val_reg_loss=losses_v["Regression"], val_cls_loss=losses_v["Classification"], + val_cls_binary_loss=losses_v["Classification_binary"], epoch=epoch, ) if (rank == 0) or (rank == "cpu"): diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index 18762f3d0..aab90741b 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -171,7 +171,8 @@ def unpack_target(y): def unpack_predictions(preds): ret = {} - ret["cls_id_onehot"], ret["momentum"] = preds + ret["cls_binary"], ret["cls_id_onehot"], ret["momentum"] = preds + # ret["cls_id_onehot"], ret["momentum"] = preds # ret["charge"] = torch.argmax(ret["charge"], axis=1, keepdim=True) - 1 @@ -182,8 +183,14 @@ def unpack_predictions(preds): ret["cos_phi"] = ret["momentum"][..., 3] ret["energy"] = ret["momentum"][..., 4] - # get PID with the maximum proba - ret["cls_id"] = torch.argmax(ret["cls_id_onehot"], axis=-1) + # first get the cases where a particle was predicted + ret["cls_id"] = torch.argmax(ret["cls_binary"], axis=-1) + # when a particle was predicted, get the particle ID + ret["cls_id"][ret["cls_id"] == 1] = torch.argmax(ret["cls_id_onehot"], axis=-1)[ret["cls_id"] == 1] + + # get the predicted particle ID + # ret["cls_id"] = torch.argmax(ret["cls_id_onehot"], axis=-1) + # particle properties ret["phi"] = torch.atan2(ret["sin_phi"], ret["cos_phi"]) ret["p4"] = torch.cat( diff --git a/notebooks/cms/cms-simvalidation.ipynb b/notebooks/cms/cms-simvalidation.ipynb index 9a6ac6c3c..7485c2851 100644 --- a/notebooks/cms/cms-simvalidation.ipynb +++ b/notebooks/cms/cms-simvalidation.ipynb @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "!ls -lrt /local/joosep/mlpf/cms/v3_pre1_nopu/TTbar_14TeV_TuneCUETP8M1_cfi/raw" + "!ls -lrt /local/joosep/mlpf/cms/20240702_cptruthdef/nopu/TTbar_14TeV_TuneCUETP8M1_cfi/raw" ] }, { @@ -109,7 +109,7 @@ "pickle_data = sum(\n", " [\n", " pickle.load(bz2.BZ2File(f, \"r\"))\n", - " for f in tqdm.tqdm(list(glob.glob(\"/local/joosep/mlpf/cms/v3_pre1_nopu//{}/*/*.pkl.bz2\".format(sample)))[:maxfiles])\n", + " for f in tqdm.tqdm(list(glob.glob(\"/local/joosep/mlpf/cms/20240702_cptruthdef/nopu/{}/*/*.pkl.bz2\".format(sample)))[:maxfiles])\n", " ],\n", " [],\n", ")\n", @@ -162,7 +162,9 @@ "metadata": {}, "outputs": [], "source": [ - "plt.hist([len(x) for x in arrs_awk[\"Xelem\"][\"typ\"]], bins=100);" + "plt.figure()\n", + "plt.hist([len(x) for x in arrs_awk[\"Xelem\"][\"typ\"]], bins=100)\n", + "plt.show()" ] }, { @@ -222,9 +224,11 @@ "metadata": {}, "outputs": [], "source": [ + "plt.figure()\n", "b = np.linspace(0, 10000, 101)\n", "plt.hist(awkward.sum(arrs_awk[\"ygen\"][\"e\"], axis=1), bins=b)\n", - "plt.yscale(\"log\")" + "plt.yscale(\"log\")\n", + "plt.show()" ] }, { @@ -234,11 +238,13 @@ "metadata": {}, "outputs": [], "source": [ + "plt.figure()\n", "b = np.linspace(0, 1e5, 100)\n", "plt.hist(awkward.sum(arrs_awk[\"Xelem\"][\"e\"], axis=1), bins=b, histtype=\"step\", lw=2)\n", "plt.hist(awkward.sum(arrs_awk[\"ygen\"][\"e\"], axis=1), bins=b, histtype=\"step\", lw=2)\n", "plt.hist(awkward.sum(arrs_awk[\"ycand\"][\"e\"], axis=1), bins=b, histtype=\"step\", lw=2)\n", - "plt.yscale(\"log\")" + "plt.yscale(\"log\")\n", + "plt.show()" ] }, { @@ -267,6 +273,7 @@ "\n", "#cms_label(ax)\n", "#sample_label(ax, sample)\n", + "plt.show()\n", "plt.savefig(plot_outpath + \"pf_vs_truth_sume.pdf\", bbox_inches=\"tight\")" ] }, @@ -310,6 +317,7 @@ "\n", "#cms_label(ax)\n", "#sample_label(ax, sample)\n", + "plt.show()\n", "plt.savefig(plot_outpath + \"pf_vs_truth_met.pdf\", bbox_inches=\"tight\")" ] }, @@ -321,7 +329,7 @@ "outputs": [], "source": [ "for pid in [\n", - " 0,\n", + " 0\n", "]:\n", " if pid == 0:\n", " msk = arrs_flat[\"ygen\"][\"typ\"] != pid\n", @@ -380,6 +388,7 @@ " # sample_label(ax, sample, \", \" + CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])\n", " plt.xlabel(\"MLPF truth $\\phi$\")\n", " plt.ylabel(\"PFElement $\\phi$\")\n", + " plt.show()\n", " plt.savefig(plot_outpath + \"truth_vs_pfelement_phi_{}.pdf\".format(pid), bbox_inches=\"tight\")\n", "\n", "# data1 = awkward.flatten(Xelem_e[msk])\n", @@ -418,6 +427,37 @@ " ycand_typ_id[ycand_typ_f == CLASS_LABELS_CMS[i]] = i" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd92d6f-bb46-4634-ae0c-2fe9de614a89", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "ygen_typ_f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b19716d-f5bb-445c-b21c-78942f1df9f7", + "metadata": {}, + "outputs": [], + "source": [ + "np.unique(Xelem_typ_f, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e92e2093-9009-4e5f-93a5-7de92e0fc389", + "metadata": {}, + "outputs": [], + "source": [ + "np.unique(Xelem_typ_f[ygen_typ_f == 11], return_counts=True)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -453,6 +493,7 @@ "plt.title(\"PF\")\n", "\n", "plt.tight_layout()\n", + "plt.show()\n", "plt.savefig(plot_outpath + \"primary_element.pdf\", bbox_inches=\"tight\")" ] }, @@ -890,7 +931,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/parameters/pytorch/pyg-cms-ttbar-nopu.yaml b/parameters/pytorch/pyg-cms-ttbar-nopu.yaml new file mode 100644 index 000000000..f39f936aa --- /dev/null +++ b/parameters/pytorch/pyg-cms-ttbar-nopu.yaml @@ -0,0 +1,122 @@ +backend: pytorch + +dataset: cms +sort_data: yes +data_dir: +gpus: 1 +gpu_batch_multiplier: 1 +load: +num_epochs: 100 +patience: 20 +lr: 0.0002 +lr_schedule: cosinedecay # constant, cosinedecay, onecycle +conv_type: attention +ntrain: +ntest: +nvalid: +num_workers: 0 +prefetch_factor: +checkpoint_freq: 1 +comet_name: particleflow-pt +comet_offline: False +comet_step_freq: 10 +dtype: bfloat16 +val_freq: # run an extra validation run every val_freq training steps + +model: + trainable: all + # - nn_energy + # - nn_pt + + learned_representation_mode: last #last, concat + input_encoding: joint #split, joint + pt_mode: linear + eta_mode: linear + sin_phi_mode: linear + cos_phi_mode: linear + energy_mode: linear + + gnn_lsh: + conv_type: gnn_lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout_ff: 0.0 + activation: "elu" + # gnn-lsh specific parameters + bin_size: 640 + max_num_bins: 200 + distance_dim: 128 + layernorm: True + num_node_messages: 2 + ffn_dist_hidden_dim: 128 + ffn_dist_num_layers: 2 + + attention: + conv_type: attention + num_convs: 8 + dropout_ff: 0.0 + dropout_conv_id_mha: 0.0 + dropout_conv_id_ff: 0.0 + dropout_conv_reg_mha: 0.0 + dropout_conv_reg_ff: 0.0 + activation: "relu" + head_dim: 16 + num_heads: 32 + attention_type: flash + + mamba: + conv_type: mamba + embedding_dim: 1024 + width: 1024 + num_convs: 4 + dropout_ff: 0.0 + activation: "elu" + # mamba specific paramters + d_state: 32 + d_conv: 4 + expand: 2 + +lr_schedule_config: + onecycle: + pct_start: 0.3 + +raytune: + local_dir: # Note: please specify an absolute path + sched: asha # asha, hyperband + search_alg: hyperopt # bayes, bohb, hyperopt, nevergrad, scikit + default_metric: "val_loss" + default_mode: "min" + # Tune schedule specific parameters + asha: + max_t: 200 + reduction_factor: 4 + brackets: 1 + grace_period: 10 + hyperband: + max_t: 200 + reduction_factor: 4 + hyperopt: + n_random_steps: 10 + nevergrad: + n_random_steps: 10 + +train_dataset: + cms: + physical_nopu: + batch_size: 10 + samples: + cms_pf_ttbar_nopu: + version: 2.0.0 + +valid_dataset: + cms: + physical_nopu: + batch_size: 10 + samples: + cms_pf_ttbar_nopu: + version: 2.0.0 + +test_dataset: + cms_pf_ttbar_nopu: + version: 2.0.0 diff --git a/parameters/pytorch/pyg-cms.yaml b/parameters/pytorch/pyg-cms.yaml index 47ec36f72..c98f6cb5e 100644 --- a/parameters/pytorch/pyg-cms.yaml +++ b/parameters/pytorch/pyg-cms.yaml @@ -8,7 +8,7 @@ gpu_batch_multiplier: 1 load: num_epochs: 100 patience: 20 -lr: 0.0001 +lr: 0.0002 lr_schedule: cosinedecay # constant, cosinedecay, onecycle conv_type: attention ntrain: @@ -54,7 +54,7 @@ model: attention: conv_type: attention - num_convs: 4 + num_convs: 8 dropout_ff: 0.0 dropout_conv_id_mha: 0.0 dropout_conv_id_ff: 0.0 @@ -103,20 +103,6 @@ raytune: train_dataset: cms: -# physical_nopu: -# batch_size: 30 -# samples: -# cms_pf_ttbar_nopu: -# version: 2.0.0 -# cms_pf_vbf_nopu: -# version: 2.0.0 -# cms_pf_qcd_nopu: -# version: 2.0.0 -# multiparticlegun: -# batch_size: 4 -# samples: -# cms_pf_multi_particle_gun: -# version: 2.0.0 physical_pu: batch_size: 1 samples: @@ -127,20 +113,6 @@ train_dataset: valid_dataset: cms: -# physical_nopu: -# batch_size: 30 -# samples: -# cms_pf_ttbar_nopu: -# version: 2.0.0 -# cms_pf_vbf_nopu: -# version: 2.0.0 -# cms_pf_qcd_nopu: -# version: 2.0.0 -# multiparticlegun: -# batch_size: 5 -# samples: -# cms_pf_multi_particle_gun: -# version: 2.0.0 physical_pu: batch_size: 1 samples: @@ -148,15 +120,8 @@ valid_dataset: version: 2.0.0 cms_pf_ttbar: version: 2.0.0 - test_dataset: cms_pf_ttbar: version: 2.0.0 cms_pf_qcd: version: 2.0.0 -# cms_pf_ttbar_nopu: -# version: 2.0.0 -# cms_pf_vbf_nopu: -# version: 2.0.0 -# cms_pf_qcd_nopu: -# version: 2.0.0 diff --git a/scripts/clic/postprocessing_jobs.py b/scripts/clic/postprocessing_jobs.py index 4ca10cc98..50ada8647 100644 --- a/scripts/clic/postprocessing_jobs.py +++ b/scripts/clic/postprocessing_jobs.py @@ -19,7 +19,7 @@ def write_script(infiles, outpath): for inf in infiles: s += [ - "singularity exec -B /local /home/software/singularity/pytorch.simg:2024-07-08 python3 " + "singularity exec -B /local /home/software/singularity/pytorch.simg:2024-08-02 python3 " + f"scripts/clic/postprocessing.py --input {inf} --outpath {outpath}" ] ret = "\n".join(s) @@ -29,7 +29,7 @@ def write_script(infiles, outpath): samples = [ - # ("/local/joosep/clic_edm4hep/2024_07/p8_ee_qq_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_qq_ecm380/"), + ("/local/joosep/clic_edm4hep/2024_07/p8_ee_qq_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_qq_ecm380/"), ("/local/joosep/clic_edm4hep/2024_07/p8_ee_tt_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_tt_ecm380/"), ] diff --git a/scripts/tallinn/a100/pytorch-small-eval-clic.sh b/scripts/tallinn/a100/pytorch-small-eval-clic.sh index 075263bcd..a10df0bda 100644 --- a/scripts/tallinn/a100/pytorch-small-eval-clic.sh +++ b/scripts/tallinn/a100/pytorch-small-eval-clic.sh @@ -7,10 +7,10 @@ IMG=/home/software/singularity/pytorch.simg:2024-08-02 cd ~/particleflow -WEIGHTS=experiments/pyg-clic_20240807_134034_168101/checkpoints/checkpoint-47-9.910686.pth +WEIGHTS=experiments/pyg-clic_20240807_134034_168101/checkpoints/checkpoint-100-9.413720.pth singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ --env KERAS_BACKEND=torch \ $IMG python3 mlpf/pyg_pipeline.py --dataset clic --gpus 1 \ --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-clic.yaml \ - --test --make-plots --gpu-batch-multiplier 100 --load $WEIGHTS --dtype bfloat16 + --test --make-plots --gpu-batch-multiplier 200 --load $WEIGHTS --dtype bfloat16 diff --git a/scripts/tallinn/a100/pytorch-small-eval-cms.sh b/scripts/tallinn/a100/pytorch-small-eval-cms.sh index db9145ce5..b85fa642f 100644 --- a/scripts/tallinn/a100/pytorch-small-eval-cms.sh +++ b/scripts/tallinn/a100/pytorch-small-eval-cms.sh @@ -1,16 +1,16 @@ #!/bin/bash #SBATCH --partition gpu #SBATCH --gres gpu:mig:1 -#SBATCH --mem-per-gpu 150G +#SBATCH --mem-per-gpu 50G #SBATCH -o logs/slurm-%x-%j-%N.out -IMG=/home/software/singularity/pytorch.simg:2024-07-08 +IMG=/home/software/singularity/pytorch.simg:2024-08-02 cd ~/particleflow -WEIGHTS=experiments/pyg-cms_20240804_095032_809397/checkpoints/checkpoint-16-19.681200.pth +WEIGHTS=experiments/pyg-cms-ttbar-nopu_20240815_233931_332621/checkpoints/checkpoint-17-17.282402.pth singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ --env KERAS_BACKEND=torch \ - $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ - --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms.yaml \ - --test --make-plots --gpu-batch-multiplier 20 --load $WEIGHTS --ntest 10000 --dtype bfloat16 + $IMG python mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ + --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms-ttbar-nopu.yaml \ + --test --make-plots --gpu-batch-multiplier 5 --load $WEIGHTS --ntest 10000 --dtype bfloat16 diff --git a/scripts/tallinn/a100/pytorch.sh b/scripts/tallinn/a100/pytorch.sh index db53f0cab..db91084de 100755 --- a/scripts/tallinn/a100/pytorch.sh +++ b/scripts/tallinn/a100/pytorch.sh @@ -4,13 +4,13 @@ #SBATCH --mem-per-gpu 200G #SBATCH -o logs/slurm-%x-%j-%N.out -IMG=/home/software/singularity/pytorch.simg:2024-07-08 +IMG=/home/software/singularity/pytorch.simg:2024-08-02 cd ~/particleflow ulimit -n 10000 singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ --env KERAS_BACKEND=torch \ - $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ - --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms.yaml \ - --train --conv-type attention --num-epochs 20 --gpu-batch-multiplier 10 --num-workers 4 --prefetch-factor 100 --checkpoint-freq 1 --comet + $IMG python3 mlpf/pyg_pipeline.py --dataset cms --gpus 1 \ + --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pytorch/pyg-cms-ttbar-nopu.yaml \ + --train --test --make-plots --num-epochs 20 --conv-type attention --gpu-batch-multiplier 10 --num-workers 4 --prefetch-factor 100 --checkpoint-freq 1 --comet --ntrain 50000 --nvalid 5000 --ntest 5000