Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embed quality-scores as HTML tag attributes #358

Merged
merged 14 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/tests/common-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ void TestSuite<Service>::qualityEstimatorWords(Ptr<TranslationModel> model) {
std::string source = readFromStdin();
const Response response = bridge_.translate(service_, model, std::move(source), responseOptions);

for (const auto &sentenceQualityEstimate : response.qualityScores) {
for (size_t sentenceIdx = 0; sentenceIdx < response.qualityScores.size(); ++sentenceIdx) {
const auto &sentenceQualityEstimate = response.qualityScores[sentenceIdx];
std::cout << "[SentenceBegin]\n";

for (const auto &wordByteRange : sentenceQualityEstimate.wordByteRanges) {
for (const auto &wordByteRange : getWordByteRanges(response, sentenceIdx)) {
const string_view word(response.target.text.data() + wordByteRange.begin, wordByteRange.size());
std::cout << word << "\n";
}
Expand Down
10 changes: 10 additions & 0 deletions src/translator/definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ struct ByteRange {
bool operator==(ByteRange other) const { return begin == other.begin && end == other.end; }
};

/// A Subword range is mechanically the same as a `ByteRange`, but instead of
/// describing a span of bytes, it describes a span of Subword tokens. Using
/// `Annotation.word()` you can switch between the two.
struct SubwordRange {
size_t begin;
size_t end;
const size_t size() const { return end - begin; }
bool operator==(SubwordRange other) const { return begin == other.begin && end == other.end; }
};

class Response;
using CallbackType = std::function<void(Response &&)>;

Expand Down
88 changes: 72 additions & 16 deletions src/translator/html.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>

#include "response.h"
#include "translator/definitions.h"
#include "xh_scanner.h"

namespace {
Expand Down Expand Up @@ -544,7 +545,12 @@ void HTML::restore(Response &response) {
copyTagStack(response, alignments, sourceTokenSpans, targetTokenSpans);
assert(targetTokenSpans.size() == debugCountTokens(response.target));

AnnotatedText target = restoreTarget(response.target, targetTokenSpans);
// Take the spans, and use them to make a taint for every word in the
// translation. Optionally add extra tags, like quality score metadata.
std::vector<HTML::TagStack> targetTokenTags;
annotateTagStack(response, targetTokenSpans, targetTokenTags);

AnnotatedText target = restoreTarget(response.target, targetTokenSpans, targetTokenTags);

response.source = source;
response.target = target;
Expand Down Expand Up @@ -592,38 +598,37 @@ AnnotatedText HTML::restoreSource(AnnotatedText const &in, std::vector<SpanItera
});
}

AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans) {
auto prevSpan = spans_.cbegin();
AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<TagStack> const &targetTokenTags) {
auto prevTags = spans_.cbegin()->tags;
auto stragglerSpanIt = spans_.cbegin();
auto targetSpanIt = targetTokenSpans.begin();
auto straggerSpanIt = spans_.cbegin();
auto targetTagIt = targetTokenTags.begin();

AnnotatedText out = in.apply([&]([[maybe_unused]] ByteRange range, string_view token, bool last) {
TokenFormatter formatter(token);

// First we scan through spans_ to catch up to the span assigned to this
// token. We're only interested in empty spans (empty and void elements)
for (; straggerSpanIt < *targetSpanIt; ++straggerSpanIt) {
for (; stragglerSpanIt < *targetSpanIt; stragglerSpanIt++) {
// We're only interested in empty spans or spans that would otherwise get
// lost because they didn't align with anything between the spans in
// targetSpanIt
// TODO That std::find makes this O(N*N) NOT GOOD NOT GOOD
if (straggerSpanIt->size() != 0 &&
std::find(targetTokenSpans.begin(), targetTokenSpans.end(), straggerSpanIt) != targetTokenSpans.end())
if (stragglerSpanIt->size() != 0 &&
std::find(targetTokenSpans.begin(), targetTokenSpans.end(), stragglerSpanIt) != targetTokenSpans.end())
continue;

formatter.append(prevSpan->tags, straggerSpanIt->tags);

// Note: here, not in 3rd part of for-statement because we don't want to
// set prevSpan if the continue clause at the beginning of this for-loop
// was hit.
prevSpan = straggerSpanIt;
formatter.append(prevTags, stragglerSpanIt->tags);
prevTags = stragglerSpanIt->tags;
}

// Now do the same thing but for our target set of tags. Note that we cannot
// combine this in the for-loop above (i.e. `span_it <= *targetSpanIt`)
// because there is no guarantee that the order in `targetTokenSpans` is
// the same as that of `spans`.
formatter.append(prevSpan->tags, (*targetSpanIt)->tags);

formatter.append(prevTags, *targetTagIt);

// If this is the last token of the response, close all open tags.
if (last) {
Expand All @@ -632,11 +637,12 @@ AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanItera
// the last token of the output. But lets assume someone someday changes
// HardAlignments(), and then this for-loop will be necessary.
// assert((*targetSpanIt)->tags.empty());
formatter.append((*targetSpanIt)->tags, HTML::TagStack());
formatter.append(*targetTagIt, HTML::TagStack());
}

prevSpan = *targetSpanIt;
prevTags = *targetTagIt;
++targetSpanIt;
++targetTagIt;

return std::move(formatter.html());
});
Expand Down Expand Up @@ -674,6 +680,56 @@ void HTML::copyTagStack(Response const &response, std::vector<std::vector<size_t
targetTokenSpans.push_back(sourceTokenSpans[offset]); // token_tag for ending whitespace
}

void HTML::annotateTagStack(Response const &response, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> &targetTokenTags) {
auto spanIt = targetTokenSpans.begin();
for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
// Sentence prefix
targetTokenTags.push_back((*spanIt)->tags);
spanIt++;

// Offset in targetTokenTags at which this sentence's tags start.
size_t tagOffset = targetTokenTags.size();

// Initially, just copy the span's tags to this token
for (size_t t = 0; t < response.target.numWords(sentenceIdx); ++t) {
targetTokenTags.emplace_back((*spanIt)->tags);
spanIt++;
}

// If we have quality score information, add that as metadata as well.
if (!response.qualityScores.empty()) {
auto const &sentenceQuality = response.qualityScores[sentenceIdx];
// Create a single <font> tag for this sentence with sentence level info
Tag *sentenceTag = makeTag({Tag::ELEMENT, "font"});
sentenceTag->attributes += format(" x-bergamot-sentence-index=\"{}\" x-bergamot-sentence-score=\"{}\"",
sentenceIdx, sentenceQuality.sentenceScore);

// Add that tag to all tokens in this sentence.
for (size_t tokenIdx = 0; tokenIdx < response.target.numWords(sentenceIdx); ++tokenIdx) {
targetTokenTags[tagOffset + tokenIdx].push_back(sentenceTag);
}

// Add word level <font> tags as well to all tokens that make up a word.
for (size_t wordIdx = 0; wordIdx < sentenceQuality.wordRanges.size(); ++wordIdx) {
Tag *wordTag = makeTag({Tag::ELEMENT, "font"});
wordTag->attributes += format(" x-bergamot-word-index=\"{}\" x-bergamot-word-score=\"{}\"", wordIdx,
sentenceQuality.wordScores[wordIdx]);
auto const &range = sentenceQuality.wordRanges[wordIdx];
for (size_t tokenIdx = range.begin; tokenIdx < range.end; ++tokenIdx) {
targetTokenTags[tagOffset + tokenIdx].push_back(wordTag);
}
}
}
}

// Suffix
targetTokenTags.push_back((*spanIt)->tags);
spanIt++;

assert(spanIt == targetTokenSpans.end());
}

// Reports if token `str` is likely to be a continuation of a word. This is used
// to determine whether we should share the markup, or whether we should see
// this token as a fresh start. This implementation will treat "hello[world]"
Expand Down
8 changes: 6 additions & 2 deletions src/translator/html.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class HTML {
void restore(Response &response);

private:
using SpanIterator = std::vector<HTML::Span>::const_iterator;
using SpanIterator = std::vector<HTML::Span>::iterator;
using AnnotatedText = marian::bergamot::AnnotatedText;

/// Reconstructs HTML in `response.source` (passed as `in`) and makes a list
Expand All @@ -175,7 +175,8 @@ class HTML {
/// Inserts the HTML into `response.target` (passed as `in`) based on
/// `targetTokenSpans`, which points to a `Span` for each token (subword) in
/// `response.target`.
AnnotatedText restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans);
AnnotatedText restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> const &targetTokenTags);

/// Utilities to test whether subword `str` is part of a word together with
/// the subword `prev`, or a separate word. Basically *does `str` start with
Expand All @@ -190,6 +191,9 @@ class HTML {
std::vector<HTML::SpanIterator> const &sourceTokenSpans,
std::vector<HTML::SpanIterator> &targetTokenSpans);

void annotateTagStack(Response const &response, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> &targetTokenTags);

/// Turns the alignment scores in `response.alignments` into one source token
/// per target token. Has some heuristics to keep all target tokens of a
/// single word pointing to the same span, and prefers spans with more markup
Expand Down
22 changes: 2 additions & 20 deletions src/translator/quality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Response::SentenceQualityScore UnsupervisedQualityEstimator::computeSentenceScor
const float sentenceScore =
std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size();

return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore};
return {wordScores, wordIndices, sentenceScore};
}

LogisticRegressorQualityEstimator::Matrix::Matrix(const size_t rowsParam, const size_t colsParam)
Expand Down Expand Up @@ -160,7 +160,7 @@ Response::SentenceQualityScore LogisticRegressorQualityEstimator::computeSentenc
const float sentenceScore =
std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size();

return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore};
return {wordScores, wordIndices, sentenceScore};
jelmervdl marked this conversation as resolved.
Show resolved Hide resolved
}

std::vector<float> LogisticRegressorQualityEstimator::predict(const Matrix& features) const {
Expand Down Expand Up @@ -267,22 +267,4 @@ std::vector<SubwordRange> mapWords(const std::vector<float>& logProbs, const Ann
return wordIndices;
}

std::vector<ByteRange> subwordToWords(const std::vector<SubwordRange>& wordIndices, const AnnotatedText& target,
const size_t sentenceIdx) {
std::vector<ByteRange> words;

for (const SubwordRange& wordIndice : wordIndices) {
size_t wordBegin = target.wordAsByteRange(sentenceIdx, wordIndice.begin).begin;
size_t wordEnd = target.wordAsByteRange(sentenceIdx, wordIndice.end).begin;

if (isspace(target.text.at(wordBegin))) {
++wordBegin;
}

words.emplace_back(ByteRange{wordBegin, wordEnd});
}

return words;
}

} // namespace marian::bergamot
12 changes: 0 additions & 12 deletions src/translator/quality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class QualityEstimator {
virtual void computeQualityScores(const Histories &histories, Response &response) const = 0;
};

using SubwordRange = ByteRange;

/// Unsupervised Quality Estimator model. It uses the translator model's log probabilities (log probs) as a proxy for
/// quality scores. Then, for a given word, its quality score is computed by taking the mean of the log probs of the
/// tokens that make it up. The sentence score is the mean of all word's log probs.
Expand Down Expand Up @@ -209,14 +207,4 @@ inline std::shared_ptr<QualityEstimator> createQualityEstimator(const AlignedMem
std::vector<SubwordRange> mapWords(const std::vector<float> &logProbs, const AnnotatedText &target,
const size_t sentenceIdx);

/// Given a vector of subwordRanges, it maps the elements to be real words rather than sublevel tokens. The words are
/// represented through ByteRanges.

/// @param [in] wordIndices: A vector where each element correspond to the index of a real word and its values are
/// represented by the SubwordRanges (which are aliases of ByteRanges) which represents sublevel token positions
/// @param [in] target: AnnotatedText target value
/// @param [in] sentenceIdx: the id of a candidate sentence
std::vector<ByteRange> subwordToWords(const std::vector<SubwordRange> &wordIndices, const AnnotatedText &target,
const size_t sentenceIdx);

} // namespace marian::bergamot
18 changes: 18 additions & 0 deletions src/translator/response.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,22 @@ std::vector<Alignment> remapAlignments(const Response &first, const Response &se
return alignments;
}

std::vector<ByteRange> getWordByteRanges(const Response &response, size_t sentenceIdx) {
std::vector<ByteRange> wordByteRanges;
wordByteRanges.reserve(response.qualityScores[sentenceIdx].wordRanges.size());

for (auto &&word : response.qualityScores[sentenceIdx].wordRanges) {
size_t wordBegin = response.target.wordAsByteRange(sentenceIdx, word.begin).begin;
size_t wordEnd = response.target.wordAsByteRange(sentenceIdx, word.end).begin;

if (std::isspace(response.target.text.at(wordBegin))) {
++wordBegin;
}

wordByteRanges.emplace_back(ByteRange{wordBegin, wordEnd});
}

return wordByteRanges;
}
jerinphilip marked this conversation as resolved.
Show resolved Hide resolved

} // namespace marian::bergamot
6 changes: 4 additions & 2 deletions src/translator/response.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ struct Response {
struct SentenceQualityScore {
/// Quality score of each translated word
std::vector<float> wordScores;
/// Each word position in the translated text
std::vector<ByteRange> wordByteRanges;
/// Position of start and end token of each word in the translated text
std::vector<SubwordRange> wordRanges;
/// Whole sentence quality score (it is composed by the mean of its words)
float sentenceScore = 0.0;
};
Expand Down Expand Up @@ -77,6 +77,8 @@ struct Response {

std::vector<Alignment> remapAlignments(const Response &first, const Response &second);

std::vector<ByteRange> getWordByteRanges(Response const &response, size_t sentenceIdx);
jerinphilip marked this conversation as resolved.
Show resolved Hide resolved

} // namespace bergamot
} // namespace marian

Expand Down
12 changes: 0 additions & 12 deletions wasm/bindings/response_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "response.h"

using Response = marian::bergamot::Response;
using SentenceQualityScore = marian::bergamot::Response::SentenceQualityScore;
using ByteRange = marian::bergamot::ByteRange;

using namespace emscripten;
Expand All @@ -20,25 +19,14 @@ EMSCRIPTEN_BINDINGS(byte_range) {
value_object<ByteRange>("ByteRange").field("begin", &ByteRange::begin).field("end", &ByteRange::end);
}

std::vector<SentenceQualityScore> getQualityScores(const Response& response) { return response.qualityScores; }

EMSCRIPTEN_BINDINGS(response) {
class_<Response>("Response")
.constructor<>()
.function("size", &Response::size)
.function("getQualityScores", &getQualityScores)
.function("getOriginalText", &Response::getOriginalText)
.function("getTranslatedText", &Response::getTranslatedText)
.function("getSourceSentence", &Response::getSourceSentenceAsByteRange)
.function("getTranslatedSentence", &Response::getTargetSentenceAsByteRange);

value_object<SentenceQualityScore>("SentenceQualityScore")
.field("wordScores", &SentenceQualityScore::wordScores)
.field("wordByteRanges", &SentenceQualityScore::wordByteRanges)
.field("sentenceScore", &SentenceQualityScore::sentenceScore);

register_vector<Response>("VectorResponse");
register_vector<SentenceQualityScore>("VectorSentenceQualityScore");
register_vector<float>("VectorFloat");
register_vector<ByteRange>("VectorByteRange");
}
21 changes: 20 additions & 1 deletion wasm/test_page/css/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ label {
align-self: center;
}

textarea {
textarea, .output-area {
padding: 1rem;
font-family: sans-serif;
font-size: 1rem;
Expand All @@ -97,3 +97,22 @@ button:hover {
#output {
background-color: #f4f4f4;
}

.output-area [x-bergamot-word-score].bad {
background-image:
linear-gradient(45deg, transparent 65%, red 80%, transparent 90%),
linear-gradient(135deg, transparent 5%, red 15%, transparent 25%),
linear-gradient(135deg, transparent 45%, red 55%, transparent 65%),
linear-gradient(45deg, transparent 25%, red 35%, transparent 50%);
background-repeat:repeat-x;
background-size: 8px 2px;
background-position:0 95%;
}

.output-area [x-bergamot-sentence-score].bad {
background: rgba(255, 128, 128, 0.8);
}

.output-area [x-bergamot-sentence-index].highlight-sentence {
background: rgba(255, 255, 128, 0.8);
}
2 changes: 1 addition & 1 deletion wasm/test_page/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
To
<select id="lang-to" name="to" class="lang-select"></select>
</label>
<textarea id="output" name="output" readonly></textarea>
<div id="output" class="output-area"></div>
</div>
<div class="footer" id="status"></div>
</div>
Expand Down
Loading