From cf541c68f9b43bce8c68e2292007a1573cfaa38e Mon Sep 17 00:00:00 2001 From: Jerin Philip Date: Tue, 21 Sep 2021 18:10:40 +0100 Subject: [PATCH] Multiple TranslationModels Implementation (#210) For outbound translation, we require having multiple models in the inventory at the same time and abstracting the "how-to-translate" using a model out. Reorganization: TranslationModel + Service. The new entity which contains everything required to translate in one direction is `TranslationModel`. The how-to-translate blocking single-threaded mode of operation or async multi-threaded mode of operation is decoupled as `BlockingService` and `AsyncService`. There is a new regression-test using multiple models in conjunction added, also serving as a demonstration for using multiple models in Outbound Translation. WASM: WebAssembly due to the inability to use threads uses `BlockingService. Bindings are provided with a new API to work with a Service, and multiple TranslationModels which the client (JS extension) can inventory and maintain. Ownership of a given `TranslationModel` is shared while translations using the model are active in the internal mechanism. Config-Parsing: So far bergamot-translator has been hijacking marian's config-parsing mechanisms. However, in order to support multiple models, it has become impractical to continue this approach and a new config-parsing that is bergamot specific is provisioned for command-line applications constituting tests. The original marian config-parsing tooling is only associated with a subset of `TranslationModel` now. The new config-parsing for the library manages workers and other common options (tentatively). There is a known issue of: Inefficient placing of workspaces, leading to more memory usage than what's necessary. This is to be fixed trickling down from marian-dev in a later pull request. This PR also brings in BRT changes which fix speed-tests that were broken and also fixes some QE outputs which were different due to not using shortlist. --- app/bergamot.cpp | 26 ++- app/cli.h | 46 ++-- bergamot-translator-tests | 2 +- src/tests/apps.cpp | 68 ++++-- src/tests/apps.h | 14 +- src/tests/cli.cpp | 54 +++-- src/translator/CMakeLists.txt | 7 +- src/translator/aggregate_batching_pool.cpp | 34 +++ src/translator/aggregate_batching_pool.h | 68 ++++++ src/translator/batch_translator.cpp | 128 ----------- src/translator/batch_translator.h | 57 ----- .../{batcher.cpp => batching_pool.cpp} | 15 +- src/translator/{batcher.h => batching_pool.h} | 23 +- src/translator/definitions.h | 3 + src/translator/parser.cpp | 170 +++++++++++++++ src/translator/parser.h | 120 +++++------ src/translator/request.h | 6 +- src/translator/response_builder.h | 2 +- src/translator/service.cpp | 99 ++++----- src/translator/service.h | 198 ++++++++---------- src/translator/text_processor.cpp | 8 +- src/translator/text_processor.h | 6 +- src/translator/threadsafe_batcher.cpp | 38 ---- src/translator/threadsafe_batcher.h | 57 ----- src/translator/threadsafe_batching_pool.cpp | 49 +++++ src/translator/threadsafe_batching_pool.h | 71 +++++++ src/translator/translation_model.cpp | 173 +++++++++++++++ src/translator/translation_model.h | 122 +++++++++++ wasm/bindings/service_bindings.cpp | 45 ++-- 29 files changed, 1068 insertions(+), 641 deletions(-) create mode 100644 src/translator/aggregate_batching_pool.cpp create mode 100644 src/translator/aggregate_batching_pool.h delete mode 100644 src/translator/batch_translator.cpp delete mode 100644 src/translator/batch_translator.h rename src/translator/{batcher.cpp => batching_pool.cpp} (83%) rename src/translator/{batcher.h => batching_pool.h} (63%) create mode 100644 src/translator/parser.cpp delete mode 100644 src/translator/threadsafe_batcher.cpp delete mode 100644 src/translator/threadsafe_batcher.h create mode 100644 src/translator/threadsafe_batching_pool.cpp create mode 100644 src/translator/threadsafe_batching_pool.h create mode 100644 src/translator/translation_model.cpp create mode 100644 src/translator/translation_model.h diff --git a/app/bergamot.cpp b/app/bergamot.cpp index 19dea1fcf..bffbbb112 100644 --- a/app/bergamot.cpp +++ b/app/bergamot.cpp @@ -1,18 +1,22 @@ #include "cli.h" int main(int argc, char *argv[]) { - auto cp = marian::bergamot::createConfigParser(); - auto options = cp.parseOptions(argc, argv, true); - const std::string mode = options->get("bergamot-mode"); + marian::bergamot::ConfigParser configParser; + configParser.parseArgs(argc, argv); + auto &config = configParser.getConfig(); using namespace marian::bergamot; - if (mode == "wasm") { - app::wasm(options); - } else if (mode == "native") { - app::native(options); - } else if (mode == "decoder") { - app::decoder(options); - } else { - ABORT("Unknown --mode {}. Use one of: {wasm,native,decoder}", mode); + switch (config.opMode) { + case OpMode::APP_WASM: + app::wasm(config); + break; + case OpMode::APP_NATIVE: + app::native(config); + break; + case OpMode::APP_DECODER: + app::decoder(config); + break; + default: + break; } return 0; } diff --git a/app/cli.h b/app/cli.h index 4afe8b9aa..9cb12dd28 100644 --- a/app/cli.h +++ b/app/cli.h @@ -34,34 +34,40 @@ namespace app { /// * Output: written to stdout as translations for the sentences supplied in corresponding lines /// /// @param [options]: Options to translate passed down to marian through Options. -void wasm(Ptr options) { +void wasm(const CLIConfig &config) { // Here, we take the command-line interface which is uniform across all apps. This is parsed into Ptr by // marian. However, mozilla does not allow a Ptr constructor and demands an std::string constructor since // std::string isn't marian internal unlike Ptr. Since this std::string path needs to be tested for mozilla // and since this class/CLI is intended at testing mozilla's path, we go from: // - // cmdline -> Ptr -> std::string -> Service(std::string) + // cmdline -> Ptr -> std::string -> TranslationModel(std::string) // // Overkill, yes. - std::string config = options->asYamlString(); - Service model(config); + const std::string &modelConfigPath = config.modelConfigPaths.front(); + + Ptr options = parseOptionsFromFilePath(modelConfigPath); + MemoryBundle memoryBundle = getMemoryBundleFromConfig(options); + + BlockingService::Config serviceConfig; + BlockingService service(serviceConfig); + + std::shared_ptr translationModel = + std::make_shared(options->asYamlString(), std::move(memoryBundle)); ResponseOptions responseOptions; std::vector texts; -#ifdef WASM_COMPATIBLE_SOURCE // Hide the translateMultiple operation for (std::string line; std::getline(std::cin, line);) { texts.emplace_back(line); } - auto results = model.translateMultiple(std::move(texts), responseOptions); + auto results = service.translateMultiple(translationModel, std::move(texts), responseOptions); for (auto &result : results) { std::cout << result.getTranslatedText() << std::endl; } -#endif } /// Application used to benchmark with marian-decoder from time-to-time. The implementation in this repository follows a @@ -82,9 +88,13 @@ void wasm(Ptr options) { /// * Output: to stdout, translations of the sentences supplied via stdin in corresponding lines /// /// @param [in] options: constructed from command-line supplied arguments -void decoder(Ptr options) { +void decoder(const CLIConfig &config) { marian::timer::Timer decoderTimer; - Service service(options); + AsyncService::Config asyncConfig{config.numWorkers}; + AsyncService service(asyncConfig); + auto options = parseOptionsFromFilePath(config.modelConfigPaths.front()); + MemoryBundle memoryBundle; + Ptr translationModel = service.createCompatibleModel(options, std::move(memoryBundle)); // Read a large input text blob from stdin std::ostringstream std_input; std_input << std::cin.rdbuf(); @@ -95,14 +105,15 @@ void decoder(Ptr options) { std::future responseFuture = responsePromise.get_future(); auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); }; - service.translate(std::move(input), std::move(callback)); + service.translate(translationModel, std::move(input), std::move(callback)); responseFuture.wait(); const Response &response = responseFuture.get(); for (size_t sentenceIdx = 0; sentenceIdx < response.size(); sentenceIdx++) { std::cout << response.target.sentence(sentenceIdx) << "\n"; } - LOG(info, "Total time: {:.5f}s wall", decoderTimer.elapsed()); + + std::cerr << "Total time: " << std::setprecision(5) << decoderTimer.elapsed() << "s wall" << std::endl; } /// Command line interface to the test the features being developed as part of bergamot C++ library on native platform. @@ -114,16 +125,19 @@ void decoder(Ptr options) { /// * Output: to stdout, translation of the source text faithful to source structure. /// /// @param [in] options: options to build translator -void native(Ptr options) { +void native(const CLIConfig &config) { + AsyncService::Config asyncConfig{config.numWorkers}; + AsyncService service(asyncConfig); + + auto options = parseOptionsFromFilePath(config.modelConfigPaths.front()); // Prepare memories for bytearrays (including model, shortlist and vocabs) MemoryBundle memoryBundle; - - if (options->get("bytearray")) { + if (config.byteArray) { // Load legit values into bytearrays. memoryBundle = getMemoryBundleFromConfig(options); } - Service service(options, std::move(memoryBundle)); + Ptr translationModel = service.createCompatibleModel(options, std::move(memoryBundle)); // Read a large input text blob from stdin std::ostringstream std_input; @@ -137,7 +151,7 @@ void native(Ptr options) { std::future responseFuture = responsePromise.get_future(); auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); }; - service.translate(std::move(input), std::move(callback), responseOptions); + service.translate(translationModel, std::move(input), std::move(callback), responseOptions); responseFuture.wait(); Response response = responseFuture.get(); diff --git a/bergamot-translator-tests b/bergamot-translator-tests index 53c6e42a9..9dc3c5e9a 160000 --- a/bergamot-translator-tests +++ b/bergamot-translator-tests @@ -1 +1 @@ -Subproject commit 53c6e42a97e512698711068d0be3c208359b1801 +Subproject commit 9dc3c5e9a1027c1d6b4a467a27bdff16d0d6a006 diff --git a/src/tests/apps.cpp b/src/tests/apps.cpp index 991d3c3fd..63febfaf0 100644 --- a/src/tests/apps.cpp +++ b/src/tests/apps.cpp @@ -2,30 +2,25 @@ namespace marian { namespace bergamot { -namespace testapp { - -// Utility function, common for all testapps. -Response translateFromStdin(Ptr options, ResponseOptions responseOptions) { - // Prepare memories for bytearrays (including model, shortlist and vocabs) - MemoryBundle memoryBundle; - if (options->get("bytearray")) { - // Load legit values into bytearrays. - memoryBundle = getMemoryBundleFromConfig(options); - } - - Service service(options, std::move(memoryBundle)); +namespace { +std::string readFromStdin() { // Read a large input text blob from stdin std::ostringstream inputStream; inputStream << std::cin.rdbuf(); std::string input = inputStream.str(); + return input; +} +// Utility function, common for all testapps. +Response translateForResponse(AsyncService &service, Ptr model, std::string &&source, + ResponseOptions responseOptions) { std::promise responsePromise; std::future responseFuture = responsePromise.get_future(); auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); }; - service.translate(std::move(input), callback, responseOptions); + service.translate(model, std::move(source), callback, responseOptions); responseFuture.wait(); @@ -33,10 +28,15 @@ Response translateFromStdin(Ptr options, ResponseOptions responseOption return response; } -void annotatedTextWords(Ptr options, bool source) { +} // namespace + +namespace testapp { + +void annotatedTextWords(AsyncService &service, Ptr model, bool sourceSide) { ResponseOptions responseOptions; - Response response = translateFromStdin(options, responseOptions); - AnnotatedText &annotatedText = source ? response.source : response.target; + std::string source = readFromStdin(); + Response response = translateForResponse(service, model, std::move(source), responseOptions); + AnnotatedText &annotatedText = sourceSide ? response.source : response.target; for (size_t s = 0; s < annotatedText.numSentences(); s++) { for (size_t w = 0; w < annotatedText.numWords(s); w++) { std::cout << (w == 0 ? "" : "\t"); @@ -46,19 +46,39 @@ void annotatedTextWords(Ptr options, bool source) { } } -void annotatedTextSentences(Ptr options, bool source) { +void annotatedTextSentences(AsyncService &service, Ptr model, bool sourceSide) { ResponseOptions responseOptions; - Response response = translateFromStdin(options, responseOptions); - AnnotatedText &annotatedText = source ? response.source : response.target; + std::string source = readFromStdin(); + Response response = translateForResponse(service, model, std::move(source), responseOptions); + AnnotatedText &annotatedText = sourceSide ? response.source : response.target; for (size_t s = 0; s < annotatedText.numSentences(); s++) { std::cout << annotatedText.sentence(s) << "\n"; } } -void qualityEstimatorWords(const Ptr &options) { +void forwardAndBackward(AsyncService &service, std::vector> &models) { + ABORT_IF(models.size() != 2, "Forward and backward test needs two models."); + ResponseOptions responseOptions; + std::string source = readFromStdin(); + Response forwardResponse = translateForResponse(service, models.front(), std::move(source), responseOptions); + + // Make a copy of target + std::string target = forwardResponse.target.text; + Response backwardResponse = translateForResponse(service, models.back(), std::move(target), responseOptions); + + // Print both onto the command-line + std::cout << forwardResponse.source.text; + std::cout << "----------------\n"; + std::cout << forwardResponse.target.text; + std::cout << "----------------\n"; + std::cout << backwardResponse.target.text; +} + +void qualityEstimatorWords(AsyncService &service, Ptr model) { ResponseOptions responseOptions; responseOptions.qualityScores = true; - const Response response = translateFromStdin(options, responseOptions); + std::string source = readFromStdin(); + const Response response = translateForResponse(service, model, std::move(source), responseOptions); for (const auto &sentenceQualityEstimate : response.qualityScores) { std::cout << "[SentenceBegin]\n"; @@ -71,10 +91,12 @@ void qualityEstimatorWords(const Ptr &options) { } } -void qualityEstimatorScores(const Ptr &options) { +void qualityEstimatorScores(AsyncService &service, Ptr model) { ResponseOptions responseOptions; responseOptions.qualityScores = true; - const Response response = translateFromStdin(options, responseOptions); + + std::string source = readFromStdin(); + const Response response = translateForResponse(service, model, std::move(source), responseOptions); for (const auto &sentenceQualityEstimate : response.qualityScores) { std::cout << std::fixed << std::setprecision(3) << sentenceQualityEstimate.sentenceScore << "\n"; diff --git a/src/tests/apps.h b/src/tests/apps.h index deb6a12dc..dee77a9be 100644 --- a/src/tests/apps.h +++ b/src/tests/apps.h @@ -21,23 +21,21 @@ namespace bergamot { namespace testapp { -// Utility function, common for all testapps. Reads content from stdin, builds a Service based on options and constructs -// a response containing translation data according responseOptions. -Response translateFromStdin(Ptr options, ResponseOptions responseOptions); - // Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source // side text annotation if source=true, target annotation otherwise. -void annotatedTextWords(Ptr options, bool source = true); +void annotatedTextWords(AsyncService &service, Ptr model, bool source = true); // Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response // in each line, depending on source = true or false respectively. -void annotatedTextSentences(Ptr options, bool source = true); +void annotatedTextSentences(AsyncService &service, Ptr model, bool source = true); + +void forwardAndBackward(AsyncService &service, std::vector> &models); // Reads from stdin and translates the read content. Prints the quality words for each sentence. -void qualityEstimatorWords(const Ptr& options); +void qualityEstimatorWords(AsyncService &service, Ptr model); // Reads from stdin and translates the read content. Prints the quality scores for each sentence. -void qualityEstimatorScores(const Ptr& options); +void qualityEstimatorScores(AsyncService &service, Ptr model); } // namespace testapp } // namespace bergamot diff --git a/src/tests/cli.cpp b/src/tests/cli.cpp index 0e9469ab0..90c386c84 100644 --- a/src/tests/cli.cpp +++ b/src/tests/cli.cpp @@ -1,23 +1,45 @@ - #include "apps.h" int main(int argc, char *argv[]) { - auto cp = marian::bergamot::createConfigParser(); - auto options = cp.parseOptions(argc, argv, true); - const std::string mode = options->get("bergamot-mode"); using namespace marian::bergamot; - if (mode == "test-response-source-sentences") { - testapp::annotatedTextSentences(options, /*source=*/true); - } else if (mode == "test-response-target-sentences") { - testapp::annotatedTextSentences(options, /*source=*/false); - } else if (mode == "test-response-source-words") { - testapp::annotatedTextWords(options, /*source=*/true); - } else if (mode == std::string("test-quality-estimator-words")) { - testapp::qualityEstimatorWords(options); - } else if (mode == std::string("test-quality-estimator-scores")) { - testapp::qualityEstimatorScores(options); - } else { - ABORT("Unknown --mode {}. Please run a valid test", mode); + marian::bergamot::ConfigParser configParser; + configParser.parseArgs(argc, argv); + auto &config = configParser.getConfig(); + AsyncService::Config serviceConfig{config.numWorkers}; + AsyncService service(serviceConfig); + std::vector> models; + + for (auto &modelConfigPath : config.modelConfigPaths) { + TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath); + std::shared_ptr model = service.createCompatibleModel(modelConfig); + models.push_back(model); + } + + switch (config.opMode) { + case OpMode::TEST_SOURCE_SENTENCES: + testapp::annotatedTextSentences(service, models.front(), /*source=*/true); + break; + case OpMode::TEST_TARGET_SENTENCES: + testapp::annotatedTextSentences(service, models.front(), /*source=*/false); + break; + case OpMode::TEST_SOURCE_WORDS: + testapp::annotatedTextWords(service, models.front(), /*source=*/true); + break; + case OpMode::TEST_TARGET_WORDS: + testapp::annotatedTextWords(service, models.front(), /*source=*/false); + break; + case OpMode::TEST_FORWARD_BACKWARD_FOR_OUTBOUND: + testapp::forwardAndBackward(service, models); + break; + case OpMode::TEST_QUALITY_ESTIMATOR_WORDS: + testapp::qualityEstimatorWords(service, models.front()); + break; + case OpMode::TEST_QUALITY_ESTIMATOR_SCORES: + testapp::qualityEstimatorScores(service, models.front()); + break; + default: + ABORT("Incompatible op-mode. Choose one of the test modes."); + break; } return 0; } diff --git a/src/translator/CMakeLists.txt b/src/translator/CMakeLists.txt index c0ee6be7a..ab1448800 100644 --- a/src/translator/CMakeLists.txt +++ b/src/translator/CMakeLists.txt @@ -5,15 +5,16 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/project_version.h.in add_library(bergamot-translator STATIC byte_array_util.cpp text_processor.cpp - batch_translator.cpp + translation_model.cpp request.cpp - batcher.cpp + batching_pool.cpp + aggregate_batching_pool.cpp response_builder.cpp quality_estimator.cpp batch.cpp annotation.cpp service.cpp - threadsafe_batcher.cpp + parser.cpp ) if (USE_WASM_COMPATIBLE_SOURCE) # Using wasm compatible sources should include this compile definition; diff --git a/src/translator/aggregate_batching_pool.cpp b/src/translator/aggregate_batching_pool.cpp new file mode 100644 index 000000000..38c55f1c4 --- /dev/null +++ b/src/translator/aggregate_batching_pool.cpp @@ -0,0 +1,34 @@ + +#include "aggregate_batching_pool.h" + +namespace marian { +namespace bergamot { + +AggregateBatchingPool::AggregateBatchingPool() { + // TODO(@jerinphilip): Set aggregate limits +} + +size_t AggregateBatchingPool::enqueueRequest(Ptr model, Ptr request) { + model->enqueueRequest(request); + aggregateQueue_.insert(model); + return request->numSegments(); +} + +size_t AggregateBatchingPool::generateBatch(Ptr& model, Batch& batch) { + while (!aggregateQueue_.empty()) { + auto candidateItr = aggregateQueue_.begin(); + Ptr candidate = *candidateItr; + size_t numSentences = candidate->generateBatch(batch); + if (numSentences > 0) { + model = candidate; + return numSentences; + } else { + // Try the next model's batching pool. + aggregateQueue_.erase(candidateItr); + } + } + return /*numSentences=*/0; +} + +} // namespace bergamot +} // namespace marian diff --git a/src/translator/aggregate_batching_pool.h b/src/translator/aggregate_batching_pool.h new file mode 100644 index 000000000..5b5d4b17a --- /dev/null +++ b/src/translator/aggregate_batching_pool.h @@ -0,0 +1,68 @@ +#ifndef SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_ +#define SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_ + +#include +#include + +#include "data/types.h" +#include "translation_model.h" + +namespace marian { +namespace bergamot { + +/// Hashes a pointer to an object using the address the pointer points to. If two pointers point to the same address, +/// they hash to the same value. Useful to put widely shared_ptrs of entities (eg: TranslationModel, Vocab, Shortlist) +/// etc into containers which require the members to be hashable (std::unordered_set, std::unordered_map). +template +struct HashPtr { + size_t operator()(const std::shared_ptr& t) const { + size_t address = reinterpret_cast(t.get()); + return std::hash()(address); + } +}; + +/// Aggregates request queueing and generation of batches from multiple TranslationModels (BatchingPools within, +/// specifically), thereby acting as an intermediary to enable multiple translation model capability in BlockingService +/// and AsyncService. +/// +/// A simple queue containing shared owning references to TranslationModels are held here from which batches are +/// generated on demand. Since a queue is involved, the ordering is first-come first serve on requests except there are +/// leaks effectively doing priority inversion if an earlier request with the same TranslationModel is pending +/// to be consumed for translation. +// +/// Actual storage for the request and batch generation are within the respective TranslationModels, which owns its own +/// BatchingPool. +/// +/// Matches API provided by BatchingPool except arguments additionally parameterized by TranslationModel. +/// +/// Note: This class is not thread-safe. You may use this class wrapped with ThreadsafeBatchingPool for a thread-safe +/// equivalent of this class, if needed. +class AggregateBatchingPool { + public: + /// Create an AggregateBatchingPool with (tentatively) global (across all BatchingPools) limits + /// imposed here. + AggregateBatchingPool(); + + /// Enqueue an existing request onto model, also keep account of that this model and request are now pending. + /// + /// @param [in] model: Model to use in translation. A shared ownership to this model is accepted by this object to + /// keep the model alive until translation is complete. + /// @param [in] request: A request to be enqueued to model. + /// @returns number of sentences added for translation. + size_t enqueueRequest(Ptr model, Ptr request); + + /// Generate a batch from pending requests, obtained from available TranslationModels. + /// + /// @param [out] model: TranslationModel + /// @param [out] batch: Batch to write onto, which is consumed at translation elsewhere. + /// @returns Number of sentences in the generated batch. + size_t generateBatch(Ptr& model, Batch& batch); + + private: + std::unordered_set, HashPtr> aggregateQueue_; +}; + +} // namespace bergamot +} // namespace marian + +#endif // SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_ diff --git a/src/translator/batch_translator.cpp b/src/translator/batch_translator.cpp deleted file mode 100644 index 889ff0073..000000000 --- a/src/translator/batch_translator.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include "batch_translator.h" - -#include "batch.h" -#include "byte_array_util.h" -#include "common/logging.h" -#include "data/corpus.h" -#include "data/text_input.h" -#include "translator/beam_search.h" - -namespace marian { -namespace bergamot { - -BatchTranslator::BatchTranslator(DeviceId const device, Vocabs &vocabs, Ptr options, - const AlignedMemory *modelMemory, const AlignedMemory *shortlistMemory) - : device_(device), - options_(options), - vocabs_(vocabs), - modelMemory_(modelMemory), - shortlistMemory_(shortlistMemory) {} - -void BatchTranslator::initialize() { - // Initializes the graph. - if (options_->hasAndNotEmpty("shortlist")) { - int srcIdx = 0, trgIdx = 1; - bool shared_vcb = - vocabs_.sources().front() == - vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab - if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) { - slgen_ = New(shortlistMemory_->begin(), shortlistMemory_->size(), - vocabs_.sources().front(), vocabs_.target(), srcIdx, trgIdx, - shared_vcb, options_->get("check-bytearray")); - } else { - // Changed to BinaryShortlistGenerator to enable loading binary shortlist file - // This class also supports text shortlist file - slgen_ = New(options_, vocabs_.sources().front(), vocabs_.target(), srcIdx, - trgIdx, shared_vcb); - } - } - - graph_ = New(true); // set the graph to be inference only - auto prec = options_->get>("precision", {"float32"}); - graph_->setDefaultElementType(typeFromString(prec[0])); - graph_->setDevice(device_); - graph_->getBackend()->configureDevice(options_); - graph_->reserveWorkspaceMB(options_->get("workspace")); - if (modelMemory_->size() > 0 && - modelMemory_->begin() != - nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model - // from there, as opposed to from reading in the config file - ABORT_IF((uintptr_t)modelMemory_->begin() % 256 != 0, - "The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it."); - if (options_->get("check-bytearray")) { - ABORT_IF(!validateBinaryModel(*modelMemory_, modelMemory_->size()), - "The binary file is invalid. Incomplete or corrupted download?"); - } - const std::vector container = { - modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector. - // However we will only ever use 1 during decoding. - scorers_ = createScorers(options_, container); - } else { - scorers_ = createScorers(options_); - } - for (auto scorer : scorers_) { - scorer->init(graph_); - if (slgen_) { - scorer->setShortlistGenerator(slgen_); - } - } - graph_->forward(); -} - -void BatchTranslator::translate(Batch &batch) { - std::vector batchVector; - - auto &sentences = batch.sentences(); - size_t batchSequenceNumber{0}; - for (auto &sentence : sentences) { - data::SentenceTuple sentence_tuple(batchSequenceNumber); - Segment segment = sentence.getUnderlyingSegment(); - sentence_tuple.push_back(segment); - batchVector.push_back(sentence_tuple); - - ++batchSequenceNumber; - } - - size_t batchSize = batchVector.size(); - std::vector sentenceIds; - std::vector maxDims; - for (auto &ex : batchVector) { - if (maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0); - for (size_t i = 0; i < ex.size(); ++i) { - if (ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size(); - } - sentenceIds.push_back(ex.getId()); - } - - typedef marian::data::SubBatch SubBatch; - typedef marian::data::CorpusBatch CorpusBatch; - - std::vector> subBatches; - for (size_t j = 0; j < maxDims.size(); ++j) { - subBatches.emplace_back(New(batchSize, maxDims[j], vocabs_.sources().at(j))); - } - - std::vector words(maxDims.size(), 0); - for (size_t i = 0; i < batchSize; ++i) { - for (size_t j = 0; j < maxDims.size(); ++j) { - for (size_t k = 0; k < batchVector[i][j].size(); ++k) { - subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k]; - subBatches[j]->mask()[k * batchSize + i] = 1.f; - words[j]++; - } - } - } - - for (size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]); - - auto corpus_batch = Ptr(new CorpusBatch(subBatches)); - corpus_batch->setSentenceIds(sentenceIds); - - auto search = New(options_, scorers_, vocabs_.target()); - - auto histories = std::move(search->search(graph_, corpus_batch)); - batch.completeBatch(histories); -} - -} // namespace bergamot -} // namespace marian diff --git a/src/translator/batch_translator.h b/src/translator/batch_translator.h deleted file mode 100644 index 6a7fa9842..000000000 --- a/src/translator/batch_translator.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef SRC_BERGAMOT_BATCH_TRANSLATOR_H_ -#define SRC_BERGAMOT_BATCH_TRANSLATOR_H_ - -#include -#include - -#include "batch.h" -#include "common/utils.h" -#include "data/shortlist.h" -#include "definitions.h" -#include "request.h" -#include "translator/history.h" -#include "translator/scorers.h" -#include "vocabs.h" - -namespace marian { -namespace bergamot { - -class BatchTranslator { - // Launches minimal marian-translation (only CPU at the moment) in individual - // threads. Constructor launches each worker thread running mainloop(). - // mainloop runs until until it receives poison from the PCQueue. Threads are - // shut down in Service which calls join() on the threads. - - public: - /** - * Initialise the marian translator. - * @param device DeviceId that performs translation. Could be CPU or GPU - * @param vocabs Vector that contains ptrs to two vocabs - * @param options Marian options object - * @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not - * used. - * @param shortlistMemory byte array of shortlist (aligned to 64) - */ - explicit BatchTranslator(DeviceId const device, Vocabs& vocabs, Ptr options, - const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory); - - // convenience function for logging. TODO(jerin) - std::string _identifier() { return "worker" + std::to_string(device_.no); } - void translate(Batch& batch); - void initialize(); - - private: - Ptr options_; - DeviceId device_; - const Vocabs& vocabs_; - Ptr graph_; - std::vector> scorers_; - Ptr slgen_; - const AlignedMemory* modelMemory_{nullptr}; - const AlignedMemory* shortlistMemory_{nullptr}; -}; - -} // namespace bergamot -} // namespace marian - -#endif // SRC_BERGAMOT_BATCH_TRANSLATOR_H_ diff --git a/src/translator/batcher.cpp b/src/translator/batching_pool.cpp similarity index 83% rename from src/translator/batcher.cpp rename to src/translator/batching_pool.cpp index 0a14459f1..83b5e00ab 100644 --- a/src/translator/batcher.cpp +++ b/src/translator/batching_pool.cpp @@ -1,4 +1,4 @@ -#include "batcher.h" +#include "batching_pool.h" #include @@ -8,7 +8,7 @@ namespace marian { namespace bergamot { -Batcher::Batcher(Ptr options) { +BatchingPool::BatchingPool(Ptr options) { miniBatchWords = options->get("mini-batch-words"); bucket_.resize(options->get("max-length-break") + 1); ABORT_IF(bucket_.size() - 1 > miniBatchWords, @@ -16,7 +16,7 @@ Batcher::Batcher(Ptr options) { "longer than what can fit in a batch."); } -bool Batcher::cleaveBatch(Batch &batch) { +size_t BatchingPool::generateBatch(Batch &batch) { // For now simply iterates on buckets and converts batches greedily. This // has to be enhanced with optimizing over priority. The baseline // implementation should at least be as fast as marian's maxi-batch with full @@ -35,22 +35,23 @@ bool Batcher::cleaveBatch(Batch &batch) { } else { // Check if elements exist assert(batch.size() > 0); - return true; + return batch.size(); } } } - bool isValidBatch = batch.size() > 0; - return isValidBatch; + return batch.size(); } -void Batcher::addWholeRequest(Ptr request) { +size_t BatchingPool::enqueueRequest(Ptr request) { for (size_t i = 0; i < request->numSegments(); i++) { RequestSentence sentence(i, request); size_t bucket_id = sentence.numTokens(); assert(bucket_id < bucket_.size()); bucket_[bucket_id].insert(sentence); } + + return request->numSegments(); } } // namespace bergamot diff --git a/src/translator/batcher.h b/src/translator/batching_pool.h similarity index 63% rename from src/translator/batcher.h rename to src/translator/batching_pool.h index 277bfc934..68b2cf0d0 100644 --- a/src/translator/batcher.h +++ b/src/translator/batching_pool.h @@ -1,5 +1,5 @@ -#ifndef SRC_BERGAMOT_BATCHER_H_ -#define SRC_BERGAMOT_BATCHER_H_ +#ifndef SRC_BERGAMOT_BATCHING_POOL_H_ +#define SRC_BERGAMOT_BATCHING_POOL_H_ #include #include @@ -12,24 +12,21 @@ namespace marian { namespace bergamot { -class Batcher { + +class BatchingPool { public: - explicit Batcher(Ptr options); + explicit BatchingPool(Ptr options); // RequestSentence incorporates (tentative) notions of priority with each // sentence. This method inserts the sentence into the internal data-structure // which maintains priority among sentences from multiple concurrent requests. - void addWholeRequest(Ptr request); - - // indicate no more sentences will be added. Does nothing here, for parity to threadsafe version. - void shutdown() {} - - bool operator>>(Batch &batch) { return cleaveBatch(batch); } + size_t enqueueRequest(Ptr request); - private: // Loads sentences with sentences compiled from (tentatively) multiple // requests optimizing for both padding and priority. - bool cleaveBatch(Batch &batch); + size_t generateBatch(Batch &batch); + + private: size_t miniBatchWords; std::vector> bucket_; size_t batchNumber_{0}; @@ -38,4 +35,4 @@ class Batcher { } // namespace bergamot } // namespace marian -#endif // SRC_BERGAMOT_BATCHER_H_ +#endif // SRC_BERGAMOT_BATCHING_POOL_H_ diff --git a/src/translator/definitions.h b/src/translator/definitions.h index a0f544ded..66ebb03b4 100644 --- a/src/translator/definitions.h +++ b/src/translator/definitions.h @@ -41,6 +41,9 @@ struct ByteRange { const size_t size() const { return end - begin; } }; +class Response; +using CallbackType = std::function; + } // namespace bergamot } // namespace marian diff --git a/src/translator/parser.cpp b/src/translator/parser.cpp new file mode 100644 index 000000000..d927409b5 --- /dev/null +++ b/src/translator/parser.cpp @@ -0,0 +1,170 @@ +#include "parser.h" + +#include + +#include "common/build_info.h" +#include "common/config.h" +#include "common/regex.h" +#include "common/version.h" + +namespace marian { +namespace bergamot { + +std::istringstream &operator>>(std::istringstream &in, OpMode &mode) { + std::string modeString; + in >> modeString; + std::unordered_map table = { + {"wasm", OpMode::APP_WASM}, + {"native", OpMode::APP_NATIVE}, + {"decoder", OpMode::APP_DECODER}, + {"test-response-source-sentences", OpMode::TEST_SOURCE_SENTENCES}, + {"test-response-target-sentences", OpMode::TEST_TARGET_SENTENCES}, + {"test-response-source-words", OpMode::TEST_SOURCE_WORDS}, + {"test-response-target-words", OpMode::TEST_TARGET_WORDS}, + {"test-quality-estimator-words", OpMode::TEST_QUALITY_ESTIMATOR_WORDS}, + {"test-quality-estimator-scores", OpMode::TEST_QUALITY_ESTIMATOR_SCORES}, + {"test-forward-backward", OpMode::TEST_FORWARD_BACKWARD_FOR_OUTBOUND}, + }; + + auto query = table.find(modeString); + if (query != table.end()) { + mode = query->second; + } else { + ABORT("Unknown mode {}", modeString); + } + + return in; +} + +ConfigParser::ConfigParser() : app_{"Bergamot Options"} { + addSpecialOptions(app_); + addOptionsBoundToConfig(app_, config_); +}; + +void ConfigParser::parseArgs(int argc, char *argv[]) { + try { + app_.parse(argc, argv); + handleSpecialOptions(); + } catch (const CLI::ParseError &e) { + exit(app_.exit(e)); + } +} + +void ConfigParser::addSpecialOptions(CLI::App &app) { + app.add_flag("--build-info", build_info_, "Print build-info and exit"); + app.add_flag("--version", version_, "Print version-info and exit"); +} + +void ConfigParser::handleSpecialOptions() { + if (build_info_) { +#ifndef _MSC_VER // cmake build options are not available on MSVC based build. + std::cerr << cmakeBuildOptionsAdvanced() << std::endl; + exit(0); +#else // _MSC_VER + ABORT("build-info is not available on MSVC based build."); +#endif // _MSC_VER + } + + if (version_) { + std::cerr << buildVersion() << std::endl; + exit(0); + } +} + +void ConfigParser::addOptionsBoundToConfig(CLI::App &app, CLIConfig &config) { + app.add_option("--model-config-paths", config.modelConfigPaths, + "Configuration files list, can be used for pivoting multiple models or multiple model workflows"); + + app.add_flag("--bytearray", config.byteArray, + "Flag holds whether to construct service from bytearrays, only for testing purpose"); + + app.add_flag("--check-bytearray", config.validateByteArray, + "Flag holds whether to check the content of the bytearrays (true by default)"); + + app.add_option("--cpu-threads", config.numWorkers, "Number of worker threads to use for translation"); + + app_.add_option("--bergamot-mode", config.opMode, "Operating mode for bergamot: [wasm, native, decoder]"); +} + +std::shared_ptr parseOptionsFromFilePath(const std::string &configPath, bool validate /*= true*/) { + // Read entire string and redirect to parseOptionsFromString + std::ifstream readStream(configPath); + std::stringstream buffer; + buffer << readStream.rdbuf(); + return parseOptionsFromString(buffer.str(), validate, /*pathsInSameDirAs=*/configPath); +}; + +std::shared_ptr parseOptionsFromString(const std::string &configAsString, bool validate /*= true*/, + std::string pathsInSameDirAs /*=""*/) { + marian::Options options; + + marian::ConfigParser configParser(cli::mode::translation); + + // These are additional options we use to hijack for our own marian-replacement layer (for batching, + // multi-request-compile etc) and hence goes into Ptr. + configParser.addOption("--max-length-break", "Bergamot Options", + "Maximum input tokens to be processed in a single sentence.", 128); + + // The following is a complete hijack of an existing option, so no need to add explicitly. + // configParser.addOption("--mini-batch-words", "Bergamot Options", + // "Maximum input tokens to be processed in a single sentence.", 1024); + + configParser.addOption("--ssplit-prefix-file", "Bergamot Options", + "File with nonbreaking prefixes for sentence splitting."); + + configParser.addOption("--ssplit-mode", "Bergamot Options", "[paragraph, sentence, wrapped_text]", + "paragraph"); + + configParser.addOption("--quality", "Bergamot Options", "File considering Quality Estimation model"); + + // Parse configs onto defaultConfig. The preliminary merge sets the YAML internal representation with legal values. + const YAML::Node &defaultConfig = configParser.getConfig(); + options.merge(defaultConfig); + options.parse(configAsString); + + // This is in a marian `.cpp` as of now, and requires explicit copy-here. + // https://github.com/marian-nmt/marian-dev/blob/9fa166be885b025711f27b35453e0f2c00c9933e/src/common/config_parser.cpp#L28 + + // clang-format off + const std::set PATHS = { + "model", + "models", + "train-sets", + "vocabs", + "embedding-vectors", + "valid-sets", + "valid-script-path", + "valid-script-args", + "valid-log", + "valid-translation-output", + "input", // except: 'stdin', handled in makeAbsolutePaths and interpolateEnvVars + "output", // except: 'stdout', handled in makeAbsolutePaths and interpolateEnvVars + "pretrained-model", + "data-weighting", + "log", + "sqlite", // except: 'temporary', handled in the processPaths function + "shortlist", // except: only the first element in the sequence is a path, handled in the + // processPaths function + "ssplit-prefix-file", // added for bergamot + "quality", // added for bergamot + }; + // clang-format on + + if (!pathsInSameDirAs.empty()) { + YAML::Node configYAML = options.cloneToYamlNode(); + marian::cli::makeAbsolutePaths(configYAML, pathsInSameDirAs, PATHS); + options.merge(configYAML, /*overwrite=*/true); + } + + // Perform validation on parsed options only when requested + if (validate) { + YAML::Node configYAML = options.cloneToYamlNode(); + marian::ConfigValidator validator(configYAML); + validator.validateOptions(marian::cli::mode::translation); + } + + return std::make_shared(options); +} + +} // namespace bergamot +} // namespace marian diff --git a/src/translator/parser.h b/src/translator/parser.h index 54aaaf86a..c9fffcebf 100644 --- a/src/translator/parser.h +++ b/src/translator/parser.h @@ -1,6 +1,10 @@ #ifndef SRC_BERGAMOT_PARSER_H #define SRC_BERGAMOT_PARSER_H +#include +#include + +#include "3rd_party/marian-dev/src/3rd_party/CLI/CLI.hpp" #include "3rd_party/yaml-cpp/yaml.h" #include "common/config_parser.h" #include "common/config_validator.h" @@ -10,65 +14,63 @@ namespace marian { namespace bergamot { -inline marian::ConfigParser createConfigParser() { - marian::ConfigParser cp(marian::cli::mode::translation); - cp.addOption("--ssplit-prefix-file", "Bergamot Options", - "File with nonbreaking prefixes for sentence splitting."); - - cp.addOption("--ssplit-mode", "Server Options", "[paragraph, sentence, wrapped_text]", "paragraph"); - - cp.addOption("--max-length-break", "Bergamot Options", - "Maximum input tokens to be processed in a single sentence.", 128); - - cp.addOption("--bytearray", "Bergamot Options", - "Flag holds whether to construct service from bytearrays, only for testing purpose", false); - - cp.addOption("--check-bytearray", "Bergamot Options", - "Flag holds whether to check the content of the bytearrays (true by default)", true); - - cp.addOption("--bergamot-mode", "Bergamot Options", - "Operating mode for bergamot: [wasm, native, decoder]", "native"); - - cp.addOption("--quality", "Bergamot Options", "File considering Quality Estimation model"); - - return cp; -} - -inline std::shared_ptr parseOptions(const std::string &config, bool validate = true) { - marian::Options options; - - // @TODO(jerinphilip) There's something off here, @XapaJIaMnu suggests - // that should not be using the defaultConfig. This function only has access - // to std::string config and needs to be able to construct Options from the - // same. - - // Absent the following code-segment, there is a parsing exception thrown on - // rebuilding YAML. - // - // Error: Unhandled exception of type 'N4YAML11InvalidNodeE': invalid node; - // this may result from using a map iterator as a sequence iterator, or - // vice-versa - // - // Error: Aborted from void unhandledException() in - // 3rd_party/marian-dev/src/common/logging.cpp:113 - - marian::ConfigParser configParser = createConfigParser(); - const YAML::Node &defaultConfig = configParser.getConfig(); - - options.merge(defaultConfig); - - // Parse configs onto defaultConfig. - options.parse(config); - YAML::Node configCopy = options.cloneToYamlNode(); - - if (validate) { - // Perform validation on parsed options only when requested - marian::ConfigValidator validator(configCopy); - validator.validateOptions(marian::cli::mode::translation); - } - - return std::make_shared(options); -} +enum OpMode { + APP_WASM, + APP_NATIVE, + APP_DECODER, + TEST_SOURCE_SENTENCES, + TEST_TARGET_SENTENCES, + TEST_SOURCE_WORDS, + TEST_TARGET_WORDS, + TEST_QUALITY_ESTIMATOR_WORDS, + TEST_QUALITY_ESTIMATOR_SCORES, + TEST_FORWARD_BACKWARD_FOR_OUTBOUND, +}; + +/// Overload for CL11, convert a read from a stringstream into opmode. +std::istringstream &operator>>(std::istringstream &in, OpMode &mode); + +struct CLIConfig { + using ModelConfigPaths = std::vector; + ModelConfigPaths modelConfigPaths; + bool byteArray; + bool validateByteArray; + size_t numWorkers; + OpMode opMode; +}; + +/// ConfigParser for bergamot. Internally stores config options with CLIConfig. CLI11 parsing binds the parsing code to +/// write to the members of the CLIConfig instance owned by this class. Usage: +/// +/// ```cpp +/// ConfigParser configParser; +/// configParser.parseArgs(argc, argv); +/// auto &config = configParser.getConfig(); +/// ``` +class ConfigParser { + public: + ConfigParser(); + void parseArgs(int argc, char *argv[]); + const CLIConfig &getConfig() { return config_; } + + private: + // Special Options: build-info and version. These are not taken down further, the respective logic executed and + // program exits after. + void addSpecialOptions(CLI::App &app); + void handleSpecialOptions(); + + void addOptionsBoundToConfig(CLI::App &app, CLIConfig &config); + + CLIConfig config_; + CLI::App app_; + + bool build_info_{false}; + bool version_{false}; +}; + +std::shared_ptr parseOptionsFromString(const std::string &config, bool validate = true, + std::string pathsInSameDirAs = ""); +std::shared_ptr parseOptionsFromFilePath(const std::string &config, bool validate = true); } // namespace bergamot } // namespace marian diff --git a/src/translator/request.h b/src/translator/request.h index a2ea1af86..d2645f6d8 100644 --- a/src/translator/request.h +++ b/src/translator/request.h @@ -19,7 +19,7 @@ namespace bergamot { /// A Request is an internal representation used to represent a request after /// processed by TextProcessor into sentences constituted by marian::Words. /// -/// The batching mechanism (Batcher) draws from multiple Requests and compiles +/// The batching mechanism (BatchingPool) draws from multiple Requests and compiles /// sentences into a batch. When a batch completes translation (at /// BatchTranslator, intended in a different thread), backward propogation /// happens through: @@ -60,7 +60,7 @@ class Request { Segment getSegment(size_t index) const; /// For notions of priority among requests, used to enable std::set in - /// Batcher. + /// BatchingPool. bool operator<(const Request &request) const; /// Processes a history obtained after translating in a heterogenous batch @@ -90,7 +90,7 @@ class Request { /// A RequestSentence provides a view to a sentence within a Request. Existence /// of this class allows the sentences and associated information to be kept -/// within Request, while batching mechanism (Batcher) compiles Batch from +/// within Request, while batching mechanism (BatchingPool) compiles Batch from /// RequestSentence-s coming from different Requests. class RequestSentence { public: diff --git a/src/translator/response_builder.h b/src/translator/response_builder.h index 614c7c282..36bae1e9e 100644 --- a/src/translator/response_builder.h +++ b/src/translator/response_builder.h @@ -29,7 +29,7 @@ class ResponseBuilder { /// @param [in] callback: callback with operates on the constructed Response. /// @param [in] qualityEstimator: the QualityEstimator model that can be used /// to provide translation quality probability. - ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source, Vocabs &vocabs, + ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source, const Vocabs &vocabs, std::function callback, const QualityEstimator &qualityEstimator) : responseOptions_(responseOptions), source_(std::move(source)), diff --git a/src/translator/service.cpp b/src/translator/service.cpp index f5996aa45..9de69ba8a 100644 --- a/src/translator/service.cpp +++ b/src/translator/service.cpp @@ -10,88 +10,59 @@ namespace marian { namespace bergamot { -Service::Service(Ptr options, MemoryBundle memoryBundle) - : requestId_(0), - options_(options), - vocabs_(options, std::move(memoryBundle.vocabs)), - text_processor_(options, vocabs_, std::move(memoryBundle.ssplitPrefixFile)), - batcher_(options), - numWorkers_(std::max(1, options->get("cpu-threads"))), - modelMemory_(std::move(memoryBundle.model)), - shortlistMemory_(std::move(memoryBundle.shortlist)), - qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memoryBundle, options))) -#ifdef WASM_COMPATIBLE_SOURCE - , - blocking_translator_(DeviceId(0, DeviceType::cpu), vocabs_, options_, &modelMemory_, &shortlistMemory_) -#endif -{ -#ifdef WASM_COMPATIBLE_SOURCE - blocking_translator_.initialize(); -#else - workers_.reserve(numWorkers_); - for (size_t cpuId = 0; cpuId < numWorkers_; cpuId++) { - workers_.emplace_back([cpuId, this] { - marian::DeviceId deviceId(cpuId, DeviceType::cpu); - BatchTranslator translator(deviceId, vocabs_, options_, &modelMemory_, &shortlistMemory_); - translator.initialize(); - Batch batch; - // Run thread mainloop - while (batcher_ >> batch) { - translator.translate(batch); - } - }); - } -#endif -} +BlockingService::BlockingService(const BlockingService::Config &config) : requestId_(0), batchingPool_() {} -#ifdef WASM_COMPATIBLE_SOURCE -std::vector Service::translateMultiple(std::vector &&inputs, ResponseOptions responseOptions) { - // We queue the individual Requests so they get compiled at batches to be - // efficiently translated. +std::vector BlockingService::translateMultiple(std::shared_ptr translationModel, + std::vector &&sources, + const ResponseOptions &responseOptions) { std::vector responses; - responses.resize(inputs.size()); + responses.resize(sources.size()); - for (size_t i = 0; i < inputs.size(); i++) { + for (size_t i = 0; i < sources.size(); i++) { auto callback = [i, &responses](Response &&response) { responses[i] = std::move(response); }; // - queueRequest(std::move(inputs[i]), std::move(callback), responseOptions); + Ptr request = + translationModel->makeRequest(requestId_++, std::move(sources[i]), callback, responseOptions); + batchingPool_.enqueueRequest(translationModel, request); } Batch batch; - // There's no need to do shutdown here because it's single threaded. - while (batcher_ >> batch) { - blocking_translator_.translate(batch); + Ptr model{nullptr}; + while (batchingPool_.generateBatch(model, batch)) { + model->translateBatch(/*deviceId=*/0, batch); } return responses; } -#endif - -void Service::queueRequest(std::string &&input, std::function &&callback, - ResponseOptions responseOptions) { - Segments segments; - AnnotatedText source; - - text_processor_.process(std::move(input), source, segments); - - ResponseBuilder responseBuilder(responseOptions, std::move(source), vocabs_, std::move(callback), *qualityEstimator_); - Ptr request = New(requestId_++, std::move(segments), std::move(responseBuilder)); - - batcher_.addWholeRequest(request); -} -void Service::translate(std::string &&input, std::function &&callback, - ResponseOptions responseOptions) { - queueRequest(std::move(input), std::move(callback), responseOptions); +AsyncService::AsyncService(const AsyncService::Config &config) : requestId_(0), config_(config), safeBatchingPool_() { + ABORT_IF(config_.numWorkers == 0, "Number of workers should be at least 1 in a threaded workflow"); + workers_.reserve(config_.numWorkers); + for (size_t cpuId = 0; cpuId < config_.numWorkers; cpuId++) { + workers_.emplace_back([cpuId, this] { + // Consumer thread main-loop. Note that this is an infinite-loop unless the monitor is explicitly told to + // shutdown, which happens in the destructor for this class. + Batch batch; + Ptr translationModel{nullptr}; + while (safeBatchingPool_.generateBatch(translationModel, batch)) { + translationModel->translateBatch(cpuId, batch); + } + }); + } } -Service::~Service() { - batcher_.shutdown(); -#ifndef WASM_COMPATIBLE_SOURCE +AsyncService::~AsyncService() { + safeBatchingPool_.shutdown(); for (std::thread &worker : workers_) { assert(worker.joinable()); worker.join(); } -#endif +} + +void AsyncService::translate(std::shared_ptr translationModel, std::string &&source, + CallbackType callback, const ResponseOptions &responseOptions) { + // Producer thread, a call to this function adds new work items. If batches are available, notifies workers waiting. + Ptr request = translationModel->makeRequest(requestId_++, std::move(source), callback, responseOptions); + safeBatchingPool_.enqueueRequest(translationModel, request); } } // namespace bergamot diff --git a/src/translator/service.h b/src/translator/service.h index 3a3d616fc..d37f5c262 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -1,146 +1,116 @@ #ifndef SRC_BERGAMOT_SERVICE_H_ #define SRC_BERGAMOT_SERVICE_H_ -#include "batch_translator.h" +#include +#include +#include + #include "data/types.h" #include "quality_estimator.h" #include "response.h" #include "response_builder.h" #include "text_processor.h" -#include "threadsafe_batcher.h" +#include "threadsafe_batching_pool.h" +#include "translation_model.h" #include "translator/parser.h" #include "vocabs.h" -#ifndef WASM_COMPATIBLE_SOURCE -#include -#endif - -#include -#include - namespace marian { namespace bergamot { -/// This is intended to be similar to the ones provided for training or -/// decoding in ML pipelines with the following additional capabilities: -/// -/// 1. Provision of a request -> response based translation flow unlike the -/// usual a line based translation or decoding provided in most ML frameworks. -/// 2. Internal handling of normalization etc which changes source text to -/// provide to client translation meta-information like alignments consistent -/// with the unnormalized input text. -/// 3. The API splits each text entry into sentences internally, which are then -/// translated independent of each other. The translated sentences are then -/// joined back together and returned in Response. -/// -/// Service exposes methods to instantiate from a string configuration (which -/// can cover most translators) and to translate an incoming blob of text. -/// -/// Optionally Service can be initialized by also passing bytearray memories -/// for purposes of efficiency (which defaults to empty and then reads from -/// file supplied through config). +class BlockingService; +class AsyncService; + +/// See AsyncService. /// -class Service { +/// BlockingService is a not-threaded counterpart of AsyncService which can operate only in a blocking workflow (queue a +/// bunch of texts and optional args to translate, wait till the translation finishes). +class BlockingService { public: - /// Construct Service from Marian options. If memoryBundle is empty, Service is - /// initialized from file-based loading. Otherwise, Service is initialized from - /// the given bytearray memories. - /// @param options Marian options object - /// @param memoryBundle holds all byte-array memories. Can be a set/subset of - /// model, shortlist, vocabs and ssplitPrefixFile or QualityEstimation bytes. Optional. - explicit Service(Ptr options, MemoryBundle memoryBundle = {}); - - /// Construct Service from a string configuration. If memoryBundle is empty, Service is - /// initialized from file-based loading. Otherwise, Service is initialized from - /// the given bytearray memories. - /// @param [in] config string parsable as YAML expected to adhere with marian config - /// @param [in] memoryBundle holds all byte-array memories. Can be a set/subset of - /// model, shortlist, vocabs and ssplitPrefixFile or qualityEstimation bytes. Optional. - explicit Service(const std::string &config, MemoryBundle memoryBundle = {}) - : Service(parseOptions(config, /*validate=*/false), std::move(memoryBundle)) {} - - /// Explicit destructor to clean up after any threads initialized in - /// asynchronous operation mode. - ~Service(); - - /// Translate an input, providing Options to construct Response. This is - /// useful when one has to set/unset alignments or quality in the Response to - /// save compute spent in constructing these objects. - /// - /// @param [in] source: rvalue reference of the string to be translated - /// @param [in] callback: A callback function provided by the client which - /// accepts an rvalue of a Response. Called on successful construction of a - /// Response following completion of translation of source by worker threads. - /// @param [in] responseOptions: Options indicating whether or not to include - /// some member in the Response, also specify any additional configurable - /// parameters. - void translate(std::string &&source, std::function &&callback, - ResponseOptions options = ResponseOptions()); - -#ifdef WASM_COMPATIBLE_SOURCE - /// Translate multiple text-blobs in a single *blocking* API call, providing - /// ResponseOptions which applies across all text-blobs dictating how to - /// construct Response. ResponseOptions can be used to enable/disable - /// additional information like quality-scores, alignments etc. - /// - /// All texts are combined to efficiently construct batches together providing - /// speedups compared to calling translate() indepdently on individual - /// text-blob. Note that there will be minor differences in output when - /// text-blobs are individually translated due to approximations but similar - /// quality nonetheless. If you have async/multithread capabilities, it is - /// recommended to work with callbacks and translate() API. - /// - /// @param [in] source: rvalue reference of the string to be translated - /// @param [in] responseOptions: ResponseOptions indicating whether or not - /// to include some member in the Response, also specify any additional - /// configurable parameters. - std::vector translateMultiple(std::vector &&source, ResponseOptions responseOptions); -#endif + struct Config {}; + /// Construct a BlockingService with configuration loaded from an Options object. Does not require any keys, values to + /// be set. + BlockingService(const BlockingService::Config &config); + + /// Translate multiple text-blobs in a single *blocking* API call, providing ResponseOptions which applies across all + /// text-blobs dictating how to construct Response. ResponseOptions can be used to enable/disable additional + /// information like quality-scores, alignments etc. - /// Returns if model is alignment capable or not. - bool isAlignmentSupported() const { return options_->hasAndNotEmpty("alignment"); } + /// If you have async/multithread capabilities, it is recommended to work with AsyncService instead of this class. + /// Note that due to batching differences and consequent floating-point rounding differences, this is not guaranteed + /// to have the same output as AsyncService. + + /// @param [in] translationModel: TranslationModel to use for the request. + /// @param [in] source: rvalue reference of the string to be translated + /// @param [in] responseOptions: ResponseOptions indicating whether or not to include some member in the Response, + /// also specify any additional configurable parameters. + std::vector translateMultiple(std::shared_ptr translationModel, + std::vector &&source, const ResponseOptions &responseOptions); private: - /// Queue an input for translation. - void queueRequest(std::string &&input, std::function &&callback, ResponseOptions responseOptions); + /// Numbering requests processed through this instance. Used to keep account of arrival times of the request. This + /// allows for using this quantity in priority based ordering. + size_t requestId_; - /// Translates through direct interaction between batcher_ and translators_ + /// An aggregate batching pool associated with an async translating instance, which maintains an aggregate queue of + /// requests compiled from batching-pools of multiple translation models. Not thread-safe. + AggregateBatchingPool batchingPool_; - /// Number of workers to launch. - size_t numWorkers_; + Config config_; +}; - /// Options object holding the options Service was instantiated with. - Ptr options_; +/// Effectively a threadpool, providing an API to take a translation request of a source-text, paramaterized by +/// TranslationModel to be used for translation. Configurability on optional items for the Response corresponding to a +/// request is provisioned through ResponseOptions. +class AsyncService { + public: + struct Config { + size_t numWorkers; + }; + /// Construct an AsyncService with configuration loaded from Options. Expects positive integer value for + /// `cpu-threads`. Additionally requires options which configure AggregateBatchingPool. + AsyncService(const AsyncService::Config &config); + + /// Create a TranslationModel compatible with this instance of Service. Internally assigns how many replicas of + /// backend needed based on worker threads set. See TranslationModel for documentation on other params. + template + Ptr createCompatibleModel(const ConfigType &config, MemoryBundle &&memory = MemoryBundle{}) { + // @TODO: Remove this remove this dependency/coupling. + return New(config, std::move(memory), /*replicas=*/config_.numWorkers); + } + + /// With the supplied TranslationModel, translate an input. A Response is constructed with optional items set/unset + /// indicated via ResponseOptions. Upon completion translation of the input, the client supplied callback is triggered + /// with the constructed Response. Concurrent-calls to this function are safe. + /// + /// @param [in] translationModel: TranslationModel to use for the request. + /// @param [in] source: rvalue reference of the string to be translated. This is available as-is to the client later + /// in the Response corresponding to this call along with the translated-text and meta-data. + /// @param [in] callback: A callback function provided by the client which accepts an rvalue of a Response. + /// @param [in] responseOptions: Options indicating whether or not to include some member in the Response, also + /// specify any additional configurable parameters. + void translate(std::shared_ptr translationModel, std::string &&source, CallbackType callback, + const ResponseOptions &options = ResponseOptions()); + + /// Thread joins and proper shutdown are required to be handled explicitly. + ~AsyncService(); - /// Model memory to load model passed as bytes. - AlignedMemory modelMemory_; // ORDER DEPENDENCY (translators_) - /// Shortlist memory passed as bytes. - AlignedMemory shortlistMemory_; // ORDER DEPENDENCY (translators_) + private: + AsyncService::Config config_; - std::shared_ptr qualityEstimator_; + std::vector workers_; /// Stores requestId of active request. Used to establish /// ordering among requests and logging/book-keeping. + /// Numbering requests processed through this instance. Used to keep account of arrival times of the request. This + /// allows for using this quantity in priority based ordering. size_t requestId_; - /// Store vocabs representing source and target. - Vocabs vocabs_; // ORDER DEPENDENCY (text_processor_) - - /// TextProcesser takes a blob of text and converts into format consumable by - /// the batch-translator and annotates sentences and words. - TextProcessor text_processor_; // ORDER DEPENDENCY (vocabs_) - - /// Batcher handles generation of batches from a request, subject to - /// packing-efficiency and priority optimization heuristics. - ThreadsafeBatcher batcher_; - - // The following constructs are available providing full capabilities on a non - // WASM platform, where one does not have to hide threads. -#ifdef WASM_COMPATIBLE_SOURCE - BatchTranslator blocking_translator_; // ORDER DEPENDENCY (modelMemory_, shortlistMemory_) -#else - std::vector workers_; -#endif // WASM_COMPATIBLE_SOURCE + + /// An aggregate batching pool associated with an async translating instance, which maintains an aggregate queue of + /// requests compiled from batching-pools of multiple translation models. The batching pool is wrapped around one + /// object for thread-safety. + ThreadsafeBatchingPool safeBatchingPool_; }; } // namespace bergamot diff --git a/src/translator/text_processor.cpp b/src/translator/text_processor.cpp index 249ce8cda..b747f79a5 100644 --- a/src/translator/text_processor.cpp +++ b/src/translator/text_processor.cpp @@ -52,7 +52,7 @@ ug::ssplit::SentenceSplitter loadSplitter(const AlignedMemory &memory) { } // namespace -Segment TextProcessor::tokenize(const string_view &segment, std::vector &wordRanges) { +Segment TextProcessor::tokenize(const string_view &segment, std::vector &wordRanges) const { // vocabs_->sources().front() is invoked as we currently only support one source vocab return vocabs_.sources().front()->encodeWithByteRanges(segment, wordRanges, /*addEOS=*/false, /*inference=*/true); } @@ -81,10 +81,10 @@ TextProcessor::TextProcessor(Ptr options, const Vocabs &vocabs, const A void TextProcessor::parseCommonOptions(Ptr options) { maxLengthBreak_ = options->get("max-length-break"); - ssplitMode_ = string2splitmode(options->get("ssplit-mode", "paragraph")); + ssplitMode_ = string2splitmode(options->get("ssplit-mode")); } -void TextProcessor::process(std::string &&input, AnnotatedText &source, Segments &segments) { +void TextProcessor::process(std::string &&input, AnnotatedText &source, Segments &segments) const { source = std::move(AnnotatedText(std::move(input))); std::string_view input_converted(source.text.data(), source.text.size()); auto sentenceStream = ug::ssplit::SentenceStream(input_converted, ssplit_, ssplitMode_); @@ -108,7 +108,7 @@ void TextProcessor::process(std::string &&input, AnnotatedText &source, Segments } void TextProcessor::wrap(Segment &segment, std::vector &wordRanges, Segments &segments, - AnnotatedText &source) { + AnnotatedText &source) const { // There's an EOS token added to the words, manually. SentencePiece/marian-vocab is set to not append EOS. Marian // requires EOS to be at the end as a marker to start translating. So while we're supplied maxLengthBreak_ from // outside, we need to ensure there's space for EOS in each wrapped segment. diff --git a/src/translator/text_processor.h b/src/translator/text_processor.h index 1dc5a4fa7..a6c918c0e 100644 --- a/src/translator/text_processor.h +++ b/src/translator/text_processor.h @@ -47,17 +47,17 @@ class TextProcessor { /// @param [out] segments: marian::Word equivalents of the sentences processed and stored in AnnotatedText for /// consumption of marian translation pipeline. - void process(std::string &&blob, AnnotatedText &source, Segments &segments); + void process(std::string &&blob, AnnotatedText &source, Segments &segments) const; private: void parseCommonOptions(Ptr options); /// Tokenizes an input string, returns Words corresponding. Loads the /// corresponding byte-ranges into tokenRanges. - Segment tokenize(const string_view &input, std::vector &tokenRanges); + Segment tokenize(const string_view &input, std::vector &tokenRanges) const; /// Wrap into sentences of at most maxLengthBreak_ tokens and add to source. - void wrap(Segment &sentence, std::vector &tokenRanges, Segments &segments, AnnotatedText &source); + void wrap(Segment &sentence, std::vector &tokenRanges, Segments &segments, AnnotatedText &source) const; const Vocabs &vocabs_; ///< Vocabularies used to tokenize a sentence size_t maxLengthBreak_; ///< Parameter used to wrap sentences to a maximum number of tokens diff --git a/src/translator/threadsafe_batcher.cpp b/src/translator/threadsafe_batcher.cpp deleted file mode 100644 index 38b6681a9..000000000 --- a/src/translator/threadsafe_batcher.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef WASM_COMPATIBLE_SOURCE -#include "threadsafe_batcher.h" - -#include - -namespace marian { -namespace bergamot { - -ThreadsafeBatcher::ThreadsafeBatcher(Ptr options) : backend_(options), enqueued_(0), shutdown_(false) {} - -ThreadsafeBatcher::~ThreadsafeBatcher() { shutdown(); } - -void ThreadsafeBatcher::addWholeRequest(Ptr request) { - std::unique_lock lock(mutex_); - assert(!shutdown_); - backend_.addWholeRequest(request); - enqueued_ += request->numSegments(); - work_.notify_all(); -} - -void ThreadsafeBatcher::shutdown() { - std::unique_lock lock(mutex_); - shutdown_ = true; - work_.notify_all(); -} - -bool ThreadsafeBatcher::operator>>(Batch &batch) { - std::unique_lock lock(mutex_); - work_.wait(lock, [this]() { return enqueued_ || shutdown_; }); - bool ret = backend_ >> batch; - assert(ret || shutdown_); - enqueued_ -= batch.size(); - return ret; -} - -} // namespace bergamot -} // namespace marian -#endif // WASM_COMPATIBLE_SOURCE diff --git a/src/translator/threadsafe_batcher.h b/src/translator/threadsafe_batcher.h deleted file mode 100644 index d0ab7b1cc..000000000 --- a/src/translator/threadsafe_batcher.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Thread-safe wrapper around batcher. */ -#ifndef SRC_BERGAMOT_THREADSAFE_BATCHER_H_ -#define SRC_BERGAMOT_THREADSAFE_BATCHER_H_ - -#include "batcher.h" -#include "common/options.h" -#include "definitions.h" - -#ifndef WASM_COMPATIBLE_SOURCE -#include -#include -#endif - -namespace marian { -namespace bergamot { - -#ifdef WASM_COMPATIBLE_SOURCE -// No threads, no locks. -typedef Batcher ThreadsafeBatcher; -#else - -class ThreadsafeBatcher { - public: - explicit ThreadsafeBatcher(Ptr options); - - ~ThreadsafeBatcher(); - - // Add sentences to be translated by calling these (see Batcher). When - // done, call shutdown. - void addWholeRequest(Ptr request); - void shutdown(); - - // Get a batch out of the batcher. Return false to shutdown worker. - bool operator>>(Batch &batch); - - private: - Batcher backend_; - - // Number of sentences in backend_; - size_t enqueued_; - - // Are we shutting down? - bool shutdown_; - - // Lock on this object. - std::mutex mutex_; - - // Signaled when there are sentences to translate. - std::condition_variable work_; -}; - -#endif - -} // namespace bergamot -} // namespace marian - -#endif // SRC_BERGAMOT_THREADSAFE_BATCHER_H_ diff --git a/src/translator/threadsafe_batching_pool.cpp b/src/translator/threadsafe_batching_pool.cpp new file mode 100644 index 000000000..0c0d8d85a --- /dev/null +++ b/src/translator/threadsafe_batching_pool.cpp @@ -0,0 +1,49 @@ + +#ifndef SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_IMPL +#error "This is an impl file and must not be included directly!" +#endif + +#include + +namespace marian { +namespace bergamot { + +template +template +ThreadsafeBatchingPool::ThreadsafeBatchingPool(Args &&... args) + : backend_(std::forward(args)...), enqueued_(0), shutdown_(false) {} + +template +ThreadsafeBatchingPool::~ThreadsafeBatchingPool() { + shutdown(); +} + +template +template +void ThreadsafeBatchingPool::enqueueRequest(Args &&... args) { + std::unique_lock lock(mutex_); + assert(!shutdown_); + enqueued_ += backend_.enqueueRequest(std::forward(args)...); + work_.notify_all(); +} + +template +void ThreadsafeBatchingPool::shutdown() { + std::unique_lock lock(mutex_); + shutdown_ = true; + work_.notify_all(); +} + +template +template +size_t ThreadsafeBatchingPool::generateBatch(Args &&... args) { + std::unique_lock lock(mutex_); + work_.wait(lock, [this]() { return enqueued_ || shutdown_; }); + size_t sentencesInBatch = backend_.generateBatch(std::forward(args)...); + assert(sentencesInBatch > 0 || shutdown_); + enqueued_ -= sentencesInBatch; + return sentencesInBatch; +} + +} // namespace bergamot +} // namespace marian diff --git a/src/translator/threadsafe_batching_pool.h b/src/translator/threadsafe_batching_pool.h new file mode 100644 index 000000000..96896eab3 --- /dev/null +++ b/src/translator/threadsafe_batching_pool.h @@ -0,0 +1,71 @@ +/* Thread-safe wrapper around BatchingPool or AggregateBatchingPool, made generic with templates. */ +#ifndef SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_H_ +#define SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_H_ + +#include +#include + +#include "aggregate_batching_pool.h" +#include "batching_pool.h" +#include "common/options.h" +#include "definitions.h" +#include "translation_model.h" + +namespace marian { +namespace bergamot { + +/// The following mechanism operates in a multithreaded async-workflow guarding access to the pushes to the structure +/// keeping sentences bucketed by length and sorted by priority. +/// +/// This is a wrap of a producer-consumer queue implemented as a monitor, where there is a mutex guarding the +/// underlying data structure (BatchingPoolType) and (worker/consumer) threads waiting on a condition variable and the +/// queuing thread producing and notifying waiting threads (consumers) through the same condition variable. +/// +/// Originally written by for a single model (where items are produce: Request, consume: Batch), converted to +/// also work for multiple models where items are produce: (TranslationModel, Request), consume: (TranlsationModel, +/// Batch). This is accomplished by template parameter packs. +/// +/// Requires BatchingPoolType to implement the following: +/// +/// * produce: `size_t enqueueRequest(...)` (returns number elements produced) +/// * consume: `size_t generateBatch(...)` (returns number of elements available to be consumed) + +template +class ThreadsafeBatchingPool { + public: + template + ThreadsafeBatchingPool(Args &&... args); + ~ThreadsafeBatchingPool(); + + template + void enqueueRequest(Args &&... args); + + template + size_t generateBatch(Args &&... args); + + void shutdown(); + + private: + BatchingPoolType backend_; + + // Number of sentences in backend_; + size_t enqueued_; + + // Are we shutting down? + bool shutdown_; + + // Lock on this object. + std::mutex mutex_; + + // Signaled when there are sentences to translate. + std::condition_variable work_; +}; + +} // namespace bergamot +} // namespace marian + +#define SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_IMPL +#include "threadsafe_batching_pool.cpp" +#undef SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_IMPL + +#endif // SRC_BERGAMOT_THREADSAFE_BATCHING_POOL_H_ diff --git a/src/translator/translation_model.cpp b/src/translator/translation_model.cpp new file mode 100644 index 000000000..5a2739542 --- /dev/null +++ b/src/translator/translation_model.cpp @@ -0,0 +1,173 @@ +#include "translation_model.h" + +#include "batch.h" +#include "byte_array_util.h" +#include "common/logging.h" +#include "data/corpus.h" +#include "data/text_input.h" +#include "parser.h" +#include "translator/beam_search.h" + +namespace marian { +namespace bergamot { + +TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory /*=MemoryBundle{}*/, + size_t replicas /*=1*/) + : options_(options), + memory_(std::move(memory)), + vocabs_(options, std::move(memory_.vocabs)), + textProcessor_(options, vocabs_, std::move(memory_.ssplitPrefixFile)), + batchingPool_(options), + qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memory, options))) { + ABORT_IF(replicas == 0, "At least one replica needs to be created."); + backend_.resize(replicas); + + if (options_->hasAndNotEmpty("shortlist")) { + int srcIdx = 0, trgIdx = 1; + bool shared_vcb = + vocabs_.sources().front() == + vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab + if (memory_.shortlist.size() > 0 && memory_.shortlist.begin() != nullptr) { + bool check = options_->get("check-bytearray", false); + shortlistGenerator_ = New(memory_.shortlist.begin(), memory_.shortlist.size(), + vocabs_.sources().front(), vocabs_.target(), srcIdx, + trgIdx, shared_vcb, check); + } else { + // Changed to BinaryShortlistGenerator to enable loading binary shortlist file + // This class also supports text shortlist file + shortlistGenerator_ = New(options_, vocabs_.sources().front(), vocabs_.target(), + srcIdx, trgIdx, shared_vcb); + } + } + + for (size_t idx = 0; idx < replicas; idx++) { + loadBackend(idx); + } +} + +void TranslationModel::loadBackend(size_t idx) { + auto &graph = backend_[idx].graph; + auto &scorerEnsemble = backend_[idx].scorerEnsemble; + + marian::DeviceId device_(idx, DeviceType::cpu); + graph = New(/*inference=*/true); // set the graph to be inference only + auto prec = options_->get>("precision", {"float32"}); + graph->setDefaultElementType(typeFromString(prec[0])); + graph->setDevice(device_); + graph->getBackend()->configureDevice(options_); + graph->reserveWorkspaceMB(options_->get("workspace")); + + // Marian Model: Load from memoryBundle or shortList + if (memory_.model.size() > 0 && + memory_.model.begin() != + nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the + // model from there, as opposed to from reading in the config file + ABORT_IF((uintptr_t)memory_.model.begin() % 256 != 0, + "The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it."); + if (options_->get("check-bytearray", false)) { + ABORT_IF(!validateBinaryModel(memory_.model, memory_.model.size()), + "The binary file is invalid. Incomplete or corrupted download?"); + } + const std::vector container = { + memory_.model.begin()}; // Marian supports multiple models initialised in this manner hence std::vector. + // However we will only ever use 1 during decoding. + scorerEnsemble = createScorers(options_, container); + } else { + scorerEnsemble = createScorers(options_); + } + for (auto scorer : scorerEnsemble) { + scorer->init(graph); + if (shortlistGenerator_) { + scorer->setShortlistGenerator(shortlistGenerator_); + } + } + graph->forward(); +} + +// Make request process is shared between Async and Blocking workflow of translating. +Ptr TranslationModel::makeRequest(size_t requestId, std::string &&source, CallbackType callback, + const ResponseOptions &responseOptions) { + Segments segments; + AnnotatedText annotatedSource; + + textProcessor_.process(std::move(source), annotatedSource, segments); + ResponseBuilder responseBuilder(responseOptions, std::move(annotatedSource), vocabs_, callback, *qualityEstimator_); + + Ptr request = New(requestId, std::move(segments), std::move(responseBuilder)); + return request; +} + +Ptr TranslationModel::convertToMarianBatch(Batch &batch) { + std::vector batchVector; + auto &sentences = batch.sentences(); + + size_t batchSequenceNumber{0}; + for (auto &sentence : sentences) { + data::SentenceTuple sentence_tuple(batchSequenceNumber); + Segment segment = sentence.getUnderlyingSegment(); + sentence_tuple.push_back(segment); + batchVector.push_back(sentence_tuple); + + ++batchSequenceNumber; + } + + // Usually one would expect inputs to be [B x T], where B = batch-size and T = max seq-len among the sentences in the + // batch. However, marian's library supports multi-source and ensembling through different source-vocabulary but same + // target vocabulary. This means the inputs are 3 dimensional when converted into marian's library formatted batches. + // + // Consequently B x T projects to N x B x T, where N = ensemble size. This adaptation does not fully force the idea of + // N = 1 (the code remains general, but N iterates only from 0-1 in the nested loop). + + size_t batchSize = batchVector.size(); + + std::vector sentenceIds; + std::vector maxDims; + + for (auto &example : batchVector) { + if (maxDims.size() < example.size()) { + maxDims.resize(example.size(), 0); + } + for (size_t i = 0; i < example.size(); ++i) { + if (example[i].size() > static_cast(maxDims[i])) { + maxDims[i] = static_cast(example[i].size()); + } + } + sentenceIds.push_back(example.getId()); + } + + using SubBatch = marian::data::SubBatch; + std::vector> subBatches; + for (size_t j = 0; j < maxDims.size(); ++j) { + subBatches.emplace_back(New(batchSize, maxDims[j], vocabs_.sources().at(j))); + } + + std::vector words(maxDims.size(), 0); + for (size_t i = 0; i < batchSize; ++i) { + for (size_t j = 0; j < maxDims.size(); ++j) { + for (size_t k = 0; k < batchVector[i][j].size(); ++k) { + subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k]; + subBatches[j]->mask()[k * batchSize + i] = 1.f; + words[j]++; + } + } + } + + for (size_t j = 0; j < maxDims.size(); ++j) { + subBatches[j]->setWords(words[j]); + } + + using CorpusBatch = marian::data::CorpusBatch; + Ptr corpusBatch = New(subBatches); + corpusBatch->setSentenceIds(sentenceIds); + return corpusBatch; +} + +void TranslationModel::translateBatch(size_t deviceId, Batch &batch) { + auto &backend = backend_[deviceId]; + BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target()); + Histories histories = search.search(backend.graph, convertToMarianBatch(batch)); + batch.completeBatch(histories); +} + +} // namespace bergamot +} // namespace marian diff --git a/src/translator/translation_model.h b/src/translator/translation_model.h new file mode 100644 index 000000000..599e6c707 --- /dev/null +++ b/src/translator/translation_model.h @@ -0,0 +1,122 @@ +#ifndef SRC_BERGAMOT_TRANSLATION_MODEL_H_ +#define SRC_BERGAMOT_TRANSLATION_MODEL_H_ + +#include +#include + +#include "batch.h" +#include "batching_pool.h" +#include "common/utils.h" +#include "data/shortlist.h" +#include "definitions.h" +#include "parser.h" +#include "request.h" +#include "text_processor.h" +#include "translator/history.h" +#include "translator/scorers.h" +#include "vocabs.h" + +namespace marian { +namespace bergamot { + +/// A TranslationModel is associated with the translation of a single language direction. Holds the graph and other +/// structures required to run the forward pass of the neural network, along with preprocessing logic (TextProcessor) +/// and a BatchingPool to create batches that are to be used in conjuction with an instance. +/// +/// Thread-safety is not handled here, but the methods are available at granularity enough to be used in threaded async +/// workflow for translation. + +class TranslationModel { + public: + using Config = Ptr; + using ShortlistGenerator = Ptr; + + /// Equivalent to options based constructor, where `options` is parsed from string configuration. Configuration can be + /// JSON or YAML. Keys expected correspond to those of `marian-decoder`, available at + /// https://marian-nmt.github.io/docs/cmd/marian-decoder/ + /// + /// Note that `replicas` is not stable. This is a temporary workaround while a more daunting task of separating + /// workspace from TranslationModel and binding it to threads is to be undertaken separately. Until the separation is + /// achieved, both TranslationModel and Service will need to be aware of workers. This is expected to be resolved + /// eventually, with only Service having the knowledge of how many workers are active. + /// + /// WebAssembly uses only single-thread, and we can hardcode replicas = 1 and use it anywhere and (client) needn't be + /// aware of this ugliness at the moment, thus providing a stable API solely for WebAssembly single-threaded modus + /// operandi. + /// + /// TODO(@jerinphilip): Clean this up. + TranslationModel(const std::string& config, MemoryBundle&& memory, size_t replicas = 1) + : TranslationModel(parseOptionsFromString(config, /*validate=*/false), std::move(memory), replicas){}; + + /// Construct TranslationModel from marian-options. If memory is empty, TranslationModel is initialized from + /// paths available in the options object, backed by filesystem. Otherwise, TranslationModel is initialized from the + /// given MemoryBundle composed of AlignedMemory holding equivalent parameters. + /// + /// @param [in] options: Marian options object. + /// @param [in] memory: MemoryBundle object holding memory buffers containing parameters to build MarianBackend, + /// ShortlistGenerator, Vocabs and SentenceSplitter. + TranslationModel(const Config& options, MemoryBundle&& memory = MemoryBundle{}, size_t replicas = 1); + + /// Make a Request to be translated by this TranslationModel instance. + /// @param [in] requestId: Unique identifier associated with this request, available from Service. + /// @param [in] source: Source text to be translated. Ownership is accepted and eventually returned to the client in + /// Response corresponding to the Request created here. + /// @param [in] callback: Callback (from client) to be issued upon completion of translation of all sentences in the + /// created Request. + /// @param [in] responseOptions: Configuration used to prepare the Response corresponding to the created request. + // @returns Request created from the query parameters wrapped within a shared-pointer. + Ptr makeRequest(size_t requestId, std::string&& source, CallbackType callback, + const ResponseOptions& responseOptions); + + /// Relays a request to the batching-pool specific to this translation model. + /// @param [in] request: Request constructed through makeRequest + void enqueueRequest(Ptr request) { batchingPool_.enqueueRequest(request); }; + + /// Generates a batch from the batching-pool for this translation model, compiling from several active requests. Note + /// that it is possible that calls to this method can give empty-batches. + /// + /// @param [out] batch: Batch to write a generated batch on to. + /// @returns number of sentences that constitute the Batch. + size_t generateBatch(Batch& batch) { return batchingPool_.generateBatch(batch); } + + /// Translate a batch generated with generateBatch + /// + /// @param [in] deviceId: There are replicas of backend created for use in each worker thread. deviceId indicates + /// which replica to use. + /// @param [in] batch: A batch generated from generateBatch from the same TranslationModel instance. + void translateBatch(size_t deviceId, Batch& batch); + + private: + Config options_; + MemoryBundle memory_; + Vocabs vocabs_; + TextProcessor textProcessor_; + + /// Maintains sentences from multiple requests bucketed by length and sorted by priority in each bucket. + BatchingPool batchingPool_; + + /// A package of marian-entities which form a backend to translate. + struct MarianBackend { + using Graph = Ptr; + using ScorerEnsemble = std::vector>; + + Graph graph; + ScorerEnsemble scorerEnsemble; + }; + + // ShortlistGenerator is purely const, we don't need one per thread. + ShortlistGenerator shortlistGenerator_; + + /// Hold replicas of the backend (graph, scorers, shortlist) for use in each thread. + /// Controlled and consistent external access via graph(id), scorerEnsemble(id), + std::vector backend_; + std::shared_ptr qualityEstimator_; + + void loadBackend(size_t idx); + Ptr convertToMarianBatch(Batch& batch); +}; + +} // namespace bergamot +} // namespace marian + +#endif // SRC_BERGAMOT_TRANSLATION_MODEL_H_ diff --git a/wasm/bindings/service_bindings.cpp b/wasm/bindings/service_bindings.cpp index 416a318ad..d05cf57cf 100644 --- a/wasm/bindings/service_bindings.cpp +++ b/wasm/bindings/service_bindings.cpp @@ -8,8 +8,10 @@ using namespace emscripten; -typedef marian::bergamot::Service Service; -typedef marian::bergamot::AlignedMemory AlignedMemory; +using BlockingService = marian::bergamot::BlockingService; +using TranslationModel = marian::bergamot::TranslationModel; +using AlignedMemory = marian::bergamot::AlignedMemory; +using MemoryBundle = marian::bergamot::MemoryBundle; val getByteArrayView(AlignedMemory& alignedMemory) { return val(typed_memory_view(alignedMemory.size(), alignedMemory.as())); @@ -42,9 +44,9 @@ std::vector> prepareVocabsSmartMemories(std::vect return vocabsSmartMemories; } -marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, - std::vector uniqueVocabsMemories) { - marian::bergamot::MemoryBundle memoryBundle; +MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, + std::vector uniqueVocabsMemories) { + MemoryBundle memoryBundle; memoryBundle.model = std::move(*modelMemory); memoryBundle.shortlist = std::move(*shortlistMemory); memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories)); @@ -52,18 +54,31 @@ marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, A return memoryBundle; } -Service* ServiceFactory(const std::string& config, AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, - std::vector uniqueVocabsMemories) { - return new Service(config, std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories))); +// This allows only shared_ptrs to be operational in JavaScript, according to emscripten. +// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers +std::shared_ptr TranslationModelFactory(const std::string& config, AlignedMemory* model, + AlignedMemory* shortlist, + std::vector vocabs) { + MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs); + return std::make_shared(config, std::move(memoryBundle)); } -EMSCRIPTEN_BINDINGS(translation_service) { - class_("Service") - .constructor(&ServiceFactory, allow_raw_pointers()) - .function("translate", &Service::translateMultiple) - .function("isAlignmentSupported", &Service::isAlignmentSupported); - // ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is - // translate. If and when async comes, we can be done with this inconsistency. +EMSCRIPTEN_BINDINGS(translation_model) { + class_("TranslationModel") + .smart_ptr_constructor("TranslationModel", &TranslationModelFactory, allow_raw_pointers()); +} + +EMSCRIPTEN_BINDINGS(blocking_service_config) { + value_object("BlockingServiceConfig"); + // .field("name", &BlockingService::Config::name") + // The above is a future hook. Note that more will come - for cache, for workspace-size or graph details limits on + // aggregate-batching etc. +} + +EMSCRIPTEN_BINDINGS(blocking_service) { + class_("BlockingService") + .constructor() + .function("translate", &BlockingService::translateMultiple); register_vector("VectorString"); }