Skip to content

Commit

Permalink
Merged PR 27051: Add an option for completely resetting validation me…
Browse files Browse the repository at this point in the history
…trics

Added `--valid-reset-all` that works as `--valid-reset-stalled` but it also resets last best saved validation metrics, which is useful for when the validation sets change for continued training.

Added new regression test: marian-nmt/marian-regression-tests#89
  • Loading branch information
Roman Grundkiewicz committed Dec 20, 2022
1 parent b7205fc commit ee50d4a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fused inplace-dropout in FFN layer in Transformer
- `--force-decode` option for marian-decoder
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
- `--valid-reset-all` option

### Fixed
- Make concat factors not break old vector implementation
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.11.14
v1.11.15
5 changes: 4 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,10 @@ stages:

# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
- bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
# Installing libunwind-dev fixes a bug in 2204 (the libunwind-14 and libunwind-dev conflict)
- bash: |
sudo apt-get install -y libunwind-dev
sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
displayName: Install packages
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
Expand Down
6 changes: 4 additions & 2 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)",
2048);
2048);
cli.add<int>("--transformer-decoder-dim-ffn",
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
0);
Expand Down Expand Up @@ -591,7 +591,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Multiple metrics can be specified",
{"cross-entropy"});
cli.add<bool>("--valid-reset-stalled",
"Reset all stalled validation metrics when the training is restarted");
"Reset stalled validation metrics when the training is restarted");
cli.add<bool>("--valid-reset-all",
"Reset all validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);
Expand Down
13 changes: 9 additions & 4 deletions src/training/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,17 @@ class Scheduler : public TrainingObserver {
state_->wordsDisp = 0;
}

if(options_->get<bool>("valid-reset-stalled")) {
if(options_->get<bool>("valid-reset-stalled") || options_->get<bool>("valid-reset-all")) {
state_->stalled = 0;
state_->maxStalled = 0;
for(const auto& validator : validators_) {
if(state_->validators[validator->type()])
if(state_->validators[validator->type()]) {
// reset the number of stalled validations, e.g. when the validation set is the same
state_->validators[validator->type()]["stalled"] = 0;
// reset last best results as well, e.g. when the validation set changes
if(options_->get<bool>("valid-reset-all"))
state_->validators[validator->type()]["last-best"] = validator->initScore();
}
}
}

Expand All @@ -512,10 +517,10 @@ class Scheduler : public TrainingObserver {
if(mpi_->isMainProcess())
if(filesystem::exists(nameYaml))
yamlStr = io::InputFileStream(nameYaml).readToString();

if(mpi_)
mpi_->bCast(yamlStr);

loadFromString(yamlStr);
}

Expand Down

0 comments on commit ee50d4a

Please sign in to comment.