Skip to content

Commit

Permalink
Merged PR 22524: Optimize guided alignment training speed via sparse …
Browse files Browse the repository at this point in the history
…alignments - part 1

This replaces dense alignment storage and training with a sparse representation. Training speed with guided alignment matches now nearly normal training speed, regaining about 25% speed.

This is no. 1 of 2 PRs. The next one will introduce a new guided-alignment training scheme with better alignment accuracy.
  • Loading branch information
emjotde committed Feb 11, 2022
1 parent 3b21ff3 commit 4b51dcb
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 113 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Fixed

### Changed

- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
- Changed minimal C++ standard to C++-17
- Faster LSH top-k search on CPU

Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.11.2
v1.11.3
2 changes: 1 addition & 1 deletion src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"none");
cli.add<std::string>("--guided-alignment-cost",
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
"mse");
"ce");
cli.add<double>("--guided-alignment-weight",
"Weight for guided alignment cost",
0.1);
Expand Down
39 changes: 35 additions & 4 deletions src/data/alignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "common/utils.h"

#include <algorithm>
#include <cmath>
#include <set>

namespace marian {
namespace data {
Expand All @@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {}

WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}

WordAlignment::WordAlignment(const std::string& line) {
WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) {
std::vector<std::string> atok = utils::splitAny(line, " -");
for(size_t i = 0; i < atok.size(); i += 2)
data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols
}

void WordAlignment::sort() {
Expand All @@ -22,6 +25,35 @@ void WordAlignment::sort() {
});
}

void WordAlignment::normalize(bool reverse/*=false*/) {
std::vector<size_t> counts;
counts.reserve(data_.size());

// reverse==false : normalize target word prob by number of source words
// reverse==true : normalize source word prob by number of target words
auto srcOrTgt = [](const Point& p, bool reverse) {
return reverse ? p.srcPos : p.tgtPos;
};

for(const auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts.size() <= pos)
counts.resize(pos + 1, 0);
counts[pos]++;
}

// a.prob at this point is either 1 or normalized to a different value,
// but we just set it to 1 / count, so multiple calls result in re-normalization
// regardless of forward or reverse direction. We also set the remaining values to 1.
for(auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts[pos] > 1)
a.prob = 1.f / counts[pos];
else
a.prob = 1.f;
}
}

std::string WordAlignment::toString() const {
std::stringstream str;
for(auto p = begin(); p != end(); ++p) {
Expand All @@ -32,7 +64,7 @@ std::string WordAlignment::toString() const {
return str.str();
}

WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold /*= 1.f*/) {
WordAlignment align;
// Alignments by maximum value
Expand All @@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
}
}
}

// Sort alignment pairs in ascending order
align.sort();

Expand Down
21 changes: 16 additions & 5 deletions src/data/alignment.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#pragma once

#include <sstream>
#include <tuple>
#include <vector>

namespace marian {
namespace data {

class WordAlignment {
struct Point
{
public:
struct Point {
size_t srcPos;
size_t tgtPos;
float prob;
};
private:
std::vector<Point> data_;

public:
WordAlignment();

Expand All @@ -28,11 +30,14 @@ class WordAlignment {
public:

/**
* @brief Constructs word alignments from textual representation.
* @brief Constructs word alignments from textual representation. Adds alignment point for externally
* supplied EOS positions in source and target string.
*
* @param line String in the form of "0-0 1-1 1-2", etc.
*/
WordAlignment(const std::string& line);
WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos);

Point& operator[](size_t i) { return data_[i]; }

auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
auto end() const -> decltype(data_.end()) { return data_.end(); }
Expand All @@ -46,6 +51,12 @@ class WordAlignment {
*/
void sort();

/**
* @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments.
* This is needed for correct cost computation for guided alignment training with CE cost criterion.
*/
void normalize(bool reverse=false);

/**
* @brief Returns textual representation.
*/
Expand All @@ -56,7 +67,7 @@ class WordAlignment {
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]

WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold = 1.f);

std::string SoftAlignToString(SoftAlignment align);
Expand Down
2 changes: 1 addition & 1 deletion src/data/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Batch {
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }

virtual void setGuidedAlignment(std::vector<float>&&) = 0;
virtual void setGuidedAlignment(std::vector<WordAlignment>&&) = 0;
virtual void setDataWeights(const std::vector<float>&) = 0;
virtual ~Batch() {};
protected:
Expand Down
13 changes: 6 additions & 7 deletions src/data/corpus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ SentenceTuple Corpus::next() {
tup.markAltered();
addWordsToSentenceTuple(fields[i], vocabId, tup);
}

// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
}
// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);

// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
Expand Down
25 changes: 11 additions & 14 deletions src/data/corpus_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,

void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
SentenceTupleImpl& tup) const {
ABORT_IF(rightLeft_,
"Guided alignment and right-left model cannot be used "
"together at the moment");
ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment");
ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size());

auto align = WordAlignment(line);
size_t srcEosPos = tup[0].size() - 1;
size_t tgtEosPos = tup[1].size() - 1;

auto align = WordAlignment(line, srcEosPos, tgtEosPos);
tup.setAlignment(align);
}

Expand All @@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl

void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
const std::vector<Sample>& batchVector) {
int srcWords = (int)batch->front()->batchWidth();
int trgWords = (int)batch->back()->batchWidth();
std::vector<WordAlignment> aligns;

int dimBatch = (int)batch->getSentenceIds().size();

std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);

aligns.reserve(dimBatch);

for(int b = 0; b < dimBatch; ++b) {

// If the batch vector is altered within marian by, for example, case augmentation,
// the guided alignments we received for this tuple cease to be valid.
// Hence skip setting alignments for that sentence tuple..
if (!batchVector[b].isAltered()) {
for(auto p : batchVector[b].getAlignment()) {
size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos;
aligns[idx] = 1.f;
}
aligns.push_back(std::move(batchVector[b].getAlignment()));
}
}
batch->setGuidedAlignment(std::move(aligns));
Expand Down
46 changes: 18 additions & 28 deletions src/data/corpus_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class SubBatch {
class CorpusBatch : public Batch {
protected:
std::vector<Ptr<SubBatch>> subBatches_;
std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<float> dataWeights_;

public:
Expand Down Expand Up @@ -444,8 +444,17 @@ class CorpusBatch : public Batch {

if(options->get("guided-alignment", std::string("none")) != "none") {
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
0.f);

std::vector<data::WordAlignment> alignment;
for(size_t k = 0; k < batchSize; ++k) {
data::WordAlignment perSentence;
// fill with random alignment points, add more twice the number of words to be safe.
for(size_t j = 0; j < lengths.back() * 2; ++j) {
size_t i = rand() % lengths.back();
perSentence.push_back(i, j, 1.0f);
}
alignment.push_back(std::move(perSentence));
}
batch->setGuidedAlignment(std::move(alignment));
}

Expand Down Expand Up @@ -501,29 +510,14 @@ class CorpusBatch : public Batch {
}

if(!guidedAlignment_.empty()) {
size_t oldTrgWords = back()->batchWidth();
size_t oldSize = size();

pos = 0;
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t srcWords = cb->front()->batchWidth();
size_t trgWords = cb->back()->batchWidth();
size_t dimBatch = cb->size();

std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);

for(size_t i = 0; i < dimBatch; ++i) {
size_t bi = i + pos;
for(size_t sid = 0; sid < srcWords; ++sid) {
for(size_t tid = 0; tid < trgWords; ++tid) {
size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
aligns[idx] = guidedAlignment_[bidx];
}
}
}
cb->setGuidedAlignment(std::move(aligns));
std::vector<WordAlignment> batchAlignment;
for(size_t i = 0; i < dimBatch; ++i)
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
cb->setGuidedAlignment(std::move(batchAlignment));
pos += dimBatch;
}
}
Expand Down Expand Up @@ -556,15 +550,11 @@ class CorpusBatch : public Batch {
return splits;
}

const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<float>&& aln) override {
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
guidedAlignment_ = std::move(aln);
}

size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
return ((s * size()) + b) * widthTrg() + t;
}

std::vector<float>& getDataWeights() { return dataWeights_; }
void setDataWeights(const std::vector<float>& weights) override {
dataWeights_ = weights;
Expand Down
2 changes: 1 addition & 1 deletion src/examples/mnist/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DataBatch : public Batch {

size_t size() const override { return inputs_.front().shape()[0]; }

void setGuidedAlignment(std::vector<float>&&) override {
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
ABORT("Guided alignment in DataBatch is not implemented");
}
void setDataWeights(const std::vector<float>&) override {
Expand Down
2 changes: 2 additions & 0 deletions src/graph/expression_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) {
/*********************************************************/

Expr concatenate(const std::vector<Expr>& concats, int ax) {
if(concats.size() == 1)
return concats[0];
return Expression<ConcatenateNodeOp>(concats, ax);
}

Expand Down
Loading

0 comments on commit 4b51dcb

Please sign in to comment.