Skip to content

Commit

Permalink
first pass at updating skimmer
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed Jul 22, 2024
1 parent 753838d commit f761bbe
Show file tree
Hide file tree
Showing 4 changed files with 789 additions and 542 deletions.
5 changes: 4 additions & 1 deletion src/HHbbVV/hh_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@
("lp_sf_pt_extrap_vars", 100),
("lp_sf_sys_down", 1),
("lp_sf_sys_up", 1),
("lp_sf_np_down", 1),
("lp_sf_np_up", 1),
("lp_sf_double_matched_event", 1),
("lp_sf_boundary_quarks", 1),
("lp_sf_inside_boundary_quarks", 1),
("lp_sf_outside_boundary_quarks", 1),
("lp_sf_unmatched_quarks", 1),
]
14 changes: 10 additions & 4 deletions src/HHbbVV/processors/bbVVSkimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class bbVVSkimmer(SkimmerABC):
jecs = hh_vars.jecs

# only the branches necessary for templates and post processing
# IMPORTANT: Add Lund plane branches in hh_vars.py
min_branches = [ # noqa: RUF012
"ak8FatJetPhi",
"ak8FatJetEta",
Expand Down Expand Up @@ -673,15 +674,16 @@ def process(self, events: ak.Array):
# Lund plane SFs
################

lp_hist = None

if isSignal and self._systematics and self._lp_sfs:
# (var, # columns)
logging.info("Starting LP SFs and saving: " + str(hh_vars.lp_sf_vars))

if len(skimmed_events["weight"]):
genbb = genbb[sel_all]
genq = genq[sel_all]

sf_dicts = []
sf_dicts, lp_hists = [], []
lp_num_jets = num_jets if self._save_all else 1

for i in range(lp_num_jets):
Expand Down Expand Up @@ -713,7 +715,7 @@ def process(self, events: ak.Array):

for key, (selector, gen_quarks, num_prongs) in selectors.items():
if np.sum(selector) > 0:
selected_sfs[key] = get_lund_SFs(
selected_sfs[key], lp_hist = get_lund_SFs(
year,
events[sel_all][selector],
fatjets[sel_all][selector],
Expand All @@ -724,10 +726,13 @@ def process(self, events: ak.Array):
), # giving HVV jet index if only doing LP SFs for HVV jet
num_prongs,
gen_quarks[selector],
weights_dict["weight"][sel_all][selector],
trunc_gauss=False,
lnN=True,
)

lp_hists.append(lp_hist)

sf_dict = {}

# collect all the scale factors, fill in 1s for unmatched jets
Expand All @@ -743,6 +748,7 @@ def process(self, events: ak.Array):
sf_dicts.append(sf_dict)

sf_dicts = concatenate_dicts(sf_dicts)
lp_hist = sum(lp_hists)

else:
logging.info("No signal events selected")
Expand Down Expand Up @@ -790,7 +796,7 @@ def process(self, events: ak.Array):
fname = events.behavior["__events_factory__"]._partition_key.replace("/", "_") + ".parquet"
self.dump_table(pddf, fname)

return {year: {dataset: {"totals": totals_dict, "cutflow": cutflow}}}
return {year: {dataset: {"totals": totals_dict, "cutflow": cutflow, "lp_hist": lp_hist}}}

def postprocess(self, accumulator):
return accumulator
Expand Down
108 changes: 91 additions & 17 deletions src/HHbbVV/processors/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import awkward as ak
import correctionlib
import hist
import numpy as np
from coffea import util as cutil
from coffea.analysis_tools import Weights
Expand Down Expand Up @@ -672,7 +673,8 @@ def add_trig_effs(weights: Weights, fatjets: FatJetArray, year: str, num_jets: i
ratio_sys_down,
pt_extrap_lookups_dict,
bratio,
) = (None, None, None, None, None, None, None)
ratio_edges,
) = (None, None, None, None, None, None, None, None)


def _get_lund_lookups(year: str, seed: int = 42, lnN: bool = True, trunc_gauss: bool = False):
Expand Down Expand Up @@ -785,16 +787,17 @@ def _np_pad(arr: np.ndarray, target: int = MAX_PT_FPARAMS):
ratio_sys_down,
pt_extrap_lookups_dict,
bratio,
ratio_edges,
)


def _get_flat_lp_vars(lds, kt_subjets_pt):
if len(lds) != 1:
# flatten and save offsets to unflatten afterwards
if type(lds.layout) == ak._ext.ListOffsetArray64:
if type(lds.layout) is ak._ext.ListOffsetArray64:
ld_offsets = lds.kt.layout.offsets
flat_subjet_pt = ak.flatten(kt_subjets_pt)
elif type(lds.layout) == ak._ext.ListArray64:
elif type(lds.layout) is ak._ext.ListArray64:
ld_offsets = lds.layout.toListOffsetArray64(False).offsets
flat_subjet_pt = kt_subjets_pt
else:
Expand All @@ -820,17 +823,14 @@ def _get_lund_arrays(
Gets the ``num_prongs`` subjet pTs and Delta and kT per primary LP splitting of fatjets at
``fatjet_idx`` in each event.
Features are flattened (for now), and offsets are saved in ``ld_offsets`` to recover the event
structure.
Args:
events (NanoEventsArray): nano events
jec_fatjets (FatJetArray): post-JEC fatjets, used to update subjet pTs.
fatjet_idx (int | ak.Array): fatjet index
num_prongs (int): number of prongs / subjets per jet to reweight
Returns:
flat_logD, flat_logkt, flat_subjet_pt, ld_offsets, kt_subjets_vec
lds, kt_subjets_vec, kt_subjets_pt: lund declusterings, subjet 4-vectors, JEC-corrected subjet pTs
"""

# jet definitions for LP SFs
Expand Down Expand Up @@ -880,6 +880,27 @@ def _get_lund_arrays(
return lds, kt_subjets_vec, kt_subjets_pt


def _get_flat_lund_arrays(events, jec_fatjet, fatjet_idx, num_prongs):
"""Wrapper for the _get_lund_arrays and _get_flat_lp_vars functions
returns: lds - lund declusterings,
kt_subjets_vec - subjet 4-vectors,
kt_subjets_pt - JEC-corrected subjet pTs,
ld_offsets - offsets for jagged structure,
flat_logD - flattened log(0.8/Delta),
flat_logkt - flattened log(kT/GeV),
flat_subjet_pt - flattened JEC-corrected subjet pTs
"""
lds, kt_subjets_vec, kt_subjets_pt = _get_lund_arrays(
events, jec_fatjet, fatjet_idx, num_prongs
)

lds_flat = ak.flatten(lds, axis=1)
ld_offsets, flat_logD, flat_logkt, flat_subjet_pt = _get_flat_lp_vars(lds_flat, kt_subjets_pt)

return lds, kt_subjets_vec, kt_subjets_pt, ld_offsets, flat_logD, flat_logkt, flat_subjet_pt


def _calc_lund_SFs(
flat_logD: np.ndarray,
flat_logkt: np.ndarray,
Expand Down Expand Up @@ -969,6 +990,7 @@ def get_lund_SFs(
fatjet_idx: int | ak.Array,
num_prongs: int,
gen_quarks: GenParticleArray,
weights: np.ndarray,
seed: int = 42,
trunc_gauss: bool = False,
lnN: bool = True,
Expand All @@ -984,6 +1006,8 @@ def get_lund_SFs(
jec_fatjets (FatJetArray): post-JEC fatjets, used to update subjet pTs.
fatjet_idx (int | ak.Array): fatjet index
num_prongs (int): number of prongs / subjets per jet to r
gen_quarks (GenParticleArray): gen quarks
weights (np.ndarray): event weights, for filling the LP histogram
seed (int, optional): seed for random smearings. Defaults to 42.
trunc_gauss (bool, optional): use truncated gaussians for smearing. Defaults to False.
lnN (bool, optional): use log normals for smearings. Defaults to True.
Expand All @@ -995,7 +1019,7 @@ def get_lund_SFs(
"""

# global variable to not have to load + smear LP ratios each time
global ratio_smeared_lookups, ratio_lnN_smeared_lookups, ratio_sys_up, ratio_sys_down, pt_extrap_lookups_dict, bratio, lp_year # noqa: PLW0603
global ratio_smeared_lookups, ratio_lnN_smeared_lookups, ratio_sys_up, ratio_sys_down, pt_extrap_lookups_dict, bratio, ratio_edges, lp_year # noqa: PLW0603

if (
(lnN and ratio_lnN_smeared_lookups is None)
Expand All @@ -1009,23 +1033,46 @@ def get_lund_SFs(
ratio_sys_down,
pt_extrap_lookups_dict,
bratio,
ratio_edges,
) = _get_lund_lookups(year, seed, lnN, trunc_gauss)
lp_year = year

ratio_nominal = ratio_lnN_smeared_lookups[0] if lnN else ratio_smeared_lookups[0]

jec_fatjet = jec_fatjets[np.arange(len(jec_fatjets)), fatjet_idx]

lds, kt_subjets_vec, kt_subjets_pt = _get_lund_arrays(
events, jec_fatjet, fatjet_idx, num_prongs
# get lund plane declusterings, subjets, and flattened LP vars
lds, kt_subjets_vec, kt_subjets_pt, ld_offsets, flat_logD, flat_logkt, flat_subjet_pt = (
_get_flat_lund_arrays(events, jec_fatjet, fatjet_idx, num_prongs)
)

lds_flat = ak.flatten(lds, axis=1)
ld_offsets, flat_logD, flat_logkt, flat_subjet_pt = _get_flat_lp_vars(lds_flat, kt_subjets_pt)
################################################################################################
# ---- Fill LP histogram for signal for distortion uncertainty ---- #
################################################################################################

sfs = {}
lp_hist = hist.Hist(
hist.axis.Variable(ratio_edges[0], name="subjet_pt", label="Subjet pT [GeV]"),
hist.axis.Variable(ratio_edges[1], name="logD", label="ln(0.8/Delta)"),
hist.axis.Variable(ratio_edges[2], name="logkt", label="ln(kT/GeV)"),
)

# repeat weights for each LP splitting
flat_weights = np.repeat(
np.repeat(weights, num_prongs), ak.count(ak.flatten(lds.kt, axis=1), axis=1)
)

lp_hist.fill(
subjet_pt=flat_subjet_pt,
logD=flat_logD,
logkt=flat_logkt,
weight=flat_weights,
)

################################################################################################
# ---- get scale factors per jet + smearings for stat unc. + syst. variations + pt extrap unc. ---- #
################################################################################################

sfs = {}

if trunc_gauss:
sfs["lp_sf"] = _calc_lund_SFs(
Expand Down Expand Up @@ -1079,7 +1126,29 @@ def get_lund_SFs(
pt_extrap_lookups_dict["smeared_params"],
)

################################################################################################
# ---- get scale factors after re-clusteing with +/- one prong, for subjet matching uncs. ---- #
################################################################################################

for shift, nps in [("down", num_prongs - 1), ("up", num_prongs + 1)]:
# get lund plane declusterings, subjets, and flattened LP vars
_, _, _, np_ld_offsets, np_flat_logD, np_flat_logkt, np_flat_subjet_pt = (
_get_flat_lund_arrays(events, jec_fatjet, fatjet_idx, nps)
)

sfs[f"lp_sf_np_{shift}"] = _calc_lund_SFs(
np_flat_logD,
np_flat_logkt,
np_flat_subjet_pt,
np_ld_offsets,
nps,
ratio_lnN_smeared_lookups,
[pt_extrap_lookups_dict["params"]],
)

################################################################################################
# ---- b-quark related uncertainties ---- #
################################################################################################

if gen_bs is not None:
assert ak.all(
Expand Down Expand Up @@ -1125,7 +1194,9 @@ def get_lund_SFs(
# weird edge case where b-subjet has no splittings
sfs["lp_sfs_bl_ratio"] = 1.0

################################################################################################
# ---- subjet matching uncertainties ---- #
################################################################################################

matching_dR = 0.2
sj_matched = []
Expand All @@ -1147,9 +1218,12 @@ def get_lund_SFs(
sj_matched_idx_mask[~sj_matched] = -1

j_q_dr = gen_quarks.delta_r(jec_fatjet)
q_boundary = (j_q_dr > 0.7) * (j_q_dr < 0.9)
# events with quarks at the boundary of the jet
sfs["lp_sf_boundary_quarks"] = np.array(np.any(q_boundary, axis=1, keepdims=True))
# events with quarks at the inside boundary of the jet
q_boundary = (j_q_dr > 0.7) * (j_q_dr <= 0.8)
sfs["lp_sf_inside_boundary_quarks"] = np.array(np.any(q_boundary, axis=1, keepdims=True))
# events with quarks at the outside boundary of the jet
q_boundary = (j_q_dr > 0.8) * (j_q_dr <= 0.9)
sfs["lp_sf_outside_boundary_quarks"] = np.array(np.any(q_boundary, axis=1, keepdims=True))

# events which have more than one quark matched to the same subjet
sfs["lp_sf_double_matched_event"] = np.any(
Expand All @@ -1162,4 +1236,4 @@ def get_lund_SFs(
# OLD pT extrapolation uncertainty
sfs["lp_sf_num_sjpt_gt350"] = np.sum(kt_subjets_vec.pt > 350, axis=1, keepdims=True).to_numpy()

return sfs
return sfs, lp_hist
Loading

0 comments on commit f761bbe

Please sign in to comment.