Skip to content

Commit

Permalink
Merge pull request #206 from MPoL-dev/train_dirty_initial
Browse files Browse the repository at this point in the history
Add utility for training to dirty image and integrate into cross-val
  • Loading branch information
jeffjennings authored Nov 30, 2023
2 parents 6c6fd68 + 01858b2 commit 5826b92
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 19 deletions.
56 changes: 38 additions & 18 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mpol.datasets import Dartboard, GriddedDataset
from mpol.precomposed import SimpleNet
from mpol.training import TrainTest
from mpol.training import TrainTest, train_to_dirty_image
from mpol.plot import split_diagnostics_fig


Expand Down Expand Up @@ -49,6 +49,10 @@ class CrossValidate:
{"sparsity":{"lambda":1e-3, "guess":False},
"entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}
}
start_dirty_image : bool, default=False
Whether to start the RML optimization loop by initializing the model
image to a dirty image of the observed data. If False, the optimization
loop will start with a blank image.
train_diag_step : int, default=None
Interval at which training diagnostics are output. If None, no
diagnostics will be generated.
Expand All @@ -67,11 +71,12 @@ class CrossValidate:
Whether to print notification messages.
"""

def __init__(self, coords, imager, kfolds=5, split_method="random_cell",
seed=None, learn_rate=0.5, epochs=10000, convergence_tol=1e-3,
regularizers={}, train_diag_step=None, split_diag_fig=False,
store_cv_diagnostics=False, save_prefix=None, device=None,
verbose=True
def __init__(self, coords, imager, learn_rate=0.5,
regularizers={}, epochs=10000, convergence_tol=1e-3,
start_dirty_image=False,
train_diag_step=None, kfolds=5, split_method="random_cell",
split_diag_fig=False, store_cv_diagnostics=False,
save_prefix=None, verbose=True, device=None, seed=None,
):
self._coords = coords
self._imager = imager
Expand All @@ -82,17 +87,18 @@ def __init__(self, coords, imager, kfolds=5, split_method="random_cell",
self._epochs = epochs
self._convergence_tol = convergence_tol
self._regularizers = regularizers
self._start_dirty_image = start_dirty_image
self._train_diag_step = train_diag_step
self._split_diag_fig = split_diag_fig
self._store_cv_diagnostics = store_cv_diagnostics
self._save_prefix = save_prefix
self._device = device
self._verbose = verbose

self._model = None
self._diagnostics = None
self._split_figure = None
self._train_figure = None

# used to collect objects across all kfolds
self._diagnostics = None

def split_dataset(self, dataset):
r"""
Expand Down Expand Up @@ -165,14 +171,23 @@ def run_crossval(self, dataset):
# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# train_set, test_set = train_set.to(self._device), test_set.to(self._device)

# create a new model and optimizer for this k_fold
self._model = SimpleNet(coords=self._coords, nchan=self._imager.nchan)
# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# self._model = self._model.to(self._device)

optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learn_rate)

trainer = TrainTest(
model = SimpleNet(coords=self._coords, nchan=self._imager.nchan)
if self._start_dirty_image is True:
if kk == 0:
if self._verbose:
logging.info(
"\n Pre-training to dirty image to initialize subsequent optimization loops"
)
# initial short training loop to get model image to approximate dirty image
model_pretrained = train_to_dirty_image(model=model, imager=self._imager)
# save the model to a state we can load in subsequent kfolds
torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt")
else:
# create a new model for this kfold, initializing it to the model pretrained on the dirty image
model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt"))

optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate)
trainer = TrainTest(
imager=self._imager,
optimizer=optimizer,
epochs=self._epochs,
Expand All @@ -185,8 +200,13 @@ def run_crossval(self, dataset):
)

# run training
loss, loss_history = trainer.train(self._model, train_set)
loss, loss_history = trainer.train(model, train_set)

# run testing
all_scores.append(trainer.test(model, test_set))

# store objects from the most recent kfold for diagnostics
self._model = model
if self._store_cv_diagnostics:
self._diagnostics["loss_histories"].append(loss_history)
# update regularizer strength values
Expand Down
52 changes: 52 additions & 0 deletions src/mpol/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,58 @@

from mpol.losses import TSV, TV_image, entropy, nll_gridded, sparsity
from mpol.plot import train_diagnostics_fig
def train_to_dirty_image(model, imager, robust=0.5, learn_rate=100, niter=1000):
r"""
Train against a dirty image of the observed visibilities using a loss function
of the mean squared error between the RML model image pixel fluxes and the
dirty image pixel fluxes. Useful for initializing a separate RML optimization
loop at a reasonable starting image.
Parameters
----------
model : `torch.nn.Module` object
A neural network module; instance of the `mpol.precomposed.SimpleNet` class.
imager : :class:`mpol.gridding.DirtyImager` object
Instance of the `mpol.gridding.DirtyImager` class.
robust : float, default=0.5
Robust weighting parameter used to create a dirty image.
learn_rate : float, default=100
Learning rate for optimization loop
niter : int, default=1000
Number of iterations for optimization loop
Returns
-------
model : `torch.nn.Module` object
The input `model` updated with the state of the training to the
dirty image
"""
logging.info(" Initializing model to dirty image")

img, beam = imager.get_dirty_image(weighting="briggs",
robust=robust,
unit="Jy/arcsec^2")
dirty_image = torch.tensor(img.copy())
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

losses = []
for ii in range(niter):
optimizer.zero_grad()

model()

sky_cube = model.icube.sky_cube

lossfunc = torch.nn.MSELoss(reduction="sum")
# MSELoss calculates mean squared error (squared L2 norm), so sqrt it
loss = (lossfunc(sky_cube, dirty_image)) ** 0.5
losses.append(loss.item())

loss.backward()
optimizer.step()

return model


class TrainTest:
r"""
Expand Down
10 changes: 9 additions & 1 deletion test/train_test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from mpol import losses, precomposed
from mpol.plot import train_diagnostics_fig
from mpol.training import TrainTest
from mpol.training import TrainTest, train_to_dirty_image
from mpol.constants import *


Expand Down Expand Up @@ -136,6 +136,14 @@ def test_standalone_train_loop(coords, dataset_cont, tmp_path):
plt.close("all")


def test_train_to_dirty_image(coords, dataset, imager):
# run a training loop against a dirty image
nchan = dataset.nchan
model = precomposed.SimpleNet(coords=coords, nchan=nchan)

train_to_dirty_image(model, imager, niter=10)


def test_tensorboard(coords, dataset_cont, tmp_path):
# not using TrainTest class,
# set everything up to run on a single channel and then
Expand Down

0 comments on commit 5826b92

Please sign in to comment.