diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 98788a322..533cbeb62 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -24,7 +24,7 @@ from pyg.mlpf import MLPF from pyg.PFDataset import Collater, InterleavedIterator, PFDataLoader, PFDataset from pyg.training import train_mlpf -from pyg.utils import CLASS_LABELS, X_FEATURES, save_HPs, load_checkpoint +from pyg.utils import CLASS_LABELS, X_FEATURES, load_checkpoint, save_HPs from utils import create_experiment_dir logging.basicConfig(level=logging.INFO) @@ -123,7 +123,7 @@ def run(rank, world_size, config, args, outdir, logfile): for split in ["train", "valid"]: # build train, valid dataset and dataloaders loaders[split] = [] # build dataloader for physical and gun samples seperately - for type_ in config[f"{split}_dataset"][config["dataset"]]: # will be "physical", "gun" + for type_ in config[f"{split}_dataset"][config["dataset"]]: # will be "physical", "gun", "multiparticlegun" dataset = [] for sample in config[f"{split}_dataset"][config["dataset"]][type_]["samples"]: version = config[f"{split}_dataset"][config["dataset"]][type_]["samples"][sample]["version"] @@ -159,7 +159,7 @@ def run(rank, world_size, config, args, outdir, logfile): ) ) - loaders[split] = InterleavedIterator(loaders[split]) # will interleave just two dataloaders + loaders[split] = InterleavedIterator(loaders[split]) # will interleave maximum of three dataloaders train_mlpf( rank, diff --git a/parameters/pyg-cms.yaml b/parameters/pyg-cms.yaml index a10295cc3..9bb7ee39d 100644 --- a/parameters/pyg-cms.yaml +++ b/parameters/pyg-cms.yaml @@ -95,6 +95,9 @@ train_dataset: version: 1.6.0 cms_pf_single_proton: version: 1.6.0 + multiparticlegun: + batch_size: 2 + samples: cms_pf_multi_particle_gun: version: 1.6.0 @@ -141,5 +144,8 @@ test_dataset: version: 1.6.0 cms_pf_single_proton: version: 1.6.0 + multiparticlegun: + batch_size: 2 + samples: cms_pf_multi_particle_gun: version: 1.6.0