diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 25fbc049..e95f0803 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -12,7 +12,7 @@ from mpol.datasets import Dartboard, GriddedDataset from mpol.precomposed import SimpleNet -from mpol.training import TrainTest +from mpol.training import TrainTest, train_to_dirty_image from mpol.plot import split_diagnostics_fig @@ -28,19 +28,8 @@ class CrossValidate: Instance of the `mpol.coordinates.GridCoords` class. imager : `mpol.gridding.DirtyImager` object Instance of the `mpol.gridding.DirtyImager` class. - kfolds : int, default=5 - Number of k-folds to use in cross-validation - split_method : str, default='random_cell' - Method to split full dataset into train/test subsets - seed : int, default=None - Seed for random number generator used in splitting data - learn_rate : float, default=0.5 - Neural network learning rate - epochs : int, default=10000 - Number of training iterations - convergence_tol : float, default=1e-3 - Tolerance for training iteration stopping criterion as assessed by - loss function (suggested <= 1e-3) + learn_rate : float, default=0.3 + Initial learning rate regularizers : nested dict, default={} Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda @@ -49,9 +38,25 @@ class CrossValidate: {"sparsity":{"lambda":1e-3, "guess":False}, "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} } + epochs : int, default=10000 + Number of training iterations + convergence_tol : float, default=1e-5 + Tolerance for training iteration stopping criterion as assessed by + loss function (suggested <= 1e-3) + schedule_factor : float, default=0.995 + For the `torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, factor + to which the learning rate is reduced when learning rate stops decreasing + start_dirty_image : bool, default=False + Whether to start the RML optimization loop by initializing the model + image to a dirty image of the observed data. If False, the optimization + loop will start with a blank image. train_diag_step : int, default=None Interval at which training diagnostics are output. If None, no diagnostics will be generated. + kfolds : int, default=5 + Number of k-folds to use in cross-validation + split_method : str, default='random_cell' + Method to split full dataset into train/test subsets split_diag_fig : bool, default=False Whether to generate a diagnostic figure of dataset splitting into train/test sets. @@ -60,39 +65,45 @@ class CrossValidate: save_prefix : str, default=None Prefix (path) used for saved figure names. If None, figures won't be saved - device : torch.device, default=None - Which hardware device to perform operations on (e.g., 'cuda:0'). - 'None' defaults to current device. verbose : bool, default=True Whether to print notification messages. + device : torch.device, default=None + Which hardware device to perform operations on (e.g., 'cuda:0'). + 'None' defaults to current device. + seed : int, default=None + Seed for random number generator used in splitting data """ - def __init__(self, coords, imager, kfolds=5, split_method="random_cell", - seed=None, learn_rate=0.5, epochs=10000, convergence_tol=1e-3, - regularizers={}, train_diag_step=None, split_diag_fig=False, - store_cv_diagnostics=False, save_prefix=None, device=None, - verbose=True + def __init__(self, coords, imager, learn_rate=0.3, + regularizers={}, epochs=10000, convergence_tol=1e-5, + schedule_factor=0.995, + start_dirty_image=False, + train_diag_step=None, kfolds=5, split_method="random_cell", + split_diag_fig=False, store_cv_diagnostics=False, + save_prefix=None, verbose=True, device=None, seed=None, ): self._coords = coords self._imager = imager - self._kfolds = kfolds - self._split_method = split_method - self._seed = seed self._learn_rate = learn_rate + self._regularizers = regularizers self._epochs = epochs self._convergence_tol = convergence_tol - self._regularizers = regularizers + self._schedule_factor = schedule_factor + self._start_dirty_image = start_dirty_image self._train_diag_step = train_diag_step + self._kfolds = kfolds + self._split_method = split_method self._split_diag_fig = split_diag_fig self._store_cv_diagnostics = store_cv_diagnostics self._save_prefix = save_prefix - self._device = device self._verbose = verbose + self._device = device + self._seed = seed - self._model = None - self._diagnostics = None self._split_figure = None - self._train_figure = None + + # used to collect objects across all kfolds + self._diagnostics = None def split_dataset(self, dataset): r""" @@ -159,22 +170,38 @@ def run_crossval(self, dataset): for kk, (train_set, test_set) in enumerate(split_iterator): if self._verbose: logging.info( - "\nCross-validation: k-fold {} of {}".format(kk, self._kfolds) + "\nCross-validation: k-fold {} of {}".format(kk, self._kfolds - 1) ) # if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu # train_set, test_set = train_set.to(self._device), test_set.to(self._device) - # create a new model and optimizer for this k_fold - self._model = SimpleNet(coords=self._coords, nchan=self._imager.nchan) - # if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu - # self._model = self._model.to(self._device) - - optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learn_rate) - - trainer = TrainTest( + model = SimpleNet(coords=self._coords, nchan=self._imager.nchan) + if self._start_dirty_image is True: + if kk == 0: + if self._verbose: + logging.info( + "\n Pre-training to dirty image to initialize subsequent optimization loops" + ) + # initial short training loop to get model image to approximate dirty image + model_pretrained = train_to_dirty_image(model=model, imager=self._imager) + # save the model to a state we can load in subsequent kfolds + torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt") + else: + # create a new model for this kfold, initializing it to the model pretrained on the dirty image + model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt")) + + # create a new optimizer and scheduler for this kfold + optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate) + if self._schedule_factor is None: + scheduler = None + else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self._schedule_factor) + + trainer = TrainTest( imager=self._imager, optimizer=optimizer, + scheduler=scheduler, epochs=self._epochs, convergence_tol=self._convergence_tol, regularizers=self._regularizers, @@ -185,17 +212,21 @@ def run_crossval(self, dataset): ) # run training - loss, loss_history = trainer.train(self._model, train_set) + loss, loss_history = trainer.train(model, train_set) - if self._store_cv_diagnostics: - self._diagnostics["loss_histories"].append(loss_history) - # update regularizer strength values - self._regularizers = trainer.regularizers - # store the most recent train figure for diagnostics - self._train_figure = trainer.train_figure - # run testing - all_scores.append(trainer.test(self._model, test_set)) + all_scores.append(trainer.test(model, test_set)) + + # store objects from the most recent kfold for diagnostics + self._model = model + self._train_figure = trainer.train_figure + + # collect objects from this kfold to store + if self._store_cv_diagnostics: + self._diagnostics["models"].append(self._model) + self._diagnostics["regularizers"].append(self._regularizers) + self._diagnostics["loss_histories"].append(loss_history) + self._diagnostics["train_figures"].append(self._train_figure) # average individual test scores to get the cross-val metric for chosen # hyperparameters @@ -204,33 +235,33 @@ def run_crossval(self, dataset): "std": np.std(all_scores), "all": all_scores, } - + return cv_score - + @property def model(self): - """SimpleNet class instance""" + """For the most recent kfold, trained model (`SimpleNet` class instance)""" return self._model @property def regularizers(self): - """Dict containing regularizers used and their strengths""" + """For the most recent kfold, dict containing regularizers used and their strengths""" return self._regularizers @property - def diagnostics(self): - """Dict containing diagnostics of the cross-validation loop""" - return self._diagnostics - + def train_figure(self): + """For the most recent kfold, (fig, axes) showing training progress""" + return self._train_figure + @property def split_figure(self): """(fig, axes) of train/test splitting diagnostic figure""" return self._split_figure - + @property - def train_figure(self): - """(fig, axes) of most recent training diagnostic figure""" - return self._train_figure + def diagnostics(self): + """Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values, training figures""" + return self._diagnostics class RandomCellSplitGridded: diff --git a/src/mpol/plot.py b/src/mpol/plot.py index 6a5d8cf7..4ce81184 100644 --- a/src/mpol/plot.py +++ b/src/mpol/plot.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import matplotlib.colors as mco from matplotlib.patches import Ellipse +import torch from astropy.visualization.mpl_normalize import simple_norm @@ -104,7 +105,7 @@ def get_residual_image(model, u, v, V, weights, robust=0.5): def plot_image(image, extent, cmap="inferno", norm=None, ax=None, - clab=r"Jy arcsec$^{-2}$", + clab=r"I [Jy arcsec$^{-2}$]", xlab=r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]", ylab=r"$\Delta \delta$ [${}^{\prime\prime}$]", ): @@ -153,7 +154,7 @@ def plot_image(image, extent, cmap="inferno", norm=None, ax=None, norm=norm, ) - cbar = plt.colorbar(im, ax=ax, location="right", pad=0.1) + cbar = plt.colorbar(im, ax=ax, location="right", pad=0.1, shrink=0.7) cbar.set_label(clab) ax.set_xlabel(xlab) @@ -399,27 +400,41 @@ def split_diagnostics_fig(splitter, channel=0, save_prefix=None): return fig, axes -def train_diagnostics_fig(model, losses=[], train_state=None, channel=0, - save_prefix=None): +def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None, + old_model_image=None, old_model_epoch=None, + kfold=None, epoch=None, + channel=0, save_prefix=None): """ Figure for model diagnostics during an optimization loop. For a `model` in a given state, plots the current: - - model image (both linear and arcsinh colormap normalization) + - model image + - flux of model image - gradient image + - difference image between `old_model_image` and current model image - loss function + - learning rate Parameters ---------- model : `torch.nn.Module` object - A neural network; instance of the `mpol.precomposed.SimpleNet` class. + A neural network module; instance of the `mpol.precomposed.SimpleNet` class. losses : list Loss value at each epoch in the training loop - train_state : dict, default=None - Dictionary containing current training parameter values. Used for - figure title and savefile name. + learn_rates : list + Learning rate at each epoch in the training loop + fluxes : list + Total flux in model image at each epoch in the training loop + old_model_image : 2D image array, default=None + Model image of a previous epoch for comparison to current image + old_model_epoch : int + Epoch of `old_model_image` + kfold : int, default=None + Current cross-validation k-fold + epoch : int, default=None + Current training epoch channel : int, default=0 Channel (of the datasets in `splitter`) to use to generate figure - save_prefix : string, default = None + save_prefix : str, default = None Prefix for saved figure name. If None, the figure won't be saved Returns @@ -430,39 +445,66 @@ def train_diagnostics_fig(model, losses=[], train_state=None, channel=0, Axes of the generated figure """ fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 8)) + axes[1][1].remove() - fig.suptitle(train_state) - + fig.suptitle(f"Pixel size {model.coords.cell_size * 1e3:.2f} mas, N_pix {model.coords.npix}\nk-fold {kfold}, epoch {epoch}", fontsize=10) + mod_im = torch2npy(model.icube.sky_cube[channel]) mod_grad = torch2npy(packed_cube_to_sky_cube(model.bcube.base_cube.grad)[channel]) extent = model.icube.coords.img_ext # model image (linear colormap) - ax = axes[0,0] - plot_image(mod_im, extent, ax=ax, xlab='', ylab='') - ax.set_title("Model image") + # ax = axes[0,0] + # plot_image(mod_im, extent, ax=ax, xlab='', ylab='') + # ax.set_title("Model image") # model image (asinh colormap) - ax = axes[0,1] - plot_image(mod_im, extent, ax=ax, norm=get_image_cmap_norm(mod_im, stretch='asinh')) - ax.set_title("Model image (asinh stretch)") + ax = axes[0,0] + plot_image(mod_im, extent, ax=ax, xlab='', ylab='', norm=get_image_cmap_norm(mod_im, stretch='asinh')) + ax.set_title("Model image", fontsize=10) # gradient image ax = axes[1,0] - plot_image(mod_grad, extent, ax=ax, xlab='', ylab='') - ax.set_title("Gradient image") + plot_image(mod_grad, extent, ax=ax) + ax.set_title("Gradient image", fontsize=10) + + if old_model_image is not None: + # current model image - model image at last stored epoch + ax = axes[0,1] + diff_image = mod_im - old_model_image + diff_im_norm = get_image_cmap_norm(diff_image, symmetric=True) + plot_image(diff_image, extent, cmap='RdBu_r', ax=ax, xlab='', ylab='', norm=diff_im_norm) + ax.set_title(f"Difference (epoch {epoch} - {old_model_epoch})", fontsize=10) + + if losses is not None: + # loss function + ax = fig.add_subplot(426) + ax.semilogy(losses, 'k', label=f"{losses[-1]:.3f}") + ax.legend(loc='upper right') + ax.xaxis.set_tick_params(labelbottom=False) + ax.set_ylabel('Loss') + + if learn_rates is not None: + # learning rate + ax = fig.add_subplot(428) + ax.plot(learn_rates, 'k', label=f"{learn_rates[-1]:.3e}") + ax.legend(loc='upper right') + ax.set_xlabel('Epoch') + ax.set_ylabel('Learn rate') - # loss function - ax = axes[1,1] - ax.semilogy(losses, 'k') - ax.set_xlabel('Epoch') - ax.set_ylabel('Loss') - ax.set_title("Loss function") + plt.tight_layout() - fig.subplots_adjust(wspace=0.25) + if fluxes is not None: + # total flux in model image + ax = fig.add_axes([0.08, 0.465, 0.3, 0.08]) + ax.plot(fluxes, 'k', label=f"{fluxes[-1]:.4f}") + ax.legend(loc='upper right', fontsize=8) + ax.tick_params(labelsize=8) + # ax.set_xlabel('Epoch', fontsize=8) + ax.set_ylabel('Flux [Jy]', fontsize=8) if save_prefix is not None: - fig.savefig(save_prefix + '_train_diag_kfold{}_epoch{:05d}.png'.format(train_state["kfold"], train_state["epoch"]), dpi=300) + fig.savefig(save_prefix + f"_train_diag_kfold{kfold}_epoch{epoch:05d}.png", dpi=300) plt.close() diff --git a/src/mpol/training.py b/src/mpol/training.py index 22223954..4a480904 100644 --- a/src/mpol/training.py +++ b/src/mpol/training.py @@ -4,6 +4,60 @@ from mpol.losses import TSV, TV_image, entropy, nll_gridded, sparsity from mpol.plot import train_diagnostics_fig +from mpol.utils import torch2npy + +def train_to_dirty_image(model, imager, robust=0.5, learn_rate=100, niter=1000): + r""" + Train against a dirty image of the observed visibilities using a loss function + of the mean squared error between the RML model image pixel fluxes and the + dirty image pixel fluxes. Useful for initializing a separate RML optimization + loop at a reasonable starting image. + + Parameters + ---------- + model : `torch.nn.Module` object + A neural network module; instance of the `mpol.precomposed.SimpleNet` class. + imager : :class:`mpol.gridding.DirtyImager` object + Instance of the `mpol.gridding.DirtyImager` class. + robust : float, default=0.5 + Robust weighting parameter used to create a dirty image. + learn_rate : float, default=100 + Learning rate for optimization loop + niter : int, default=1000 + Number of iterations for optimization loop + + Returns + ------- + model : `torch.nn.Module` object + The input `model` updated with the state of the training to the + dirty image + """ + logging.info(" Initializing model to dirty image") + + img, beam = imager.get_dirty_image(weighting="briggs", + robust=robust, + unit="Jy/arcsec^2") + dirty_image = torch.tensor(img.copy()) + optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate) + + losses = [] + for ii in range(niter): + optimizer.zero_grad() + + model() + + sky_cube = model.icube.sky_cube + + lossfunc = torch.nn.MSELoss(reduction="sum") + # MSELoss calculates mean squared error (squared L2 norm), so sqrt it + loss = (lossfunc(sky_cube, dirty_image)) ** 0.5 + losses.append(loss.item()) + + loss.backward() + optimizer.step() + + return model + class TrainTest: r""" @@ -12,9 +66,7 @@ class TrainTest: Args: imager (:class:`mpol.gridding.DirtyImager` object): Instance of the `mpol.gridding.DirtyImager` class. optimizer (:class:`torch.optim` object): PyTorch optimizer class for the training loop. - epochs (int): Number of training iterations, default=10000 - convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by - loss function (suggested <= 1e-3) + scheduler (:class:`torch.optim.lr_scheduler` object, default=None): Scheduler for adjusting learning rate during optimization. regularizers (nested dict): Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda ('guess', bool), and other quantities needed to compute their loss term. Example: @@ -22,24 +74,30 @@ class TrainTest: "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} }`` + epochs (int): Number of training iterations, default=10000 + convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by + loss function (suggested <= 1e-3) train_diag_step (int): Interval at which training diagnostics are output. If None, no diagnostics will be generated. kfold (int): The k-fold of the current training set (for diagnostics) save_prefix (str): Prefix (path) used for saved figure names. If None, figures won't be saved verbose (bool): Whether to print notification messages """ - def __init__(self, imager, optimizer, epochs=10000, convergence_tol=1e-3, - regularizers={}, train_diag_step=None, kfold=None, - save_prefix=None, verbose=True - ): + def __init__(self, imager, optimizer, scheduler=None, regularizers={}, + epochs=10000, convergence_tol=1e-5, + train_diag_step=None, + kfold=None, save_prefix=None, verbose=True + ): self._imager = imager - self._optimizer = optimizer + self._optimizer = optimizer + self._scheduler = scheduler + self._regularizers = regularizers self._epochs = epochs self._convergence_tol = convergence_tol - self._regularizers = regularizers + self._train_diag_step = train_diag_step - self._save_prefix = save_prefix self._kfold = kfold + self._save_prefix = save_prefix self._verbose = verbose self._train_figure = None @@ -78,7 +136,10 @@ def loss_lambda_guess(self): The guesses update `lambda` values in `self._regularizers`. """ - + if self._verbose: + logging.info(" Updating regularizer strengths with automated " + f"guessing. Initial values: {self._regularizers}") + # generate images of the data using two briggs robust values img1, _ = self._imager.get_dirty_image(weighting='briggs', robust=0.0) img2, _ = self._imager.get_dirty_image(weighting='briggs', robust=0.5) @@ -114,6 +175,9 @@ def loss_lambda_guess(self): guess_TSV = 1 / (loss_TSV2 - loss_TSV1) self._regularizers['TSV']['lambda'] = guess_TSV.numpy().item() + if self._verbose: + logging.info(f" Updated values: {self._regularizers}") + def loss_eval(self, vis, dataset, sky_cube=None): r""" @@ -148,7 +212,7 @@ def loss_eval(self, vis, dataset, sky_cube=None): loss += self._regularizers['TSV']['lambda'] * TSV(sky_cube) return loss - + def train(self, model, dataset): r""" @@ -159,7 +223,7 @@ def train(self, model, dataset): Parameters ---------- model : `torch.nn.Module` object - A neural network; instance of the `mpol.precomposed.SimpleNet` class. + A neural network module; instance of the `mpol.precomposed.SimpleNet` class. dataset : PyTorch dataset object Instance of the `mpol.datasets.GriddedDataset` class. @@ -170,12 +234,16 @@ def train(self, model, dataset): losses : list of float Loss value at each iteration (epoch) in the loop """ + # set model to training mode model.train() count = 0 + fluxes = [] losses = [] - self._train_state = {} + learn_rates = [] + old_mod_im = None + old_mod_epoch = None # guess initial strengths for regularizers in `self._regularizers` # that have 'guess':True @@ -194,7 +262,7 @@ def train(self, model, dataset): ) # check early-on whether the loss isn't evolving - if count == 20: + if count == 10: loss_arr = np.array(losses) if all(0.9 <= loss_arr[:-1] / loss_arr[1:]) and all( loss_arr[:-1] / loss_arr[1:] <= 1.1 @@ -215,6 +283,10 @@ def train(self, model, dataset): # get predicted sky cube corresponding to model visibilities sky_cube = model.icube.sky_cube + # total flux in model image + total_flux = model.coords.cell_size ** 2 * torch.sum(sky_cube) + fluxes.append(torch2npy(total_flux)) + # calculate loss between model visibilities and data loss = self.loss_eval(vis, dataset, sky_cube) losses.append(loss.item()) @@ -225,20 +297,25 @@ def train(self, model, dataset): # update model parameters via gradient descent self._optimizer.step() - # store current training parameter values - # TODO: store hyperpar values, access in crossval.py - self._train_state["kfold"] = self._kfold - self._train_state["epoch"] = count - self._train_state["learn_rate"] = self._optimizer.state_dict()['param_groups'][0]['lr'] + if self._scheduler is not None: + self._scheduler.step(loss) + learn_rates.append(self._optimizer.param_groups[0]['lr']) # generate optional fit diagnostics if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))): train_fig, train_axes = train_diagnostics_fig( - model, losses=losses, train_state=self._train_state, + model, losses=losses, learn_rates=learn_rates, fluxes=fluxes, + old_model_image=old_mod_im, + old_model_epoch=old_mod_epoch, + kfold=self._kfold, epoch=count, save_prefix=self._save_prefix ) self._train_figure = (train_fig, train_axes) + # temporarily store the current model image for use in next call to `train_diagnostics_fig` + old_mod_im = torch2npy(model.icube.sky_cube[0]) # TODO: support 'channel' (in TrainTest) + old_mod_epoch = count * 1 + count += 1 if self._verbose: @@ -260,7 +337,7 @@ def test(self, model, dataset): Parameters ---------- model : `torch.nn.Module` object - A neural network; instance of the `mpol.precomposed.SimpleNet` class. + A neural network module; instance of the `mpol.precomposed.SimpleNet` class. dataset : PyTorch dataset object Instance of the `mpol.datasets.GriddedDataset` class. @@ -281,6 +358,7 @@ def test(self, model, dataset): # return loss value return loss.item() + @property def regularizers(self): """Dict containing regularizers used and their strengths""" @@ -290,8 +368,3 @@ def regularizers(self): def train_figure(self): """(fig, axes) of figure showing training diagnostics""" return self._train_figure - - @property - def train_state(self): - """Dict containing parameters of interest in the training loop""" - return self._train_state \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 39eee5e6..dc405fb1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -208,7 +208,7 @@ def crossvalidation_products(mock_visibility_data): def generic_parameters(tmp_path): # generic model parameters to test training loop and cross-val loop regularizers = { - "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}, + "entropy": {"lambda":1e-3, "guess":False, "prior_intensity":1e-10}, } train_pars = {"epochs":15, "convergence_tol":1e-3, diff --git a/test/train_test_test.py b/test/train_test_test.py index dffb7aad..9826d9fc 100644 --- a/test/train_test_test.py +++ b/test/train_test_test.py @@ -6,8 +6,8 @@ from mpol import losses, precomposed from mpol.plot import train_diagnostics_fig -from mpol.training import TrainTest -from mpol.constants import * +from mpol.training import TrainTest, train_to_dirty_image +from mpol.utils import torch2npy def test_traintestclass_training(coords, imager, dataset, generic_parameters): @@ -16,7 +16,8 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters): model = precomposed.SimpleNet(coords=coords, nchan=nchan) train_pars = generic_parameters["train_pars"] - # bypass TrainTest.loss_lambda_guess + + # no regularizers train_pars["regularizers"] = {} learn_rate = generic_parameters["crossval_pars"]["learn_rate"] @@ -27,6 +28,26 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters): loss, loss_history = trainer.train(model, dataset) +def test_traintestclass_training_scheduler(coords, imager, dataset, generic_parameters): + # using the TrainTest class, run a training loop with regularizers, + # using the learning rate scheduler + nchan = dataset.nchan + model = precomposed.SimpleNet(coords=coords, nchan=nchan) + + train_pars = generic_parameters["train_pars"] + + learn_rate = generic_parameters["crossval_pars"]["learn_rate"] + + optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) + + # use a scheduler + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.995) + train_pars["scheduler"] = scheduler + + trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) + loss, loss_history = trainer.train(model, dataset) + + def test_traintestclass_training_guess(coords, imager, dataset, generic_parameters): # using the TrainTest class, run a training loop with regularizers, # with a call to the regularizer strength guesser @@ -37,6 +58,8 @@ def test_traintestclass_training_guess(coords, imager, dataset, generic_paramete learn_rate = generic_parameters["crossval_pars"]["learn_rate"] + train_pars['regularizers']['entropy']['guess'] = True + optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) @@ -60,8 +83,16 @@ def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_p trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) loss, loss_history = trainer.train(model, dataset) - train_state = trainer.train_state - train_fig, train_axes = train_diagnostics_fig(model, losses=loss_history, train_state=train_state) + learn_rates = np.repeat(learn_rate, len(loss_history)) + + old_mod_im = torch2npy(model.icube.sky_cube[0]) + + train_fig, train_axes = train_diagnostics_fig(model, + losses=loss_history, + learn_rates=learn_rates, + fluxes=np.zeros(len(loss_history)), + old_model_image=old_mod_im + ) train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300) plt.close("all") @@ -136,6 +167,14 @@ def test_standalone_train_loop(coords, dataset_cont, tmp_path): plt.close("all") +def test_train_to_dirty_image(coords, dataset, imager): + # run a training loop against a dirty image + nchan = dataset.nchan + model = precomposed.SimpleNet(coords=coords, nchan=nchan) + + train_to_dirty_image(model, imager, niter=10) + + def test_tensorboard(coords, dataset_cont, tmp_path): # not using TrainTest class, # set everything up to run on a single channel and then