Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learn rate scheduler #207

Merged
merged 18 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,8 @@ class CrossValidate:
Instance of the `mpol.coordinates.GridCoords` class.
imager : `mpol.gridding.DirtyImager` object
Instance of the `mpol.gridding.DirtyImager` class.
kfolds : int, default=5
Number of k-folds to use in cross-validation
split_method : str, default='random_cell'
Method to split full dataset into train/test subsets
seed : int, default=None
Seed for random number generator used in splitting data
learn_rate : float, default=0.5
Neural network learning rate
epochs : int, default=10000
Number of training iterations
convergence_tol : float, default=1e-3
Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
learn_rate : float, default=0.3
Initial learning rate
regularizers : nested dict, default={}
Dictionary of image regularizers to use. For each, a dict of the
strength ('lambda', float), whether to guess an initial value for lambda
Expand All @@ -49,13 +38,25 @@ class CrossValidate:
{"sparsity":{"lambda":1e-3, "guess":False},
"entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}
}
epochs : int, default=10000
Number of training iterations
convergence_tol : float, default=1e-5
Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
schedule_factor : float, default=0.995
For the `torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, factor
to which the learning rate is reduced when learning rate stops decreasing
start_dirty_image : bool, default=False
Whether to start the RML optimization loop by initializing the model
image to a dirty image of the observed data. If False, the optimization
loop will start with a blank image.
train_diag_step : int, default=None
Interval at which training diagnostics are output. If None, no
diagnostics will be generated.
kfolds : int, default=5
Number of k-folds to use in cross-validation
split_method : str, default='random_cell'
Method to split full dataset into train/test subsets
split_diag_fig : bool, default=False
Whether to generate a diagnostic figure of dataset splitting into
train/test sets.
Expand All @@ -64,36 +65,40 @@ class CrossValidate:
save_prefix : str, default=None
Prefix (path) used for saved figure names. If None, figures won't be
saved
device : torch.device, default=None
Which hardware device to perform operations on (e.g., 'cuda:0').
'None' defaults to current device.
verbose : bool, default=True
Whether to print notification messages.
device : torch.device, default=None
Which hardware device to perform operations on (e.g., 'cuda:0').
'None' defaults to current device.
seed : int, default=None
Seed for random number generator used in splitting data
"""

def __init__(self, coords, imager, learn_rate=0.5,
regularizers={}, epochs=10000, convergence_tol=1e-3,
def __init__(self, coords, imager, learn_rate=0.3,
regularizers={}, epochs=10000, convergence_tol=1e-5,
schedule_factor=0.995,
start_dirty_image=False,
train_diag_step=None, kfolds=5, split_method="random_cell",
split_diag_fig=False, store_cv_diagnostics=False,
save_prefix=None, verbose=True, device=None, seed=None,
):
self._coords = coords
self._imager = imager
self._kfolds = kfolds
self._split_method = split_method
self._seed = seed
self._learn_rate = learn_rate
self._regularizers = regularizers
self._epochs = epochs
self._convergence_tol = convergence_tol
self._regularizers = regularizers
self._schedule_factor = schedule_factor
self._start_dirty_image = start_dirty_image
self._train_diag_step = train_diag_step
self._kfolds = kfolds
self._split_method = split_method
self._split_diag_fig = split_diag_fig
self._store_cv_diagnostics = store_cv_diagnostics
self._save_prefix = save_prefix
self._device = device
self._verbose = verbose
self._device = device
self._seed = seed

self._split_figure = None

Expand Down Expand Up @@ -186,10 +191,17 @@ def run_crossval(self, dataset):
# create a new model for this kfold, initializing it to the model pretrained on the dirty image
model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt"))

# create a new optimizer and scheduler for this kfold
optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate)
if self._schedule_factor is None:
scheduler = None
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self._schedule_factor)

trainer = TrainTest(
imager=self._imager,
optimizer=optimizer,
scheduler=scheduler,
epochs=self._epochs,
convergence_tol=self._convergence_tol,
regularizers=self._regularizers,
Expand Down
38 changes: 20 additions & 18 deletions src/mpol/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,38 @@ class TrainTest:
Args:
imager (:class:`mpol.gridding.DirtyImager` object): Instance of the `mpol.gridding.DirtyImager` class.
optimizer (:class:`torch.optim` object): PyTorch optimizer class for the training loop.
epochs (int): Number of training iterations, default=10000
convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
scheduler (:class:`torch.optim.lr_scheduler` object, default=None): Scheduler for adjusting learning rate during optimization.
regularizers (nested dict): Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda ('guess', bool), and other quantities needed to compute their loss term.

Example:
``{"sparsity":{"lambda":1e-3, "guess":False},
"entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}
}``

epochs (int): Number of training iterations, default=10000
convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
train_diag_step (int): Interval at which training diagnostics are output. If None, no diagnostics will be generated.
kfold (int): The k-fold of the current training set (for diagnostics)
save_prefix (str): Prefix (path) used for saved figure names. If None, figures won't be saved
verbose (bool): Whether to print notification messages
"""

def __init__(self, imager, optimizer, epochs=10000, convergence_tol=1e-3,
regularizers={}, train_diag_step=None, kfold=None,
save_prefix=None, verbose=True
):
def __init__(self, imager, optimizer, scheduler=None, regularizers={},
epochs=10000, convergence_tol=1e-5,
train_diag_step=None,
kfold=None, save_prefix=None, verbose=True
):
self._imager = imager
self._optimizer = optimizer
self._optimizer = optimizer
self._scheduler = scheduler
self._regularizers = regularizers
self._epochs = epochs
self._convergence_tol = convergence_tol
self._regularizers = regularizers

self._train_diag_step = train_diag_step
self._save_prefix = save_prefix
self._kfold = kfold
self._save_prefix = save_prefix
self._verbose = verbose

self._train_figure = None
Expand Down Expand Up @@ -200,7 +204,7 @@ def loss_eval(self, vis, dataset, sky_cube=None):
loss += self._regularizers['TSV']['lambda'] * TSV(sky_cube)

return loss


def train(self, model, dataset):
r"""
Expand All @@ -227,7 +231,7 @@ def train(self, model, dataset):

count = 0
losses = []
self._train_state = {}
learn_rates = []

# guess initial strengths for regularizers in `self._regularizers`
# that have 'guess':True
Expand All @@ -246,7 +250,7 @@ def train(self, model, dataset):
)

# check early-on whether the loss isn't evolving
if count == 20:
if count == 10:
loss_arr = np.array(losses)
if all(0.9 <= loss_arr[:-1] / loss_arr[1:]) and all(
loss_arr[:-1] / loss_arr[1:] <= 1.1
Expand Down Expand Up @@ -277,11 +281,9 @@ def train(self, model, dataset):
# update model parameters via gradient descent
self._optimizer.step()

# store current training parameter values
# TODO: store hyperpar values, access in crossval.py
self._train_state["kfold"] = self._kfold
self._train_state["epoch"] = count
self._train_state["learn_rate"] = self._optimizer.state_dict()['param_groups'][0]['lr']
if self._scheduler is not None:
self._scheduler.step(loss)
learn_rates.append(self._optimizer.param_groups[0]['lr'])

# generate optional fit diagnostics
if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))):
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def crossvalidation_products(mock_visibility_data):
def generic_parameters(tmp_path):
# generic model parameters to test training loop and cross-val loop
regularizers = {
"entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10},
"entropy": {"lambda":1e-3, "guess":False, "prior_intensity":1e-10},
}

train_pars = {"epochs":15, "convergence_tol":1e-3,
Expand Down
30 changes: 27 additions & 3 deletions test/train_test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters):
model = precomposed.SimpleNet(coords=coords, nchan=nchan)

train_pars = generic_parameters["train_pars"]
# bypass TrainTest.loss_lambda_guess

# no regularizers
train_pars["regularizers"] = {}

learn_rate = generic_parameters["crossval_pars"]["learn_rate"]
Expand All @@ -27,6 +28,26 @@ def test_traintestclass_training(coords, imager, dataset, generic_parameters):
loss, loss_history = trainer.train(model, dataset)


def test_traintestclass_training_scheduler(coords, imager, dataset, generic_parameters):
# using the TrainTest class, run a training loop with regularizers,
# using the learning rate scheduler
nchan = dataset.nchan
model = precomposed.SimpleNet(coords=coords, nchan=nchan)

train_pars = generic_parameters["train_pars"]

learn_rate = generic_parameters["crossval_pars"]["learn_rate"]

optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

# use a scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.995)
train_pars["scheduler"] = scheduler

trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars)
loss, loss_history = trainer.train(model, dataset)


def test_traintestclass_training_guess(coords, imager, dataset, generic_parameters):
# using the TrainTest class, run a training loop with regularizers,
# with a call to the regularizer strength guesser
Expand All @@ -37,6 +58,8 @@ def test_traintestclass_training_guess(coords, imager, dataset, generic_paramete

learn_rate = generic_parameters["crossval_pars"]["learn_rate"]

train_pars['regularizers']['entropy']['guess'] = True

optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars)
Expand All @@ -60,8 +83,9 @@ def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_p
trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars)
loss, loss_history = trainer.train(model, dataset)

train_state = trainer.train_state
train_fig, train_axes = train_diagnostics_fig(model, losses=loss_history, train_state=train_state)
train_fig, train_axes = train_diagnostics_fig(model,
losses=loss_history,
)
train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300)
plt.close("all")

Expand Down