Skip to content

Commit

Permalink
Merge pull request #16 from bryandam/master
Browse files Browse the repository at this point in the history
BaseModel.Fit: Add Checkpoint Create Bool
  • Loading branch information
otuva authored Oct 2, 2023
2 parents 8784848 + 31ae173 commit cf5ec03
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions handwriting_synthesis/tf/BaseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def __init__(
self.checkpoint_dir_averaged = checkpoint_dir + '_avg'

self.init_logging(self.log_dir)
logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__)))
logging.info('\nNew run with parameters:\n{}'.format(pp.pformat(self.__dict__)))

self.graph = self.build_graph()
self.session = tfcompat.Session(graph=self.graph)
logging.info('built graph')
logging.info('Built Graph')

def update_train_params(self):
self.batch_size = self.batch_sizes[self.restart_idx]
Expand All @@ -146,7 +146,7 @@ def update_train_params(self):
self.early_stopping_steps = self.patiences[self.restart_idx]

def calculate_loss(self):
raise NotImplementedError('subclass must implement this')
raise NotImplementedError('Subclass must implement this.')

def fit(self):
with self.session.as_default():
Expand All @@ -170,6 +170,8 @@ def fit(self):
metric_name: deque(maxlen=self.loss_averaging_window) for metric_name in self.metrics
}
best_validation_loss, best_validation_tstep = float('inf'), 0
checkpoint_created=False


while step < self.num_training_steps:

Expand Down Expand Up @@ -260,31 +262,42 @@ def fit(self):

logging.info(metric_log)

# Save the best step.
if early_stopping_metric < best_validation_loss:
logging.info('Updating best validation loss {} with early stopping metric {}.'.format(round(best_validation_loss,4),round(early_stopping_metric,4)))
best_validation_loss = early_stopping_metric
best_validation_tstep = step
# Take a snapshot if the minimum number of steps have been reached.
if step > self.min_steps_to_checkpoint:
self.save(step)
if self.enable_parameter_averaging:
self.save(step, averaged=True)
checkpoint_created=True

# Stop training early and either restart with tigher training parameters or finish entirely.
if step - best_validation_tstep > self.early_stopping_steps:

logging.info('Stopping early at step {}: Best Validation Step: {} Early Stopping Steps: {}'.format(step, best_validation_tstep, self.early_stopping_steps))
if self.num_restarts is None or self.restart_idx >= self.num_restarts:
logging.info('best validation loss of {} at training step {}'.format(
best_validation_loss, best_validation_tstep))
logging.info('early stopping - ending training.')
logging.info('Best validation loss of {} at training step {}'.format(best_validation_loss, best_validation_tstep))
logging.info('Early stopping - ending training.')
return

if self.restart_idx < self.num_restarts:
self.restore(best_validation_tstep)
step = best_validation_tstep
self.restart_idx += 1
self.update_train_params()
train_generator = self.reader.train_batch_generator(self.batch_size)
#Restart the training with tighter parameters if we have remaining restarts and a checkpoint has been created.
if self.restart_idx < self.num_restarts and checkpoint_created:
logging.info('Restarting for the {} time out of {} total restarts.'.format(self.restart_idx, self.num_restarts))
try:
self.restore(best_validation_tstep)
except Exception as error:
logging.warn('Failed to restore checkpoint; will continue training: {} - {}'.format(type(error).__name__, error))
else:
step = best_validation_tstep
self.restart_idx += 1
self.update_train_params()
train_generator = self.reader.train_batch_generator(self.batch_size)

step += 1

#Make sure at least one model gets saved.
if step <= self.min_steps_to_checkpoint:
# best_validation_tstep = step
self.save(step)
Expand Down Expand Up @@ -398,13 +411,13 @@ def update_parameters(self, loss):
else:
self.step = step

logging.info('all parameters:')
logging.info('All parameters:')
logging.info(pp.pformat([(var.name, shape(var)) for var in tfcompat.global_variables()]))

logging.info('trainable parameters:')
logging.info('Trainable parameters:')
logging.info(pp.pformat([(var.name, shape(var)) for var in tfcompat.trainable_variables()]))

logging.info('trainable parameter count:')
logging.info('Trainable parameter count:')
logging.info(str(np.sum(np.prod(shape(var)) for var in tfcompat.trainable_variables())))

def get_optimizer(self, learning_rate, beta1_decay):
Expand All @@ -415,7 +428,7 @@ def get_optimizer(self, learning_rate, beta1_decay):
elif self.optimizer == 'rms':
return tfcompat.train.RMSPropOptimizer(learning_rate, decay=beta1_decay, momentum=0.9)
else:
assert False, 'optimizer must be adam, gd, or rms'
assert False, 'Optimizer must be adam, gd, or rms'

def build_graph(self):
with tf.Graph().as_default() as graph:
Expand Down

0 comments on commit cf5ec03

Please sign in to comment.