diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 9d8b59da..2adfd124 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -28,19 +28,8 @@ class CrossValidate: Instance of the `mpol.coordinates.GridCoords` class. imager : `mpol.gridding.DirtyImager` object Instance of the `mpol.gridding.DirtyImager` class. - kfolds : int, default=5 - Number of k-folds to use in cross-validation - split_method : str, default='random_cell' - Method to split full dataset into train/test subsets - seed : int, default=None - Seed for random number generator used in splitting data - learn_rate : float, default=0.5 - Neural network learning rate - epochs : int, default=10000 - Number of training iterations - convergence_tol : float, default=1e-3 - Tolerance for training iteration stopping criterion as assessed by - loss function (suggested <= 1e-3) + learn_rate : float, default=0.3 + Initial learning rate regularizers : nested dict, default={} Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda @@ -49,6 +38,14 @@ class CrossValidate: {"sparsity":{"lambda":1e-3, "guess":False}, "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} } + epochs : int, default=10000 + Number of training iterations + convergence_tol : float, default=1e-5 + Tolerance for training iteration stopping criterion as assessed by + loss function (suggested <= 1e-3) + schedule_factor : float, default=0.995 + For the `torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, factor + to which the learning rate is reduced when learning rate stops decreasing start_dirty_image : bool, default=False Whether to start the RML optimization loop by initializing the model image to a dirty image of the observed data. If False, the optimization @@ -56,6 +53,10 @@ class CrossValidate: train_diag_step : int, default=None Interval at which training diagnostics are output. If None, no diagnostics will be generated. + kfolds : int, default=5 + Number of k-folds to use in cross-validation + split_method : str, default='random_cell' + Method to split full dataset into train/test subsets split_diag_fig : bool, default=False Whether to generate a diagnostic figure of dataset splitting into train/test sets. @@ -64,15 +65,18 @@ class CrossValidate: save_prefix : str, default=None Prefix (path) used for saved figure names. If None, figures won't be saved - device : torch.device, default=None - Which hardware device to perform operations on (e.g., 'cuda:0'). - 'None' defaults to current device. verbose : bool, default=True Whether to print notification messages. + device : torch.device, default=None + Which hardware device to perform operations on (e.g., 'cuda:0'). + 'None' defaults to current device. + seed : int, default=None + Seed for random number generator used in splitting data """ - def __init__(self, coords, imager, learn_rate=0.5, - regularizers={}, epochs=10000, convergence_tol=1e-3, + def __init__(self, coords, imager, learn_rate=0.3, + regularizers={}, epochs=10000, convergence_tol=1e-5, + schedule_factor=0.995, start_dirty_image=False, train_diag_step=None, kfolds=5, split_method="random_cell", split_diag_fig=False, store_cv_diagnostics=False, @@ -80,20 +84,21 @@ def __init__(self, coords, imager, learn_rate=0.5, ): self._coords = coords self._imager = imager - self._kfolds = kfolds - self._split_method = split_method - self._seed = seed self._learn_rate = learn_rate + self._regularizers = regularizers self._epochs = epochs self._convergence_tol = convergence_tol - self._regularizers = regularizers + self._schedule_factor = schedule_factor self._start_dirty_image = start_dirty_image self._train_diag_step = train_diag_step + self._kfolds = kfolds + self._split_method = split_method self._split_diag_fig = split_diag_fig self._store_cv_diagnostics = store_cv_diagnostics self._save_prefix = save_prefix - self._device = device self._verbose = verbose + self._device = device + self._seed = seed self._split_figure = None @@ -186,10 +191,17 @@ def run_crossval(self, dataset): # create a new model for this kfold, initializing it to the model pretrained on the dirty image model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt")) + # create a new optimizer and scheduler for this kfold optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate) + if self._schedule_factor is None: + scheduler = None + else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self._schedule_factor) + trainer = TrainTest( imager=self._imager, optimizer=optimizer, + scheduler=scheduler, epochs=self._epochs, convergence_tol=self._convergence_tol, regularizers=self._regularizers, diff --git a/src/mpol/training.py b/src/mpol/training.py index 518ccf3a..5eec3fca 100644 --- a/src/mpol/training.py +++ b/src/mpol/training.py @@ -64,9 +64,7 @@ class TrainTest: Args: imager (:class:`mpol.gridding.DirtyImager` object): Instance of the `mpol.gridding.DirtyImager` class. optimizer (:class:`torch.optim` object): PyTorch optimizer class for the training loop. - epochs (int): Number of training iterations, default=10000 - convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by - loss function (suggested <= 1e-3) + scheduler (:class:`torch.optim.lr_scheduler` object, default=None): Scheduler for adjusting learning rate during optimization. regularizers (nested dict): Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda ('guess', bool), and other quantities needed to compute their loss term. Example: @@ -74,24 +72,30 @@ class TrainTest: "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} }`` + epochs (int): Number of training iterations, default=10000 + convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by + loss function (suggested <= 1e-3) train_diag_step (int): Interval at which training diagnostics are output. If None, no diagnostics will be generated. kfold (int): The k-fold of the current training set (for diagnostics) save_prefix (str): Prefix (path) used for saved figure names. If None, figures won't be saved verbose (bool): Whether to print notification messages """ - def __init__(self, imager, optimizer, epochs=10000, convergence_tol=1e-3, - regularizers={}, train_diag_step=None, kfold=None, - save_prefix=None, verbose=True - ): + def __init__(self, imager, optimizer, scheduler=None, regularizers={}, + epochs=10000, convergence_tol=1e-5, + train_diag_step=None, + kfold=None, save_prefix=None, verbose=True + ): self._imager = imager - self._optimizer = optimizer + self._optimizer = optimizer + self._scheduler = scheduler + self._regularizers = regularizers self._epochs = epochs self._convergence_tol = convergence_tol - self._regularizers = regularizers + self._train_diag_step = train_diag_step - self._save_prefix = save_prefix self._kfold = kfold + self._save_prefix = save_prefix self._verbose = verbose self._train_figure = None @@ -200,7 +204,7 @@ def loss_eval(self, vis, dataset, sky_cube=None): loss += self._regularizers['TSV']['lambda'] * TSV(sky_cube) return loss - + def train(self, model, dataset): r""" @@ -227,7 +231,7 @@ def train(self, model, dataset): count = 0 losses = [] - self._train_state = {} + learn_rates = [] # guess initial strengths for regularizers in `self._regularizers` # that have 'guess':True @@ -246,7 +250,7 @@ def train(self, model, dataset): ) # check early-on whether the loss isn't evolving - if count == 20: + if count == 10: loss_arr = np.array(losses) if all(0.9 <= loss_arr[:-1] / loss_arr[1:]) and all( loss_arr[:-1] / loss_arr[1:] <= 1.1 @@ -277,11 +281,9 @@ def train(self, model, dataset): # update model parameters via gradient descent self._optimizer.step() - # store current training parameter values - # TODO: store hyperpar values, access in crossval.py - self._train_state["kfold"] = self._kfold - self._train_state["epoch"] = count - self._train_state["learn_rate"] = self._optimizer.state_dict()['param_groups'][0]['lr'] + if self._scheduler is not None: + self._scheduler.step(loss) + learn_rates.append(self._optimizer.param_groups[0]['lr']) # generate optional fit diagnostics if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))): diff --git a/test/conftest.py b/test/conftest.py index 39eee5e6..dc405fb1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -208,7 +208,7 @@ def crossvalidation_products(mock_visibility_data): def generic_parameters(tmp_path): # generic model parameters to test training loop and cross-val loop regularizers = { - "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}, + "entropy": {"lambda":1e-3, "guess":False, "prior_intensity":1e-10}, } train_pars = {"epochs":15, "convergence_tol":1e-3, diff --git a/test/train_test_test.py b/test/train_test_test.py index fb1894b3..fd457e4c 100644 --- a/test/train_test_test.py +++ b/test/train_test_test.py @@ -16,7 +16,8 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters): model = precomposed.SimpleNet(coords=coords, nchan=nchan) train_pars = generic_parameters["train_pars"] - # bypass TrainTest.loss_lambda_guess + + # no regularizers train_pars["regularizers"] = {} learn_rate = generic_parameters["crossval_pars"]["learn_rate"] @@ -27,6 +28,26 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters): loss, loss_history = trainer.train(model, dataset) +def test_traintestclass_training_scheduler(coords, imager, dataset, generic_parameters): + # using the TrainTest class, run a training loop with regularizers, + # using the learning rate scheduler + nchan = dataset.nchan + model = precomposed.SimpleNet(coords=coords, nchan=nchan) + + train_pars = generic_parameters["train_pars"] + + learn_rate = generic_parameters["crossval_pars"]["learn_rate"] + + optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) + + # use a scheduler + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.995) + train_pars["scheduler"] = scheduler + + trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) + loss, loss_history = trainer.train(model, dataset) + + def test_traintestclass_training_guess(coords, imager, dataset, generic_parameters): # using the TrainTest class, run a training loop with regularizers, # with a call to the regularizer strength guesser @@ -37,6 +58,8 @@ def test_traintestclass_training_guess(coords, imager, dataset, generic_paramete learn_rate = generic_parameters["crossval_pars"]["learn_rate"] + train_pars['regularizers']['entropy']['guess'] = True + optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) @@ -60,8 +83,9 @@ def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_p trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) loss, loss_history = trainer.train(model, dataset) - train_state = trainer.train_state - train_fig, train_axes = train_diagnostics_fig(model, losses=loss_history, train_state=train_state) + train_fig, train_axes = train_diagnostics_fig(model, + losses=loss_history, + ) train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300) plt.close("all")