Skip to content

Commit

Permalink
Fix issue with checkpoint decoder metrics for distributed training (#997
Browse files Browse the repository at this point in the history
)

* Broadcast checkpoint decoder metrics from primary to secondary workers

* Only primary worker runs the checkpoint decoder

* Update version and changelog
  • Loading branch information
mjdenkowski authored Dec 17, 2021
1 parent 0226871 commit 71527d7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.0.6]

### Fixed

- Fixed checkpoint decoder issue that prevented using `bleu` as `--optimized-metric` for distributed training ([#995](https://github.com/awslabs/sockeye/issues/995)).

## [3.0.5]

### Fixed
Expand All @@ -33,7 +39,7 @@ Each version section may have subsections for: _Added_, _Changed_, _Removed_, _D

### Changed

- `sockeye-translate`: Beam search now computes and returns secondary target factor scores. Secondary target factors
- `sockeye-translate`: Beam search now computes and returns secondary target factor scores. Secondary target factors
do not participate in beam search, but are greedily chosen at every time step. Accumulated scores for secondary factors
are not normalized by length. Factor scores are included in JSON output (``--output-type json``).
- `sockeye-score` now returns tab-separated scores for each target factor. Users can decide how to combine factor scores
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.0.5'
__version__ = '3.0.6'
22 changes: 14 additions & 8 deletions sockeye/training_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def fit(self,
checkpoint_decoder: Optional[checkpoint_decoder_pt.CheckpointDecoder] = None):
logger.info("Early stopping by optimizing '%s'", self.config.early_stopping_metric)

if self.config.early_stopping_metric in C.METRICS_REQUIRING_DECODER:
if utils.is_primary_worker() and self.config.early_stopping_metric in C.METRICS_REQUIRING_DECODER:
utils.check_condition(checkpoint_decoder is not None,
"%s requires CheckpointDecoder" % self.config.early_stopping_metric)

Expand Down Expand Up @@ -405,15 +405,21 @@ def _evaluate(self, checkpoint: int, data_iter,
for loss_metric, (loss_value, num_samples) in zip(val_metrics, loss_outputs):
loss_metric.update(loss_value.item(), num_samples.item())

# Optionally run the checkpoint decoder
if checkpoint_decoder is not None:
# Primary worker optionally runs the checkpoint decoder
decoder_metrics = {} # type: Dict[str, float]
if utils.is_primary_worker() and checkpoint_decoder is not None:
output_name = os.path.join(self.config.output_dir, C.DECODE_OUT_NAME.format(checkpoint=checkpoint))
decoder_metrics = checkpoint_decoder.decode_and_evaluate(output_name=output_name)
for metric_name, metric_value in decoder_metrics.items():
assert metric_name not in val_metrics, "Duplicate validation metric %s" % metric_name
metric = loss_pt.LossMetric(name=metric_name)
metric.update(metric_value, num_samples=1)
val_metrics.append(metric)
# Broadcast decoder metrics (if any) from primary worker to secondary
# workers
if utils.is_distributed():
decoder_metrics = utils.broadcast_object(decoder_metrics)
# Add decoder metrics (if any) to validation metrics
for metric_name, metric_value in decoder_metrics.items():
assert metric_name not in val_metrics, "Duplicate validation metric %s" % metric_name
metric = loss_pt.LossMetric(name=metric_name)
metric.update(metric_value, num_samples=1)
val_metrics.append(metric)

logger.info('Checkpoint [%d]\t%s',
self.state.checkpoint, "\t".join("Validation-%s" % str(lm) for lm in val_metrics))
Expand Down

0 comments on commit 71527d7

Please sign in to comment.