diff --git a/handwriting_synthesis/tf/BaseModel.py b/handwriting_synthesis/tf/BaseModel.py index 5c2496c..1903868 100644 --- a/handwriting_synthesis/tf/BaseModel.py +++ b/handwriting_synthesis/tf/BaseModel.py @@ -133,11 +133,11 @@ def __init__( self.checkpoint_dir_averaged = checkpoint_dir + '_avg' self.init_logging(self.log_dir) - logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__))) + logging.info('\nNew run with parameters:\n{}'.format(pp.pformat(self.__dict__))) self.graph = self.build_graph() self.session = tfcompat.Session(graph=self.graph) - logging.info('built graph') + logging.info('Built Graph') def update_train_params(self): self.batch_size = self.batch_sizes[self.restart_idx] @@ -146,7 +146,7 @@ def update_train_params(self): self.early_stopping_steps = self.patiences[self.restart_idx] def calculate_loss(self): - raise NotImplementedError('subclass must implement this') + raise NotImplementedError('Subclass must implement this.') def fit(self): with self.session.as_default(): @@ -170,6 +170,8 @@ def fit(self): metric_name: deque(maxlen=self.loss_averaging_window) for metric_name in self.metrics } best_validation_loss, best_validation_tstep = float('inf'), 0 + checkpoint_created=False + while step < self.num_training_steps: @@ -260,31 +262,42 @@ def fit(self): logging.info(metric_log) + # Save the best step. if early_stopping_metric < best_validation_loss: + logging.info('Updating best validation loss {} with early stopping metric {}.'.format(round(best_validation_loss,4),round(early_stopping_metric,4))) best_validation_loss = early_stopping_metric best_validation_tstep = step + # Take a snapshot if the minimum number of steps have been reached. if step > self.min_steps_to_checkpoint: self.save(step) if self.enable_parameter_averaging: self.save(step, averaged=True) + checkpoint_created=True + # Stop training early and either restart with tigher training parameters or finish entirely. if step - best_validation_tstep > self.early_stopping_steps: - + logging.info('Stopping early at step {}: Best Validation Step: {} Early Stopping Steps: {}'.format(step, best_validation_tstep, self.early_stopping_steps)) if self.num_restarts is None or self.restart_idx >= self.num_restarts: - logging.info('best validation loss of {} at training step {}'.format( - best_validation_loss, best_validation_tstep)) - logging.info('early stopping - ending training.') + logging.info('Best validation loss of {} at training step {}'.format(best_validation_loss, best_validation_tstep)) + logging.info('Early stopping - ending training.') return - if self.restart_idx < self.num_restarts: - self.restore(best_validation_tstep) - step = best_validation_tstep - self.restart_idx += 1 - self.update_train_params() - train_generator = self.reader.train_batch_generator(self.batch_size) + #Restart the training with tighter parameters if we have remaining restarts and a checkpoint has been created. + if self.restart_idx < self.num_restarts and checkpoint_created: + logging.info('Restarting for the {} time out of {} total restarts.'.format(self.restart_idx, self.num_restarts)) + try: + self.restore(best_validation_tstep) + except Exception as error: + logging.warn('Failed to restore checkpoint; will continue training: {} - {}'.format(type(error).__name__, error)) + else: + step = best_validation_tstep + self.restart_idx += 1 + self.update_train_params() + train_generator = self.reader.train_batch_generator(self.batch_size) step += 1 + #Make sure at least one model gets saved. if step <= self.min_steps_to_checkpoint: # best_validation_tstep = step self.save(step) @@ -398,13 +411,13 @@ def update_parameters(self, loss): else: self.step = step - logging.info('all parameters:') + logging.info('All parameters:') logging.info(pp.pformat([(var.name, shape(var)) for var in tfcompat.global_variables()])) - logging.info('trainable parameters:') + logging.info('Trainable parameters:') logging.info(pp.pformat([(var.name, shape(var)) for var in tfcompat.trainable_variables()])) - logging.info('trainable parameter count:') + logging.info('Trainable parameter count:') logging.info(str(np.sum(np.prod(shape(var)) for var in tfcompat.trainable_variables()))) def get_optimizer(self, learning_rate, beta1_decay): @@ -415,7 +428,7 @@ def get_optimizer(self, learning_rate, beta1_decay): elif self.optimizer == 'rms': return tfcompat.train.RMSPropOptimizer(learning_rate, decay=beta1_decay, momentum=0.9) else: - assert False, 'optimizer must be adam, gd, or rms' + assert False, 'Optimizer must be adam, gd, or rms' def build_graph(self): with tf.Graph().as_default() as graph: