Skip to content

Commit

Permalink
Merge branch 'main' into crossval_image_diag_fig
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings authored Nov 30, 2023
2 parents cac0a4b + 14a76e5 commit b2996b8
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 121 deletions.
151 changes: 91 additions & 60 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mpol.datasets import Dartboard, GriddedDataset
from mpol.precomposed import SimpleNet
from mpol.training import TrainTest
from mpol.training import TrainTest, train_to_dirty_image
from mpol.plot import split_diagnostics_fig


Expand All @@ -28,19 +28,8 @@ class CrossValidate:
Instance of the `mpol.coordinates.GridCoords` class.
imager : `mpol.gridding.DirtyImager` object
Instance of the `mpol.gridding.DirtyImager` class.
kfolds : int, default=5
Number of k-folds to use in cross-validation
split_method : str, default='random_cell'
Method to split full dataset into train/test subsets
seed : int, default=None
Seed for random number generator used in splitting data
learn_rate : float, default=0.5
Neural network learning rate
epochs : int, default=10000
Number of training iterations
convergence_tol : float, default=1e-3
Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
learn_rate : float, default=0.3
Initial learning rate
regularizers : nested dict, default={}
Dictionary of image regularizers to use. For each, a dict of the
strength ('lambda', float), whether to guess an initial value for lambda
Expand All @@ -49,9 +38,25 @@ class CrossValidate:
{"sparsity":{"lambda":1e-3, "guess":False},
"entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}
}
epochs : int, default=10000
Number of training iterations
convergence_tol : float, default=1e-5
Tolerance for training iteration stopping criterion as assessed by
loss function (suggested <= 1e-3)
schedule_factor : float, default=0.995
For the `torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, factor
to which the learning rate is reduced when learning rate stops decreasing
start_dirty_image : bool, default=False
Whether to start the RML optimization loop by initializing the model
image to a dirty image of the observed data. If False, the optimization
loop will start with a blank image.
train_diag_step : int, default=None
Interval at which training diagnostics are output. If None, no
diagnostics will be generated.
kfolds : int, default=5
Number of k-folds to use in cross-validation
split_method : str, default='random_cell'
Method to split full dataset into train/test subsets
split_diag_fig : bool, default=False
Whether to generate a diagnostic figure of dataset splitting into
train/test sets.
Expand All @@ -60,39 +65,45 @@ class CrossValidate:
save_prefix : str, default=None
Prefix (path) used for saved figure names. If None, figures won't be
saved
device : torch.device, default=None
Which hardware device to perform operations on (e.g., 'cuda:0').
'None' defaults to current device.
verbose : bool, default=True
Whether to print notification messages.
device : torch.device, default=None
Which hardware device to perform operations on (e.g., 'cuda:0').
'None' defaults to current device.
seed : int, default=None
Seed for random number generator used in splitting data
"""

def __init__(self, coords, imager, kfolds=5, split_method="random_cell",
seed=None, learn_rate=0.5, epochs=10000, convergence_tol=1e-3,
regularizers={}, train_diag_step=None, split_diag_fig=False,
store_cv_diagnostics=False, save_prefix=None, device=None,
verbose=True
def __init__(self, coords, imager, learn_rate=0.3,
regularizers={}, epochs=10000, convergence_tol=1e-5,
schedule_factor=0.995,
start_dirty_image=False,
train_diag_step=None, kfolds=5, split_method="random_cell",
split_diag_fig=False, store_cv_diagnostics=False,
save_prefix=None, verbose=True, device=None, seed=None,
):
self._coords = coords
self._imager = imager
self._kfolds = kfolds
self._split_method = split_method
self._seed = seed
self._learn_rate = learn_rate
self._regularizers = regularizers
self._epochs = epochs
self._convergence_tol = convergence_tol
self._regularizers = regularizers
self._schedule_factor = schedule_factor
self._start_dirty_image = start_dirty_image
self._train_diag_step = train_diag_step
self._kfolds = kfolds
self._split_method = split_method
self._split_diag_fig = split_diag_fig
self._store_cv_diagnostics = store_cv_diagnostics
self._save_prefix = save_prefix
self._device = device
self._verbose = verbose
self._device = device
self._seed = seed

self._model = None
self._diagnostics = None
self._split_figure = None
self._train_figure = None

# used to collect objects across all kfolds
self._diagnostics = None

def split_dataset(self, dataset):
r"""
Expand Down Expand Up @@ -159,22 +170,38 @@ def run_crossval(self, dataset):
for kk, (train_set, test_set) in enumerate(split_iterator):
if self._verbose:
logging.info(
"\nCross-validation: k-fold {} of {}".format(kk, self._kfolds)
"\nCross-validation: k-fold {} of {}".format(kk, self._kfolds - 1)
)

# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# train_set, test_set = train_set.to(self._device), test_set.to(self._device)

# create a new model and optimizer for this k_fold
self._model = SimpleNet(coords=self._coords, nchan=self._imager.nchan)
# if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
# self._model = self._model.to(self._device)

optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learn_rate)

trainer = TrainTest(
model = SimpleNet(coords=self._coords, nchan=self._imager.nchan)
if self._start_dirty_image is True:
if kk == 0:
if self._verbose:
logging.info(
"\n Pre-training to dirty image to initialize subsequent optimization loops"
)
# initial short training loop to get model image to approximate dirty image
model_pretrained = train_to_dirty_image(model=model, imager=self._imager)
# save the model to a state we can load in subsequent kfolds
torch.save(model_pretrained.state_dict(), f=self._save_prefix + "_dirty_image_model.pt")
else:
# create a new model for this kfold, initializing it to the model pretrained on the dirty image
model.load_state_dict(torch.load(self._save_prefix + "_dirty_image_model.pt"))

# create a new optimizer and scheduler for this kfold
optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate)
if self._schedule_factor is None:
scheduler = None
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self._schedule_factor)

trainer = TrainTest(
imager=self._imager,
optimizer=optimizer,
scheduler=scheduler,
epochs=self._epochs,
convergence_tol=self._convergence_tol,
regularizers=self._regularizers,
Expand All @@ -185,17 +212,21 @@ def run_crossval(self, dataset):
)

# run training
loss, loss_history = trainer.train(self._model, train_set)
loss, loss_history = trainer.train(model, train_set)

if self._store_cv_diagnostics:
self._diagnostics["loss_histories"].append(loss_history)
# update regularizer strength values
self._regularizers = trainer.regularizers
# store the most recent train figure for diagnostics
self._train_figure = trainer.train_figure

# run testing
all_scores.append(trainer.test(self._model, test_set))
all_scores.append(trainer.test(model, test_set))

# store objects from the most recent kfold for diagnostics
self._model = model
self._train_figure = trainer.train_figure

# collect objects from this kfold to store
if self._store_cv_diagnostics:
self._diagnostics["models"].append(self._model)
self._diagnostics["regularizers"].append(self._regularizers)
self._diagnostics["loss_histories"].append(loss_history)
self._diagnostics["train_figures"].append(self._train_figure)

# average individual test scores to get the cross-val metric for chosen
# hyperparameters
Expand All @@ -204,33 +235,33 @@ def run_crossval(self, dataset):
"std": np.std(all_scores),
"all": all_scores,
}

return cv_score

@property
def model(self):
"""SimpleNet class instance"""
"""For the most recent kfold, trained model (`SimpleNet` class instance)"""
return self._model

@property
def regularizers(self):
"""Dict containing regularizers used and their strengths"""
"""For the most recent kfold, dict containing regularizers used and their strengths"""
return self._regularizers

@property
def diagnostics(self):
"""Dict containing diagnostics of the cross-validation loop"""
return self._diagnostics

def train_figure(self):
"""For the most recent kfold, (fig, axes) showing training progress"""
return self._train_figure
@property
def split_figure(self):
"""(fig, axes) of train/test splitting diagnostic figure"""
return self._split_figure

@property
def train_figure(self):
"""(fig, axes) of most recent training diagnostic figure"""
return self._train_figure
def diagnostics(self):
"""Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values, training figures"""
return self._diagnostics


class RandomCellSplitGridded:
Expand Down
Loading

0 comments on commit b2996b8

Please sign in to comment.