Skip to content

Commit

Permalink
Multiple TranslationModels Implementation (#210)
Browse files Browse the repository at this point in the history
For outbound translation, we require having multiple models in the
inventory at the same time and abstracting the "how-to-translate" 
using a model out.

Reorganization: TranslationModel + Service. The new entity which
contains everything required to translate in one direction is
`TranslationModel`. The how-to-translate blocking single-threaded mode
of operation or async multi-threaded mode of operation is decoupled as
`BlockingService` and `AsyncService`. There is a new regression-test
using multiple models in conjunction added, also serving as
a demonstration for using multiple models in Outbound Translation.

WASM: WebAssembly due to the inability to use threads uses
`BlockingService.  Bindings are provided with a new API to work with a
Service, and multiple TranslationModels which the client (JS extension)
can inventory and maintain.  Ownership of a given `TranslationModel` is
shared while translations using the model are active in the internal
mechanism.

Config-Parsing: So far bergamot-translator has been hijacking marian's
config-parsing mechanisms. However, in order to support multiple models,
it has become impractical to continue this approach and a new
config-parsing that is bergamot specific is provisioned for
command-line applications constituting tests. The original marian
config-parsing tooling is only associated with a subset of
`TranslationModel` now. The new config-parsing for the library manages
workers and other common options (tentatively).

There is a known issue of: Inefficient placing of workspaces, leading to
more memory usage than what's necessary. This is to be fixed trickling
down from marian-dev in a later pull request. 

This PR also brings in BRT changes which fix speed-tests that were
broken and also fixes some QE outputs which were different due to not
using shortlist.
  • Loading branch information
jerinphilip authored Sep 21, 2021
1 parent 63120c1 commit cf541c6
Show file tree
Hide file tree
Showing 29 changed files with 1,068 additions and 641 deletions.
26 changes: 15 additions & 11 deletions app/bergamot.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
#include "cli.h"

int main(int argc, char *argv[]) {
auto cp = marian::bergamot::createConfigParser();
auto options = cp.parseOptions(argc, argv, true);
const std::string mode = options->get<std::string>("bergamot-mode");
marian::bergamot::ConfigParser configParser;
configParser.parseArgs(argc, argv);
auto &config = configParser.getConfig();
using namespace marian::bergamot;
if (mode == "wasm") {
app::wasm(options);
} else if (mode == "native") {
app::native(options);
} else if (mode == "decoder") {
app::decoder(options);
} else {
ABORT("Unknown --mode {}. Use one of: {wasm,native,decoder}", mode);
switch (config.opMode) {
case OpMode::APP_WASM:
app::wasm(config);
break;
case OpMode::APP_NATIVE:
app::native(config);
break;
case OpMode::APP_DECODER:
app::decoder(config);
break;
default:
break;
}
return 0;
}
46 changes: 30 additions & 16 deletions app/cli.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,40 @@ namespace app {
/// * Output: written to stdout as translations for the sentences supplied in corresponding lines
///
/// @param [options]: Options to translate passed down to marian through Options.
void wasm(Ptr<Options> options) {
void wasm(const CLIConfig &config) {
// Here, we take the command-line interface which is uniform across all apps. This is parsed into Ptr<Options> by
// marian. However, mozilla does not allow a Ptr<Options> constructor and demands an std::string constructor since
// std::string isn't marian internal unlike Ptr<Options>. Since this std::string path needs to be tested for mozilla
// and since this class/CLI is intended at testing mozilla's path, we go from:
//
// cmdline -> Ptr<Options> -> std::string -> Service(std::string)
// cmdline -> Ptr<Options> -> std::string -> TranslationModel(std::string)
//
// Overkill, yes.

std::string config = options->asYamlString();
Service model(config);
const std::string &modelConfigPath = config.modelConfigPaths.front();

Ptr<Options> options = parseOptionsFromFilePath(modelConfigPath);
MemoryBundle memoryBundle = getMemoryBundleFromConfig(options);

BlockingService::Config serviceConfig;
BlockingService service(serviceConfig);

std::shared_ptr<TranslationModel> translationModel =
std::make_shared<TranslationModel>(options->asYamlString(), std::move(memoryBundle));

ResponseOptions responseOptions;
std::vector<std::string> texts;

#ifdef WASM_COMPATIBLE_SOURCE
// Hide the translateMultiple operation
for (std::string line; std::getline(std::cin, line);) {
texts.emplace_back(line);
}

auto results = model.translateMultiple(std::move(texts), responseOptions);
auto results = service.translateMultiple(translationModel, std::move(texts), responseOptions);

for (auto &result : results) {
std::cout << result.getTranslatedText() << std::endl;
}
#endif
}

/// Application used to benchmark with marian-decoder from time-to-time. The implementation in this repository follows a
Expand All @@ -82,9 +88,13 @@ void wasm(Ptr<Options> options) {
/// * Output: to stdout, translations of the sentences supplied via stdin in corresponding lines
///
/// @param [in] options: constructed from command-line supplied arguments
void decoder(Ptr<Options> options) {
void decoder(const CLIConfig &config) {
marian::timer::Timer decoderTimer;
Service service(options);
AsyncService::Config asyncConfig{config.numWorkers};
AsyncService service(asyncConfig);
auto options = parseOptionsFromFilePath(config.modelConfigPaths.front());
MemoryBundle memoryBundle;
Ptr<TranslationModel> translationModel = service.createCompatibleModel(options, std::move(memoryBundle));
// Read a large input text blob from stdin
std::ostringstream std_input;
std_input << std::cin.rdbuf();
Expand All @@ -95,14 +105,15 @@ void decoder(Ptr<Options> options) {
std::future<Response> responseFuture = responsePromise.get_future();
auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); };

service.translate(std::move(input), std::move(callback));
service.translate(translationModel, std::move(input), std::move(callback));
responseFuture.wait();
const Response &response = responseFuture.get();

for (size_t sentenceIdx = 0; sentenceIdx < response.size(); sentenceIdx++) {
std::cout << response.target.sentence(sentenceIdx) << "\n";
}
LOG(info, "Total time: {:.5f}s wall", decoderTimer.elapsed());

std::cerr << "Total time: " << std::setprecision(5) << decoderTimer.elapsed() << "s wall" << std::endl;
}

/// Command line interface to the test the features being developed as part of bergamot C++ library on native platform.
Expand All @@ -114,16 +125,19 @@ void decoder(Ptr<Options> options) {
/// * Output: to stdout, translation of the source text faithful to source structure.
///
/// @param [in] options: options to build translator
void native(Ptr<Options> options) {
void native(const CLIConfig &config) {
AsyncService::Config asyncConfig{config.numWorkers};
AsyncService service(asyncConfig);

auto options = parseOptionsFromFilePath(config.modelConfigPaths.front());
// Prepare memories for bytearrays (including model, shortlist and vocabs)
MemoryBundle memoryBundle;

if (options->get<bool>("bytearray")) {
if (config.byteArray) {
// Load legit values into bytearrays.
memoryBundle = getMemoryBundleFromConfig(options);
}

Service service(options, std::move(memoryBundle));
Ptr<TranslationModel> translationModel = service.createCompatibleModel(options, std::move(memoryBundle));

// Read a large input text blob from stdin
std::ostringstream std_input;
Expand All @@ -137,7 +151,7 @@ void native(Ptr<Options> options) {
std::future<Response> responseFuture = responsePromise.get_future();
auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); };

service.translate(std::move(input), std::move(callback), responseOptions);
service.translate(translationModel, std::move(input), std::move(callback), responseOptions);
responseFuture.wait();
Response response = responseFuture.get();

Expand Down
68 changes: 45 additions & 23 deletions src/tests/apps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,41 @@

namespace marian {
namespace bergamot {
namespace testapp {

// Utility function, common for all testapps.
Response translateFromStdin(Ptr<Options> options, ResponseOptions responseOptions) {
// Prepare memories for bytearrays (including model, shortlist and vocabs)
MemoryBundle memoryBundle;

if (options->get<bool>("bytearray")) {
// Load legit values into bytearrays.
memoryBundle = getMemoryBundleFromConfig(options);
}

Service service(options, std::move(memoryBundle));
namespace {

std::string readFromStdin() {
// Read a large input text blob from stdin
std::ostringstream inputStream;
inputStream << std::cin.rdbuf();
std::string input = inputStream.str();
return input;
}

// Utility function, common for all testapps.
Response translateForResponse(AsyncService &service, Ptr<TranslationModel> model, std::string &&source,
ResponseOptions responseOptions) {
std::promise<Response> responsePromise;
std::future<Response> responseFuture = responsePromise.get_future();

auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); };
service.translate(std::move(input), callback, responseOptions);
service.translate(model, std::move(source), callback, responseOptions);

responseFuture.wait();

Response response = responseFuture.get();
return response;
}

void annotatedTextWords(Ptr<Options> options, bool source) {
} // namespace

namespace testapp {

void annotatedTextWords(AsyncService &service, Ptr<TranslationModel> model, bool sourceSide) {
ResponseOptions responseOptions;
Response response = translateFromStdin(options, responseOptions);
AnnotatedText &annotatedText = source ? response.source : response.target;
std::string source = readFromStdin();
Response response = translateForResponse(service, model, std::move(source), responseOptions);
AnnotatedText &annotatedText = sourceSide ? response.source : response.target;
for (size_t s = 0; s < annotatedText.numSentences(); s++) {
for (size_t w = 0; w < annotatedText.numWords(s); w++) {
std::cout << (w == 0 ? "" : "\t");
Expand All @@ -46,19 +46,39 @@ void annotatedTextWords(Ptr<Options> options, bool source) {
}
}

void annotatedTextSentences(Ptr<Options> options, bool source) {
void annotatedTextSentences(AsyncService &service, Ptr<TranslationModel> model, bool sourceSide) {
ResponseOptions responseOptions;
Response response = translateFromStdin(options, responseOptions);
AnnotatedText &annotatedText = source ? response.source : response.target;
std::string source = readFromStdin();
Response response = translateForResponse(service, model, std::move(source), responseOptions);
AnnotatedText &annotatedText = sourceSide ? response.source : response.target;
for (size_t s = 0; s < annotatedText.numSentences(); s++) {
std::cout << annotatedText.sentence(s) << "\n";
}
}

void qualityEstimatorWords(const Ptr<Options> &options) {
void forwardAndBackward(AsyncService &service, std::vector<Ptr<TranslationModel>> &models) {
ABORT_IF(models.size() != 2, "Forward and backward test needs two models.");
ResponseOptions responseOptions;
std::string source = readFromStdin();
Response forwardResponse = translateForResponse(service, models.front(), std::move(source), responseOptions);

// Make a copy of target
std::string target = forwardResponse.target.text;
Response backwardResponse = translateForResponse(service, models.back(), std::move(target), responseOptions);

// Print both onto the command-line
std::cout << forwardResponse.source.text;
std::cout << "----------------\n";
std::cout << forwardResponse.target.text;
std::cout << "----------------\n";
std::cout << backwardResponse.target.text;
}

void qualityEstimatorWords(AsyncService &service, Ptr<TranslationModel> model) {
ResponseOptions responseOptions;
responseOptions.qualityScores = true;
const Response response = translateFromStdin(options, responseOptions);
std::string source = readFromStdin();
const Response response = translateForResponse(service, model, std::move(source), responseOptions);

for (const auto &sentenceQualityEstimate : response.qualityScores) {
std::cout << "[SentenceBegin]\n";
Expand All @@ -71,10 +91,12 @@ void qualityEstimatorWords(const Ptr<Options> &options) {
}
}

void qualityEstimatorScores(const Ptr<Options> &options) {
void qualityEstimatorScores(AsyncService &service, Ptr<TranslationModel> model) {
ResponseOptions responseOptions;
responseOptions.qualityScores = true;
const Response response = translateFromStdin(options, responseOptions);

std::string source = readFromStdin();
const Response response = translateForResponse(service, model, std::move(source), responseOptions);

for (const auto &sentenceQualityEstimate : response.qualityScores) {
std::cout << std::fixed << std::setprecision(3) << sentenceQualityEstimate.sentenceScore << "\n";
Expand Down
14 changes: 6 additions & 8 deletions src/tests/apps.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@ namespace bergamot {

namespace testapp {

// Utility function, common for all testapps. Reads content from stdin, builds a Service based on options and constructs
// a response containing translation data according responseOptions.
Response translateFromStdin(Ptr<Options> options, ResponseOptions responseOptions);

// Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source
// side text annotation if source=true, target annotation otherwise.
void annotatedTextWords(Ptr<Options> options, bool source = true);
void annotatedTextWords(AsyncService &service, Ptr<TranslationModel> model, bool source = true);

// Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response
// in each line, depending on source = true or false respectively.
void annotatedTextSentences(Ptr<Options> options, bool source = true);
void annotatedTextSentences(AsyncService &service, Ptr<TranslationModel> model, bool source = true);

void forwardAndBackward(AsyncService &service, std::vector<Ptr<TranslationModel>> &models);

// Reads from stdin and translates the read content. Prints the quality words for each sentence.
void qualityEstimatorWords(const Ptr<Options>& options);
void qualityEstimatorWords(AsyncService &service, Ptr<TranslationModel> model);

// Reads from stdin and translates the read content. Prints the quality scores for each sentence.
void qualityEstimatorScores(const Ptr<Options>& options);
void qualityEstimatorScores(AsyncService &service, Ptr<TranslationModel> model);

} // namespace testapp
} // namespace bergamot
Expand Down
54 changes: 38 additions & 16 deletions src/tests/cli.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@

#include "apps.h"

int main(int argc, char *argv[]) {
auto cp = marian::bergamot::createConfigParser();
auto options = cp.parseOptions(argc, argv, true);
const std::string mode = options->get<std::string>("bergamot-mode");
using namespace marian::bergamot;
if (mode == "test-response-source-sentences") {
testapp::annotatedTextSentences(options, /*source=*/true);
} else if (mode == "test-response-target-sentences") {
testapp::annotatedTextSentences(options, /*source=*/false);
} else if (mode == "test-response-source-words") {
testapp::annotatedTextWords(options, /*source=*/true);
} else if (mode == std::string("test-quality-estimator-words")) {
testapp::qualityEstimatorWords(options);
} else if (mode == std::string("test-quality-estimator-scores")) {
testapp::qualityEstimatorScores(options);
} else {
ABORT("Unknown --mode {}. Please run a valid test", mode);
marian::bergamot::ConfigParser configParser;
configParser.parseArgs(argc, argv);
auto &config = configParser.getConfig();
AsyncService::Config serviceConfig{config.numWorkers};
AsyncService service(serviceConfig);
std::vector<std::shared_ptr<TranslationModel>> models;

for (auto &modelConfigPath : config.modelConfigPaths) {
TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath);
std::shared_ptr<TranslationModel> model = service.createCompatibleModel(modelConfig);
models.push_back(model);
}

switch (config.opMode) {
case OpMode::TEST_SOURCE_SENTENCES:
testapp::annotatedTextSentences(service, models.front(), /*source=*/true);
break;
case OpMode::TEST_TARGET_SENTENCES:
testapp::annotatedTextSentences(service, models.front(), /*source=*/false);
break;
case OpMode::TEST_SOURCE_WORDS:
testapp::annotatedTextWords(service, models.front(), /*source=*/true);
break;
case OpMode::TEST_TARGET_WORDS:
testapp::annotatedTextWords(service, models.front(), /*source=*/false);
break;
case OpMode::TEST_FORWARD_BACKWARD_FOR_OUTBOUND:
testapp::forwardAndBackward(service, models);
break;
case OpMode::TEST_QUALITY_ESTIMATOR_WORDS:
testapp::qualityEstimatorWords(service, models.front());
break;
case OpMode::TEST_QUALITY_ESTIMATOR_SCORES:
testapp::qualityEstimatorScores(service, models.front());
break;
default:
ABORT("Incompatible op-mode. Choose one of the test modes.");
break;
}
return 0;
}
7 changes: 4 additions & 3 deletions src/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/project_version.h.in
add_library(bergamot-translator STATIC
byte_array_util.cpp
text_processor.cpp
batch_translator.cpp
translation_model.cpp
request.cpp
batcher.cpp
batching_pool.cpp
aggregate_batching_pool.cpp
response_builder.cpp
quality_estimator.cpp
batch.cpp
annotation.cpp
service.cpp
threadsafe_batcher.cpp
parser.cpp
)
if (USE_WASM_COMPATIBLE_SOURCE)
# Using wasm compatible sources should include this compile definition;
Expand Down
Loading

0 comments on commit cf541c6

Please sign in to comment.