Skip to content

Commit

Permalink
fixes for pytorch, CMS t1tttt dataset, update response plots (jpata#232)
Browse files Browse the repository at this point in the history
  * fixes for pytorch, CMS t1tttt dataset, update response plots
  • Loading branch information
jpata authored Oct 11, 2023
1 parent 0547572 commit 59b5d97
Show file tree
Hide file tree
Showing 17 changed files with 453 additions and 334 deletions.
3 changes: 2 additions & 1 deletion mlpf/data_cms/genjob_pu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@ cmsRun $CMSSWDIR/src/Validation/RecoParticleFlow/test/pfanalysis_ntuple.py
mv pfntuple.root pfntuple_${SEED}.root
python3 ${MLPF_PATH}/mlpf/data_cms/postprocessing2.py --input pfntuple_${SEED}.root --outpath ./ --save-normalized-table
bzip2 -z pfntuple_${SEED}.pkl
#rm step*.root
cp *.pkl.bz2 $OUTDIR/
rm -Rf $WORKDIR
2 changes: 1 addition & 1 deletion mlpf/data_cms/prepare_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# "ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi",
# "QCDForPF_14TeV_TuneCUETP8M1_cfi",
# "QCD_Pt_3000_7000_14TeV_TuneCUETP8M1_cfi",
# "SMS-T1tttt_mGl-1500_mLSP-100_TuneCP5_14TeV_pythia8_cfi",
("SMS-T1tttt_mGl-1500_mLSP-100_TuneCP5_14TeV_pythia8_cfi", 200000, 202050),
# "ZpTT_1500_14TeV_TuneCP5_cfi",
]

Expand Down
4 changes: 4 additions & 0 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,8 @@ def plots(train_dir, max_files):
load_loss_history,
loss_plot,
plot_jet_response_binned,
plot_jet_response_binned_separate,
plot_jet_response_binned_eta,
plot_met_response_binned,
get_class_names,
plot_rocs,
Expand Down Expand Up @@ -1387,6 +1389,8 @@ def plots(train_dir, max_files):
plot_particles(yvals, cp_dir=cp_dir, title=_title)

plot_jet_response_binned(yvals, cp_dir=cp_dir, title=_title)
plot_jet_response_binned_eta(yvals, cp_dir=cp_dir, title=_title)
plot_jet_response_binned_separate(yvals, cp_dir=cp_dir, title=_title)
plot_met_response_binned(met_data, cp_dir=cp_dir, title=_title)

mom_data = compute_3dmomentum_and_ratio(yvals)
Expand Down
214 changes: 174 additions & 40 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@
"gen_met": "$p_{\mathrm{T,gen}}^\text{miss}$ [GeV]",
"gen_mom": "$p_{\mathrm{gen}}$ [GeV]",
"gen_jet": "jet $p_{\mathrm{T,gen}}$ [GeV]",
"gen_jet_eta": "jet $\eta_{\mathrm{gen}}$ [GeV]",
"reco_met": "$p_{\mathrm{T,reco}}^\text{miss}$ [GeV]",
"reco_gen_met_ratio": "$p_{\mathrm{T,reco}}^\mathrm{miss} / p_{\\mathrm{T,gen}}^\mathrm{miss}$",
"reco_gen_mom_ratio": "$p_{\mathrm{reco}} / p_{\\mathrm{gen}}$",
"reco_gen_jet_ratio": "jet $p_{\mathrm{T,reco}} / p_{\\mathrm{T,gen}}$",
"gen_met_range": "${} \less p_{{\mathrm{{T,gen}}}}^\mathrm{{miss}}\leq {}$",
"gen_mom_range": "${} \less p_{{\mathrm{{gen}}}}\leq {}$",
"gen_jet_range": "${} \less p_{{\mathrm{{T,gen}}}} \leq {}$",
"gen_jet_range_eta": "${} \less \eta_{{\mathrm{{gen}}}} \leq {}$",
}


Expand Down Expand Up @@ -296,6 +298,12 @@ def compute_jet_ratio(data, yvals):
axis=1,
)
)
ret["jet_gen_to_pred_geneta"] = awkward.to_numpy(
awkward.flatten(
vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_pred"]["gen"]]).eta,
axis=1,
)
)
ret["jet_gen_to_pred_predpt"] = awkward.to_numpy(
awkward.flatten(
vector.awk(data["jets"]["pred"][data["matched_jets"]["gen_to_pred"]["pred"]]).pt,
Expand All @@ -308,6 +316,12 @@ def compute_jet_ratio(data, yvals):
axis=1,
)
)
ret["jet_gen_to_cand_geneta"] = awkward.to_numpy(
awkward.flatten(
vector.awk(data["jets"]["gen"][data["matched_jets"]["gen_to_cand"]["gen"]]).eta,
axis=1,
)
)
ret["jet_gen_to_cand_candpt"] = awkward.to_numpy(
awkward.flatten(
vector.awk(data["jets"]["cand"][data["matched_jets"]["gen_to_cand"]["cand"]]).pt,
Expand Down Expand Up @@ -978,6 +992,70 @@ def plot_particles(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=
)


def plot_jet_response_binned_separate(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None):
pf_genjet_pt = yvals["jet_gen_to_cand_genpt"]
mlpf_genjet_pt = yvals["jet_gen_to_pred_genpt"]

pf_response = yvals["jet_ratio_cand"]
mlpf_response = yvals["jet_ratio_pred"]

genjet_bins = [10, 20, 40, 60, 80, 100, 200]

x_vals = []
pf_vals = []
mlpf_vals = []
b = np.linspace(0, 2, 100)

for ibin in range(len(genjet_bins) - 1):
lim_low = genjet_bins[ibin]
lim_hi = genjet_bins[ibin + 1]
x_vals.append(np.mean([lim_low, lim_hi]))

mask_genjet = (pf_genjet_pt > lim_low) & (pf_genjet_pt <= lim_hi)
pf_subsample = pf_response[mask_genjet]
if len(pf_subsample) > 0:
pf_p25 = np.percentile(pf_subsample, 25)
pf_p50 = np.percentile(pf_subsample, 50)
pf_p75 = np.percentile(pf_subsample, 75)
else:
pf_p25 = 0
pf_p50 = 0
pf_p75 = 0
pf_vals.append([pf_p25, pf_p50, pf_p75])

mask_genjet = (mlpf_genjet_pt > lim_low) & (mlpf_genjet_pt <= lim_hi)
mlpf_subsample = mlpf_response[mask_genjet]

if len(mlpf_subsample) > 0:
mlpf_p25 = np.percentile(mlpf_subsample, 25)
mlpf_p50 = np.percentile(mlpf_subsample, 50)
mlpf_p75 = np.percentile(mlpf_subsample, 75)
else:
mlpf_p25 = 0
mlpf_p50 = 0
mlpf_p75 = 0
mlpf_vals.append([mlpf_p25, mlpf_p50, mlpf_p75])

plt.figure()
plt.hist(pf_subsample, bins=b, histtype="step", lw=2, label="PF")
plt.hist(mlpf_subsample, bins=b, histtype="step", lw=2, label="MLPF")
plt.xlim(0, 2)
plt.xticks([0, 0.5, 1, 1.5, 2])
plt.ylabel("Matched jets / bin")
plt.xlabel(labels["reco_gen_jet_ratio"])
plt.axvline(1.0, ymax=0.7, color="black", ls="--")
plt.legend(loc=1, fontsize=16)
plt.title(labels["gen_jet_range"].format(lim_low, lim_hi))
plt.yscale("log")

save_img(
"jet_response_binned_pt{}.png".format(lim_low),
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)


def plot_jet_response_binned(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None):
pf_genjet_pt = yvals["jet_gen_to_cand_genpt"]
mlpf_genjet_pt = yvals["jet_gen_to_pred_genpt"]
Expand Down Expand Up @@ -1049,19 +1127,10 @@ def plot_jet_response_binned(yvals, epoch=None, cp_dir=None, comet_experiment=No
mlpf_vals = np.array(mlpf_vals)

# Plot median and IQR as a function of gen pt
fig, axs = plt.subplots(2, 1, sharex=True)
plt.sca(axs[0])
plt.plot(x_vals, pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylim(0.75, 1.25)
plt.axhline(1.0, color="black", ls="--")
plt.ylabel("Response median")
plt.legend(title=title)

plt.sca(axs[1])
plt.plot(x_vals, pf_vals[:, 2] - pf_vals[:, 0], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 2] - mlpf_vals[:, 0], marker="o", label="MLPF")
plt.ylabel("Response IQR")
plt.figure()
plt.plot(x_vals, (pf_vals[:, 2] - pf_vals[:, 0]) / pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, (mlpf_vals[:, 2] - mlpf_vals[:, 0]) / mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylabel("Response IQR / median")
plt.xlabel(labels["gen_jet"])

plt.tight_layout()
Expand All @@ -1073,6 +1142,91 @@ def plot_jet_response_binned(yvals, epoch=None, cp_dir=None, comet_experiment=No
)


def plot_jet_response_binned_eta(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None):
pf_genjet_eta = yvals["jet_gen_to_cand_geneta"]
mlpf_genjet_eta = yvals["jet_gen_to_pred_geneta"]

pf_response = yvals["jet_ratio_cand"]
mlpf_response = yvals["jet_ratio_pred"]

genjet_bins = [-4, -3, -2, -1, 0, 1, 2, 3, 4]

x_vals = []
pf_vals = []
mlpf_vals = []
b = np.linspace(0, 2, 100)

fig, axs = plt.subplots(3, 3, figsize=(3 * 5, 3 * 5))
axs = axs.flatten()
for ibin in range(len(genjet_bins) - 1):
lim_low = genjet_bins[ibin]
lim_hi = genjet_bins[ibin + 1]
x_vals.append(np.mean([lim_low, lim_hi]))

mask_genjet = (pf_genjet_eta > lim_low) & (pf_genjet_eta <= lim_hi)
pf_subsample = pf_response[mask_genjet]
if len(pf_subsample) > 0:
pf_p25 = np.percentile(pf_subsample, 25)
pf_p50 = np.percentile(pf_subsample, 50)
pf_p75 = np.percentile(pf_subsample, 75)
else:
pf_p25 = 0
pf_p50 = 0
pf_p75 = 0
pf_vals.append([pf_p25, pf_p50, pf_p75])

mask_genjet = (mlpf_genjet_eta > lim_low) & (mlpf_genjet_eta <= lim_hi)
mlpf_subsample = mlpf_response[mask_genjet]

if len(mlpf_subsample) > 0:
mlpf_p25 = np.percentile(mlpf_subsample, 25)
mlpf_p50 = np.percentile(mlpf_subsample, 50)
mlpf_p75 = np.percentile(mlpf_subsample, 75)
else:
mlpf_p25 = 0
mlpf_p50 = 0
mlpf_p75 = 0
mlpf_vals.append([mlpf_p25, mlpf_p50, mlpf_p75])

plt.sca(axs[ibin])
plt.hist(pf_subsample, bins=b, histtype="step", lw=2, label="PF")
plt.hist(mlpf_subsample, bins=b, histtype="step", lw=2, label="MLPF")
plt.xlim(0, 2)
plt.xticks([0, 0.5, 1, 1.5, 2])
plt.ylabel("Matched jets / bin")
plt.xlabel(labels["reco_gen_jet_ratio"])
plt.axvline(1.0, ymax=0.7, color="black", ls="--")
plt.legend(loc=1, fontsize=16)
plt.title(labels["gen_jet_range_eta"].format(lim_low, lim_hi))
plt.yscale("log")

plt.tight_layout()
save_img(
"jet_response_binned_eta.png",
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)

x_vals = np.array(x_vals)
pf_vals = np.array(pf_vals)
mlpf_vals = np.array(mlpf_vals)

# Plot median and IQR as a function of gen pt
plt.figure()
plt.plot(x_vals, (pf_vals[:, 2] - pf_vals[:, 0]) / pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, (mlpf_vals[:, 2] - mlpf_vals[:, 0]) / mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylabel("Response IQR / median")
plt.xlabel(labels["gen_jet_eta"])
plt.tight_layout()
save_img(
"jet_response_med_iqr_eta.png",
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)


def plot_met_response_binned(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None):
genmet = yvals["gen_met"]

Expand Down Expand Up @@ -1141,21 +1295,10 @@ def plot_met_response_binned(yvals, epoch=None, cp_dir=None, comet_experiment=No
mlpf_vals = np.array(mlpf_vals)

# Plot median and IQR as a function of gen pt
fig, axs = plt.subplots(2, 1, sharex=True)
plt.sca(axs[0])
plt.plot(x_vals, pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylim(0.75, 1.25)
plt.axhline(1.0, color="black", ls="--")
plt.ylabel("Response median")
if title:
plt.title(title)
plt.legend()

plt.sca(axs[1])
plt.plot(x_vals, pf_vals[:, 2] - pf_vals[:, 0], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 2] - mlpf_vals[:, 0], marker="o", label="MLPF")
plt.ylabel("Response IQR")
plt.figure()
plt.plot(x_vals, (pf_vals[:, 2] - pf_vals[:, 0]) / pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, (mlpf_vals[:, 2] - mlpf_vals[:, 0]) / mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylabel("Response IQR / median")
plt.legend()
if title:
plt.title(title)
Expand Down Expand Up @@ -1238,18 +1381,9 @@ def plot_3dmomentum_response_binned(yvals, epoch=None, cp_dir=None, comet_experi
mlpf_vals = np.array(mlpf_vals)

# Plot median and IQR as a function of gen pt
fig, axs = plt.subplots(2, 1, sharex=True)
plt.sca(axs[0])
plt.plot(x_vals, pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylim(0.75, 1.25)
plt.axhline(1.0, color="black", ls="--")
plt.ylabel("Response median")
plt.legend(title=title)

plt.sca(axs[1])
plt.plot(x_vals, pf_vals[:, 2] - pf_vals[:, 0], marker="o", label="PF")
plt.plot(x_vals, mlpf_vals[:, 2] - mlpf_vals[:, 0], marker="o", label="MLPF")
plt.figure()
plt.plot(x_vals, (pf_vals[:, 2] - pf_vals[:, 0]) / pf_vals[:, 1], marker="o", label="PF")
plt.plot(x_vals, (mlpf_vals[:, 2] - mlpf_vals[:, 0]) / mlpf_vals[:, 1], marker="o", label="MLPF")
plt.ylabel("Response IQR")
plt.xlabel(labels["gen_mom"])

Expand Down
39 changes: 22 additions & 17 deletions mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
return ret


@torch.no_grad()
def run_predictions(rank, mlpf, loader, sample, outpath):
"""Runs inference on the given sample and stores the output as .parquet files."""

Expand All @@ -53,8 +54,9 @@ def run_predictions(rank, mlpf, loader, sample, outpath):

ti = time.time()

for i, batch in tqdm.tqdm(enumerate(loader)):
event = batch.to(rank)
for i, event in tqdm.tqdm(enumerate(loader), total=len(loader)):
event.X = event.X.to(rank)
event.batch = event.batch.to(rank)

# recall target ~ ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"]
target_ids = event.ygen[:, 0].long()
Expand All @@ -65,10 +67,13 @@ def run_predictions(rank, mlpf, loader, sample, outpath):

# make mlpf forward pass
pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event)
pred_ids_one_hot = pred_ids_one_hot.detach().cpu()
pred_momentum = pred_momentum.detach().cpu()
pred_charge = pred_charge.detach().cpu()

pred_ids = torch.argmax(pred_ids_one_hot.detach(), axis=-1)
pred_charge = torch.argmax(pred_charge.detach(), axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum.detach()], axis=-1)
pred_ids = torch.argmax(pred_ids_one_hot, axis=-1)
pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum], axis=-1)

batch_ids = event.batch.cpu().numpy()
awkvals = {
Expand All @@ -80,22 +85,22 @@ def run_predictions(rank, mlpf, loader, sample, outpath):
gen_p4, cand_p4, pred_p4 = [], [], []
gen_cls, cand_cls, pred_cls = [], [], []
Xs = []
for _ibatch in np.unique(event.batch.cpu().numpy()):
msk_batch = event.batch == _ibatch
msk_gen = target_ids[msk_batch] != 0
msk_cand = cand_ids[msk_batch] != 0
msk_pred = pred_ids[msk_batch] != 0
for _ibatch in np.unique(batch_ids):
msk_batch = batch_ids == _ibatch
msk_gen = (target_ids[msk_batch] != 0).numpy()
msk_cand = (cand_ids[msk_batch] != 0).numpy()
msk_pred = (pred_ids[msk_batch] != 0).numpy()

Xs.append(event.X[msk_batch].cpu().numpy())

gen_p4.append(event.ygen[msk_batch, 1:][msk_gen])
gen_cls.append(target_ids[msk_batch][msk_gen])
gen_p4.append(event.ygen[msk_batch, 1:][msk_gen].numpy())
gen_cls.append(target_ids[msk_batch][msk_gen].numpy())

cand_p4.append(event.ycand[msk_batch, 1:][msk_cand])
cand_cls.append(cand_ids[msk_batch][msk_cand])
cand_p4.append(event.ycand[msk_batch, 1:][msk_cand].numpy())
cand_cls.append(cand_ids[msk_batch][msk_cand].numpy())

pred_p4.append(pred_momentum[msk_batch, :][msk_pred])
pred_cls.append(pred_ids[msk_batch][msk_pred])
pred_p4.append(pred_momentum[msk_batch, :][msk_pred].numpy())
pred_cls.append(pred_ids[msk_batch][msk_pred].numpy())

Xs = awkward.from_iter(Xs)
gen_p4 = awkward.from_iter(gen_p4)
Expand Down Expand Up @@ -158,7 +163,7 @@ def run_predictions(rank, mlpf, loader, sample, outpath):
)
_logger.info(f"Saved predictions at {outpath}/preds/{sample}/pred_{rank}_{i}.parquet")

if i == 2:
if i == 100:
break

_logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min")
Expand Down
Loading

0 comments on commit 59b5d97

Please sign in to comment.