Skip to content

Commit

Permalink
CMS dataset relabel, generate v2.1.0 with more stats, separate binary…
Browse files Browse the repository at this point in the history
… classifier (jpata#340)

* dataset relabeling
  • Loading branch information
jpata authored Aug 16, 2024
1 parent 7ad13f4 commit 2579e76
Show file tree
Hide file tree
Showing 20 changed files with 442 additions and 85 deletions.
4 changes: 4 additions & 0 deletions mlpf/data_cms/postprocessing2.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,10 @@ def process(args):
"genmet": genmet,
}

# print("trk", ygen[Xelem["typ"] == 1]["typ"])
# print("ecal", ygen[Xelem["typ"] == 4]["typ"])
# print("hcal", ygen[Xelem["typ"] == 4]["typ"])

if args.save_full_graph:
data["full_graph"] = g

Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/qq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.0.0")
VERSION = tfds.core.Version("2.1.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "update stats, move to 380 GeV",
Expand All @@ -36,6 +36,7 @@ class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder):
"1.4.0": "Fix ycand matching",
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.1.0": "Bump dataset size",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/ttbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.0.0")
VERSION = tfds.core.Version("2.1.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "update stats, move to 380 GeV",
Expand All @@ -35,6 +35,7 @@ class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
"1.4.0": "Fix ycand matching",
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.1.0": "Bump dataset size",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
6 changes: 1 addition & 5 deletions mlpf/heptfds/cms_pf/cms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
]


def prepare_data_cms(fn, with_jet_idx=False):
def prepare_data_cms(fn):
Xs = []
ygens = []
ycands = []
Expand Down Expand Up @@ -147,10 +147,6 @@ def prepare_data_cms(fn, with_jet_idx=False):
ygen["typ_idx"] = np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32)
ycand["typ_idx"] = np.array([CLASS_LABELS_CMS.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32)

if with_jet_idx:
ygen["jet_idx"] = np.zeros(len(ygen["typ"]), dtype=np.float32)
ycand["jet_idx"] = np.zeros(len(ycand["typ"]), dtype=np.float32)

Xelem_flat = ak.to_numpy(
np.stack(
[Xelem[k] for k in X_FEATURES],
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/cms_pf/qcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CmsPfQcd(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cms_pf_qcd dataset."""

VERSION = tfds.core.Version("2.0.0")
VERSION = tfds.core.Version("2.1.0")
RELEASE_NOTES = {
"1.3.0": "12_2_0_pre2 generation with updated caloparticle/trackingparticle",
"1.3.1": "Remove PS again",
Expand All @@ -32,6 +32,7 @@ class CmsPfQcd(tfds.core.GeneratorBasedBuilder):
"1.7.0": "Add cluster shape vars",
"1.7.1": "Increase stats to 400k events",
"2.0.0": "New truth def based primarily on CaloParticles",
"2.1.0": "Additional stats",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_qcd ~/tensorflow_datasets/
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/cms_pf/ttbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CmsPfTtbar(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cms_pf dataset."""

VERSION = tfds.core.Version("2.0.0")
VERSION = tfds.core.Version("2.1.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "Add muon type, fix electron GSF association",
Expand All @@ -36,6 +36,7 @@ class CmsPfTtbar(tfds.core.GeneratorBasedBuilder):
"1.7.1": "Increase stats to 400k events",
"1.8.0": "Add ispu, genjets, genmet; disable genjet_idx; improved merging",
"2.0.0": "New truth def based primarily on CaloParticles",
"2.1.0": "Additional stats",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_ttbar ~/tensorflow_datasets/
Expand Down
66 changes: 66 additions & 0 deletions mlpf/heptfds/cms_pf/vbf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""CMS PF TTbar dataset."""
import cms_utils
import tensorflow as tf

import tensorflow_datasets as tfds

X_FEATURES = cms_utils.X_FEATURES
Y_FEATURES = cms_utils.Y_FEATURES

_DESCRIPTION = """
Dataset generated with CMSSW and full detector sim.
VBF events with PU 55-75 in a Run3 setup.
"""

# TODO(cms_pf): BibTeX citation
_CITATION = """
"""


class CmsPfVbf(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cms_pf_vbf dataset."""

VERSION = tfds.core.Version("2.1.0")
RELEASE_NOTES = {
"1.7.1": "First version",
"1.8.0": "Add ispu, genjets, genmet; disable genjet_idx; improved merging",
"2.0.0": "New truth def based primarily on CaloParticles",
"2.1.0": "Additional statistics",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_vbf ~/tensorflow_datasets/
"""

def __init__(self, *args, **kwargs):
kwargs["file_format"] = tfds.core.FileFormat.ARRAY_RECORD
super(CmsPfVbf, self).__init__(*args, **kwargs)

def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict(
{
"X": tfds.features.Tensor(shape=(None, len(X_FEATURES)), dtype=tf.float32),
"ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"genmet": tfds.features.Scalar(dtype=tf.float32),
"genjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32),
}
),
supervised_keys=("X", "ygen"),
homepage="",
citation=_CITATION,
metadata=tfds.core.MetadataDict(x_features=X_FEATURES, y_features=Y_FEATURES),
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
path = dl_manager.manual_dir
sample_dir = "VBF_TuneCP5_14TeV_pythia8_cfi"
return cms_utils.split_sample(path / sample_dir / "raw")

def _generate_examples(self, files):
return cms_utils.generate_examples(files)
63 changes: 63 additions & 0 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_class_names(sample_name):
"clic_edm_zh_tautau_pf": r"$e^+e^- \rightarrow ZH \rightarrow \tau \tau$",
"cms_pf_qcd": r"QCD $p_T \in [15, 3000]\ \mathrm{GeV}$+PU",
"cms_pf_ztt": r"$\mathrm{Z}\rightarrow \mathrm{\tau}\mathrm{\tau}$+PU",
"cms_pf_vbf": r"VBF+PU",
"cms_pf_ttbar": r"$\mathrm{t}\overline{\mathrm{t}}$+PU",
"cms_pf_ttbar_nopu": r"$\mathrm{t}\overline{\mathrm{t}}$",
"cms_pf_qcd_nopu": r"QCD $p_T \in [15, 3000]\ \mathrm{GeV}$",
Expand Down Expand Up @@ -1076,6 +1077,68 @@ def plot_particle_multiplicity(X, yvals, class_names, epoch=None, cp_dir=None, c
)


def plot_particle_ratio(yvals, class_names, epoch=None, cp_dir=None, comet_experiment=None, title=None, sample=None, dataset=None):
msk_cand = yvals["cand_cls_id"] != 0
msk_pred = yvals["pred_cls_id"] != 0
msk_gen = yvals["gen_cls_id"] != 0

cand_pt = awkward.to_numpy(awkward.flatten(yvals["cand_pt"][msk_gen & msk_cand]))
pred_pt = awkward.to_numpy(awkward.flatten(yvals["pred_pt"][msk_gen & msk_pred]))
gen_cand_pt = awkward.to_numpy(awkward.flatten(yvals["gen_pt"][msk_gen & msk_cand]))
gen_pred_pt = awkward.to_numpy(awkward.flatten(yvals["gen_pt"][msk_gen & msk_pred]))
ratio_cand_pt = cand_pt / gen_cand_pt
ratio_pred_pt = pred_pt / gen_pred_pt

cand_e = awkward.to_numpy(awkward.flatten(yvals["cand_energy"][msk_gen & msk_cand]))
pred_e = awkward.to_numpy(awkward.flatten(yvals["pred_energy"][msk_gen & msk_pred]))
gen_cand_e = awkward.to_numpy(awkward.flatten(yvals["gen_energy"][msk_gen & msk_cand]))
gen_pred_e = awkward.to_numpy(awkward.flatten(yvals["gen_energy"][msk_gen & msk_pred]))
ratio_cand_e = cand_e / gen_cand_e
ratio_pred_e = pred_e / gen_pred_e

gen_cls_id = awkward.flatten(yvals["gen_cls_id"][msk_gen])
gen_cls_id1 = awkward.flatten(yvals["gen_cls_id"][msk_gen & msk_cand])
gen_cls_id2 = awkward.flatten(yvals["gen_cls_id"][msk_gen & msk_pred])
cls_ids = np.unique(awkward.values_astype(gen_cls_id, np.int64))
print("cls_ids", cls_ids)
for cls_id in cls_ids:
if cls_id == 0:
continue
clname = class_names[cls_id]

plt.figure()
b = np.linspace(0, 5, 100)
plt.hist(ratio_cand_pt[gen_cls_id1 == cls_id], bins=b, label="PF", histtype="step")
plt.hist(ratio_pred_pt[gen_cls_id2 == cls_id], bins=b, label="MLPF", histtype="step")
plt.legend(loc="best")
if title:
plt.title(title + ", " + clname)
save_img(
"particle_pt_ratio_{}.png".format(cls_id),
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)
plt.xlabel("Reconstructed / target $p_T$")
plt.clf()

plt.figure()
b = np.linspace(0, 5, 100)
plt.hist(ratio_cand_e[gen_cls_id1 == cls_id], bins=b, label="PF", histtype="step")
plt.hist(ratio_pred_e[gen_cls_id2 == cls_id], bins=b, label="MLPF", histtype="step")
plt.legend(loc="best")
if title:
plt.title(title + ", " + clname)
save_img(
"particle_e_ratio_{}.png".format(cls_id),
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)
plt.xlabel("Reconstructed / target $E$")
plt.clf()


def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None, sample=None, dataset=None):
msk_cand = yvals["cand_cls_id"] != 0
cand_pt = awkward.to_numpy(awkward.flatten(yvals["cand_pt"][msk_cand], axis=1))
Expand Down
29 changes: 28 additions & 1 deletion mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ def __getitem__(self, item):
ret["ycand"] = ret["ycand"][sortidx]
ret["ygen"] = ret["ygen"][sortidx]

if self.ds.dataset_info.name.startswith("cms_"):
ret["ygen"][:, 0][(ret["X"][:, 0] == 1) & (ret["ygen"][:, 0] == 2)] = 1
ret["ygen"][:, 0][(ret["X"][:, 0] == 1) & (ret["ygen"][:, 0] == 5)] = 1

ret["ygen"][:, 0][(ret["X"][:, 0] == 4) & (ret["ygen"][:, 0] == 1)] = 5
ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 1)] = 2
# ret["ygen"][:, 0][(ret["X"][:, 0]==4) & (ret["ygen"][:, 0] == 6)] = 5
ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 6)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 4) & (ret["ygen"][:, 0] == 7)] = 5
ret["ygen"][:, 0][(ret["X"][:, 0] == 5) & (ret["ygen"][:, 0] == 7)] = 2

ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 1)] = 4
ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 1)] = 3
ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 2)] = 4
ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 2)] = 3
ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 6)] = 4
ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 6)] = 3
ret["ygen"][:, 0][(ret["X"][:, 0] == 8) & (ret["ygen"][:, 0] == 7)] = 4
ret["ygen"][:, 0][(ret["X"][:, 0] == 9) & (ret["ygen"][:, 0] == 7)] = 3

ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 1)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 1)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 6)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 6)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 7)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 7)] = 2

return ret

def __len__(self):
Expand Down Expand Up @@ -181,7 +208,7 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
if world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = torch.utils.data.RandomSampler(dataset)
sampler = torch.utils.data.SequentialSampler(dataset)

# build dataloaders
batch_size = config[f"{split}_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"]
Expand Down
5 changes: 4 additions & 1 deletion mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vector
from jet_utils import build_dummy_array, match_two_jet_collections
from plotting.plot_utils import (
get_class_names,
compute_met_and_ratio,
load_eval_data,
plot_jets,
Expand All @@ -22,6 +23,7 @@
plot_met_response_binned,
plot_num_elements,
plot_particles,
plot_particle_ratio,
plot_sum_energy,
)

Expand Down Expand Up @@ -178,7 +180,7 @@ def make_plots(outpath, sample, dataset, dir_name=""):
"""Uses the predictions stored as .parquet files (see above) to make plots."""

mplhep.style.use(mplhep.styles.CMS)

class_names = get_class_names(sample)
os.system(f"mkdir -p {outpath}/plots{dir_name}/{sample}")

plots_path = Path(f"{outpath}/plots{dir_name}/{sample}/")
Expand Down Expand Up @@ -241,3 +243,4 @@ def make_plots(outpath, sample, dataset, dir_name=""):
plot_met_response_binned(met_data, cp_dir=plots_path, dataset=dataset, sample=sample)

plot_particles(yvals, cp_dir=plots_path, dataset=dataset, sample=sample)
plot_particle_ratio(yvals, class_names, cp_dir=plots_path, dataset=dataset, sample=sample)
16 changes: 9 additions & 7 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ def __init__(
decoding_dim = self.input_dim + embedding_dim

# DNN that acts on the node level to predict the PID
self.nn_id = ffn(decoding_dim, num_classes, width, self.act, dropout_ff)
self.nn_binary_particle = ffn(decoding_dim, 2, width, self.act, dropout_ff)
self.nn_pid = ffn(decoding_dim, num_classes, width, self.act, dropout_ff)

# elementwise DNN for node momentum regression
embed_dim = decoding_dim + num_classes
embed_dim = decoding_dim + 2 + num_classes
self.nn_pt = RegressionOutput(pt_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
self.nn_eta = RegressionOutput(eta_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
self.nn_sin_phi = RegressionOutput(sin_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
Expand Down Expand Up @@ -319,20 +320,21 @@ def forward(self, X_features, mask):
out_padded = conv(conv_input, mask)
embeddings_reg.append(out_padded)

# id input
if self.learned_representation_mode == "concat":
final_embedding_id = torch.cat([Xfeat_normed] + embeddings_id, axis=-1)
elif self.learned_representation_mode == "last":
final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)
preds_id = self.nn_id(final_embedding_id)
preds_binary_particle = self.nn_binary_particle(final_embedding_id)
preds_pid = self.nn_pid(final_embedding_id)

# pred_charge = self.nn_charge(final_embedding_id)

# regression input
if self.learned_representation_mode == "concat":
final_embedding_reg = torch.cat([Xfeat_normed] + embeddings_reg + [preds_id], axis=-1)
final_embedding_reg = torch.cat([Xfeat_normed] + embeddings_reg + [preds_binary_particle.detach(), preds_pid.detach()], axis=-1)
elif self.learned_representation_mode == "last":
final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)
final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_id], axis=-1)
final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_binary_particle.detach(), preds_pid.detach()], axis=-1)

# The PFElement feature order in X_features defined in fcc/postprocessing.py
preds_pt = self.nn_pt(X_features, final_embedding_reg, X_features[..., 1:2])
Expand All @@ -342,4 +344,4 @@ def forward(self, X_features, mask):
preds_energy = self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6])
preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1)

return preds_id, preds_momentum
return preds_binary_particle, preds_pid, preds_momentum
Loading

0 comments on commit 2579e76

Please sign in to comment.