Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross detector #376

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
27 changes: 19 additions & 8 deletions mlpf/model/PFDataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import sys
from types import SimpleNamespace

import numpy as np
import tensorflow_datasets as tfds
import torch
import torch.utils.data

from mlpf.model.logger import _logger

import numpy as np
import sys


class TFDSDataSource:
def __init__(self, ds, sort):
Expand Down Expand Up @@ -117,7 +116,9 @@ def __init__(self, data_dir, name, split, num_samples=None, sort=False):
builder = tfds.builder(name, data_dir=data_dir)
except Exception:
_logger.error(
"Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format(name, data_dir)
"Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format(
name, data_dir
)
)
sys.exit(1)
self.ds = TFDSDataSource(builder.as_data_source(split=split), sort=sort)
Expand Down Expand Up @@ -156,15 +157,19 @@ def to(self, device, **kwargs):
class Collater:
def __init__(self, per_particle_keys_to_get, per_event_keys_to_get, **kwargs):
super(Collater, self).__init__(**kwargs)
self.per_particle_keys_to_get = per_particle_keys_to_get # these quantities are a variable-length tensor per each event
self.per_particle_keys_to_get = (
per_particle_keys_to_get # these quantities are a variable-length tensor per each event
)
self.per_event_keys_to_get = per_event_keys_to_get # these quantities are one value (scalar) per event

def __call__(self, inputs):
ret = {}

# per-particle quantities need to be padded across events of different size
for key_to_get in self.per_particle_keys_to_get:
ret[key_to_get] = torch.nn.utils.rnn.pad_sequence([torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True)
ret[key_to_get] = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True
)

# per-event quantities can be stacked across events
for key_to_get in self.per_event_keys_to_get:
Expand Down Expand Up @@ -229,12 +234,16 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
split_configs = config[f"{split}_dataset"][config["dataset"]][type_]["samples"][sample]["splits"]
print("split_configs", split_configs)

nevents = None
if not (config[f"n{split}"] is None):
nevents = config[f"n{split}"] // len(split_configs)

for split_config in split_configs:
ds = PFDataset(
config["data_dir"],
f"{sample}/{split_config}:{version}",
split,
num_samples=config[f"n{split}"],
num_samples=nevents,
sort=config["sort_data"],
).ds

Expand All @@ -258,7 +267,9 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=Collater(["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"]),
collate_fn=Collater(
["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"]
),
sampler=sampler,
num_workers=config["num_workers"],
prefetch_factor=config["prefetch_factor"],
Expand Down
36 changes: 25 additions & 11 deletions mlpf/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,30 @@
import tqdm
import vector
from jet_utils import match_two_jet_collections
from plotting.plot_utils import (
get_class_names,
from plotting.plot_utils import ( # plot_elements,
compute_met_and_ratio,
get_class_names,
load_eval_data,
plot_jets,
plot_jet_ratio,
plot_jet_response_binned,
plot_jet_response_binned_vstarget,
plot_jet_response_binned_eta,
plot_jet_response_binned_vstarget,
plot_jets,
plot_met,
plot_met_ratio,
plot_met_response_binned,
plot_num_elements,
plot_particles,
plot_particle_ratio,
# plot_elements,
plot_particles,
)

from .logger import _logger
from .utils import unpack_predictions, unpack_target


def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample):
def predict_one_batch(
conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_etacut, jet_match_dr, outpath, dir_name, sample
):

# skip prediction if output exists
outfile = f"{outpath}/preds{dir_name}/{sample}/pred_{rank}_{i}.parquet"
Expand Down Expand Up @@ -62,7 +63,8 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m
ycand = unpack_target(batch.ycand.to(torch.float32), model)
ypred = unpack_predictions(ypred)

genjets_msk = batch.genjets[:, :, 0].cpu() > jet_ptcut
genjets_msk = (batch.genjets[:, :, 0].cpu() > jet_ptcut) & (abs(batch.genjets[:, :, 1]).cpu() < jet_etacut)

genjets = awkward.unflatten(batch.genjets.cpu().to(torch.float64)[genjets_msk], torch.sum(genjets_msk, axis=1))
genjets = vector.awk(
awkward.zip(
Expand Down Expand Up @@ -125,7 +127,15 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m
)

awkward.to_parquet(
awkward.Array({"inputs": Xs, "particles": awkvals, "jets": jets_coll, "matched_jets": matched_jets, "genmet": batch.genmet.cpu()}),
awkward.Array(
{
"inputs": Xs,
"particles": awkvals,
"jets": jets_coll,
"matched_jets": matched_jets,
"genmet": batch.genmet.cpu(),
}
),
outfile,
)
_logger.info(f"Saved predictions at {outfile}")
Expand All @@ -136,7 +146,9 @@ def predict_one_batch_args(args):


@torch.no_grad()
def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, jet_ptcut=15.0, jet_match_dr=0.1, dir_name=""):
def run_predictions(
world_size, rank, model, loader, sample, outpath, jetdef, jet_ptcut=15.0, jet_etacut=2.5, jet_match_dr=0.1, dir_name=""
):
"""Runs inference on the given sample and stores the output as .parquet files."""
if world_size > 1:
conv_type = model.module.conv_type
Expand All @@ -153,7 +165,9 @@ def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, je

ti = time.time()
for i, batch in iterator:
predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample)
predict_one_batch(
conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_etacut, jet_match_dr, outpath, dir_name, sample
)

_logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min")

Expand Down
Loading
Loading