diff --git a/train.py b/train.py index 09e18124..4485c013 100755 --- a/train.py +++ b/train.py @@ -24,14 +24,21 @@ trainer = Pix2PixTrainer(opt) # create tool for counting iterations -iter_counter = IterationCounter(opt, len(dataloader)) +iter_counter = IterationCounter(opt, len(dataloader) * opt.batchSize) # create tool for visualization visualizer = Visualizer(opt) +clear_iter = False for epoch in iter_counter.training_epochs(): - iter_counter.record_epoch_start(epoch) - for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): + iter_counter.record_epoch_start(epoch, clear_iter) + clear_iter = True + + start_batch_idx = iter_counter.epoch_iter // opt.batchSize + for i, data_i in enumerate(dataloader): + if i < start_batch_idx: + continue + iter_counter.record_one_iteration() # Training diff --git a/util/iter_counter.py b/util/iter_counter.py index 1a0182fa..770eb7cc 100644 --- a/util/iter_counter.py +++ b/util/iter_counter.py @@ -33,9 +33,10 @@ def __init__(self, opt, dataset_size): def training_epochs(self): return range(self.first_epoch, self.total_epochs + 1) - def record_epoch_start(self, epoch): + def record_epoch_start(self, epoch, clear_iter=True): self.epoch_start_time = time.time() - self.epoch_iter = 0 + if clear_iter: + self.epoch_iter = 0 self.last_iter_time = time.time() self.current_epoch = epoch