diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 25fbc049..9d8b59da 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -12,7 +12,7 @@ from mpol.datasets import Dartboard, GriddedDataset from mpol.precomposed import SimpleNet -from mpol.training import TrainTest +from mpol.training import TrainTest, train_to_dirty_image from mpol.plot import split_diagnostics_fig @@ -49,6 +49,10 @@ class CrossValidate: {"sparsity":{"lambda":1e-3, "guess":False}, "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} } + start_dirty_image : bool, default=False + Whether to start the RML optimization loop by initializing the model + image to a dirty image of the observed data. If False, the optimization + loop will start with a blank image. train_diag_step : int, default=None Interval at which training diagnostics are output. If None, no diagnostics will be generated. @@ -67,11 +71,12 @@ class CrossValidate: Whether to print notification messages. """ - def __init__(self, coords, imager, kfolds=5, split_method="random_cell", - seed=None, learn_rate=0.5, epochs=10000, convergence_tol=1e-3, - regularizers={}, train_diag_step=None, split_diag_fig=False, - store_cv_diagnostics=False, save_prefix=None, device=None, - verbose=True + def __init__(self, coords, imager, learn_rate=0.5, + regularizers={}, epochs=10000, convergence_tol=1e-3, + start_dirty_image=False, + train_diag_step=None, kfolds=5, split_method="random_cell", + split_diag_fig=False, store_cv_diagnostics=False, + save_prefix=None, verbose=True, device=None, seed=None, ): self._coords = coords self._imager = imager @@ -82,6 +87,7 @@ def __init__(self, coords, imager, kfolds=5, split_method="random_cell", self._epochs = epochs self._convergence_tol = convergence_tol self._regularizers = regularizers + self._start_dirty_image = start_dirty_image self._train_diag_step = train_diag_step self._split_diag_fig = split_diag_fig self._store_cv_diagnostics = store_cv_diagnostics @@ -89,10 +95,10 @@ def __init__(self, coords, imager, kfolds=5, split_method="random_cell", self._device = device self._verbose = verbose - self._model = None - self._diagnostics = None self._split_figure = None - self._train_figure = None + + # used to collect objects across all kfolds + self._diagnostics = None def split_dataset(self, dataset): r""" @@ -165,14 +171,23 @@ def run_crossval(self, dataset): # if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu # train_set, test_set = train_set.to(self._device), test_set.to(self._device) - # create a new model and optimizer for this k_fold - self._model = SimpleNet(coords=self._coords, nchan=self._imager.nchan) - # if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu - # self._model = self._model.to(self._device) - - optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learn_rate) - - trainer = TrainTest( + model = SimpleNet(coords=self._coords, nchan=self._imager.nchan) + if self._start_dirty_image is True: + if kk == 0: + if self._verbose: + logging.info( + "\n Pre-training to dirty image to initialize subsequent optimization loops" + ) + # initial short training loop to get model image to approximate dirty image + model_pretrained = train_to_dirty_image(model=model, imager=self._imager) + # save the model to a state we can load in subsequent kfolds + torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt") + else: + # create a new model for this kfold, initializing it to the model pretrained on the dirty image + model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt")) + + optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate) + trainer = TrainTest( imager=self._imager, optimizer=optimizer, epochs=self._epochs, @@ -185,8 +200,13 @@ def run_crossval(self, dataset): ) # run training - loss, loss_history = trainer.train(self._model, train_set) + loss, loss_history = trainer.train(model, train_set) + + # run testing + all_scores.append(trainer.test(model, test_set)) + # store objects from the most recent kfold for diagnostics + self._model = model if self._store_cv_diagnostics: self._diagnostics["loss_histories"].append(loss_history) # update regularizer strength values diff --git a/src/mpol/training.py b/src/mpol/training.py index 22223954..518ccf3a 100644 --- a/src/mpol/training.py +++ b/src/mpol/training.py @@ -4,6 +4,58 @@ from mpol.losses import TSV, TV_image, entropy, nll_gridded, sparsity from mpol.plot import train_diagnostics_fig +def train_to_dirty_image(model, imager, robust=0.5, learn_rate=100, niter=1000): + r""" + Train against a dirty image of the observed visibilities using a loss function + of the mean squared error between the RML model image pixel fluxes and the + dirty image pixel fluxes. Useful for initializing a separate RML optimization + loop at a reasonable starting image. + + Parameters + ---------- + model : `torch.nn.Module` object + A neural network module; instance of the `mpol.precomposed.SimpleNet` class. + imager : :class:`mpol.gridding.DirtyImager` object + Instance of the `mpol.gridding.DirtyImager` class. + robust : float, default=0.5 + Robust weighting parameter used to create a dirty image. + learn_rate : float, default=100 + Learning rate for optimization loop + niter : int, default=1000 + Number of iterations for optimization loop + + Returns + ------- + model : `torch.nn.Module` object + The input `model` updated with the state of the training to the + dirty image + """ + logging.info(" Initializing model to dirty image") + + img, beam = imager.get_dirty_image(weighting="briggs", + robust=robust, + unit="Jy/arcsec^2") + dirty_image = torch.tensor(img.copy()) + optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate) + + losses = [] + for ii in range(niter): + optimizer.zero_grad() + + model() + + sky_cube = model.icube.sky_cube + + lossfunc = torch.nn.MSELoss(reduction="sum") + # MSELoss calculates mean squared error (squared L2 norm), so sqrt it + loss = (lossfunc(sky_cube, dirty_image)) ** 0.5 + losses.append(loss.item()) + + loss.backward() + optimizer.step() + + return model + class TrainTest: r""" diff --git a/test/train_test_test.py b/test/train_test_test.py index dffb7aad..fb1894b3 100644 --- a/test/train_test_test.py +++ b/test/train_test_test.py @@ -6,7 +6,7 @@ from mpol import losses, precomposed from mpol.plot import train_diagnostics_fig -from mpol.training import TrainTest +from mpol.training import TrainTest, train_to_dirty_image from mpol.constants import * @@ -136,6 +136,14 @@ def test_standalone_train_loop(coords, dataset_cont, tmp_path): plt.close("all") +def test_train_to_dirty_image(coords, dataset, imager): + # run a training loop against a dirty image + nchan = dataset.nchan + model = precomposed.SimpleNet(coords=coords, nchan=nchan) + + train_to_dirty_image(model, imager, niter=10) + + def test_tensorboard(coords, dataset_cont, tmp_path): # not using TrainTest class, # set everything up to run on a single channel and then