From a8d9368ab92b4594f7e0083b17178a0c7f74b4f9 Mon Sep 17 00:00:00 2001 From: lzhbrian Date: Mon, 13 May 2019 12:10:26 +0800 Subject: [PATCH 1/2] fix training iter count error --- train.py | 7 +++++-- util/iter_counter.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 09e18124..ecbf19a2 100755 --- a/train.py +++ b/train.py @@ -24,14 +24,17 @@ trainer = Pix2PixTrainer(opt) # create tool for counting iterations -iter_counter = IterationCounter(opt, len(dataloader)) +iter_counter = IterationCounter(opt, len(dataloader) * opt.batchSize) # create tool for visualization visualizer = Visualizer(opt) for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) - for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): + for i, data_i in enumerate(dataloader): + if i < iter_counter.epoch_iter // opt.batchSize: + continue + iter_counter.record_one_iteration() # Training diff --git a/util/iter_counter.py b/util/iter_counter.py index 1a0182fa..83b01e95 100644 --- a/util/iter_counter.py +++ b/util/iter_counter.py @@ -35,7 +35,6 @@ def training_epochs(self): def record_epoch_start(self, epoch): self.epoch_start_time = time.time() - self.epoch_iter = 0 self.last_iter_time = time.time() self.current_epoch = epoch From 570ff017dc362e8d3033af66a9ba872ce33ebb82 Mon Sep 17 00:00:00 2001 From: lzhbrian Date: Mon, 13 May 2019 14:31:13 +0800 Subject: [PATCH 2/2] bugfix --- train.py | 8 ++++++-- util/iter_counter.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index ecbf19a2..4485c013 100755 --- a/train.py +++ b/train.py @@ -29,10 +29,14 @@ # create tool for visualization visualizer = Visualizer(opt) +clear_iter = False for epoch in iter_counter.training_epochs(): - iter_counter.record_epoch_start(epoch) + iter_counter.record_epoch_start(epoch, clear_iter) + clear_iter = True + + start_batch_idx = iter_counter.epoch_iter // opt.batchSize for i, data_i in enumerate(dataloader): - if i < iter_counter.epoch_iter // opt.batchSize: + if i < start_batch_idx: continue iter_counter.record_one_iteration() diff --git a/util/iter_counter.py b/util/iter_counter.py index 83b01e95..770eb7cc 100644 --- a/util/iter_counter.py +++ b/util/iter_counter.py @@ -33,8 +33,10 @@ def __init__(self, opt, dataset_size): def training_epochs(self): return range(self.first_epoch, self.total_epochs + 1) - def record_epoch_start(self, epoch): + def record_epoch_start(self, epoch, clear_iter=True): self.epoch_start_time = time.time() + if clear_iter: + self.epoch_iter = 0 self.last_iter_time = time.time() self.current_epoch = epoch