diff --git a/src/translator/service.cpp b/src/translator/service.cpp index 9de69ba8a..150310fc1 100644 --- a/src/translator/service.cpp +++ b/src/translator/service.cpp @@ -10,7 +10,30 @@ namespace marian { namespace bergamot { -BlockingService::BlockingService(const BlockingService::Config &config) : requestId_(0), batchingPool_() {} +Ptr horribleOptionsHack() { + Ptr options = std::make_shared(); + options->set("gemm-precision", "int8shiftAlphaAll"); + options->set("dump-quantmult", false); + options->set("clip-gemm", 1.0); + options->set("use-legacy-batching", false); + return options; +} + +Graph createGraph(size_t workspaceSizeInMB, size_t cpuId) { + marian::DeviceId device_(cpuId, DeviceType::cpu); + Graph graph = New(/*inference=*/true); // set the graph to be inference only + // auto prec = options_->get>("precision", {"float32"}); + std::vector prec = {"float32"}; // Hardcode now. + + graph->setDefaultElementType(typeFromString(prec[0])); + graph->setDevice(device_); + graph->getBackend()->configureDevice(horribleOptionsHack()); + graph->reserveWorkspaceMB(workspaceSizeInMB); + return graph; +} + +BlockingService::BlockingService(const BlockingService::Config &config) + : requestId_(0), batchingPool_(), graph_(createGraph(config.workspaceSizeInMB, /*cpuId=*/0)) {} std::vector BlockingService::translateMultiple(std::shared_ptr translationModel, std::vector &&sources, @@ -28,7 +51,7 @@ std::vector BlockingService::translateMultiple(std::shared_ptr model{nullptr}; while (batchingPool_.generateBatch(model, batch)) { - model->translateBatch(/*deviceId=*/0, batch); + model->translateBatch(graph_, batch); } return responses; @@ -36,15 +59,19 @@ std::vector BlockingService::translateMultiple(std::shared_ptr translationModel{nullptr}; while (safeBatchingPool_.generateBatch(translationModel, batch)) { - translationModel->translateBatch(cpuId, batch); + translationModel->translateBatch(graphs_[cpuId], batch); } }); } diff --git a/src/translator/service.h b/src/translator/service.h index d37f5c262..ac4996d0d 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -18,6 +18,8 @@ namespace marian { namespace bergamot { +using Graph = Ptr; + class BlockingService; class AsyncService; @@ -27,7 +29,9 @@ class AsyncService; /// bunch of texts and optional args to translate, wait till the translation finishes). class BlockingService { public: - struct Config {}; + struct Config { + size_t workspaceSizeInMB{1024}; + }; /// Construct a BlockingService with configuration loaded from an Options object. Does not require any keys, values to /// be set. BlockingService(const BlockingService::Config &config); @@ -56,6 +60,8 @@ class BlockingService { /// requests compiled from batching-pools of multiple translation models. Not thread-safe. AggregateBatchingPool batchingPool_; + Graph graph_; + Config config_; }; @@ -66,6 +72,7 @@ class AsyncService { public: struct Config { size_t numWorkers; + size_t workspaceSizeInMB{1024}; }; /// Construct an AsyncService with configuration loaded from Options. Expects positive integer value for /// `cpu-threads`. Additionally requires options which configure AggregateBatchingPool. @@ -98,6 +105,7 @@ class AsyncService { private: AsyncService::Config config_; + std::vector graphs_; std::vector workers_; /// Stores requestId of active request. Used to establish diff --git a/src/translator/translation_model.cpp b/src/translator/translation_model.cpp index 5a2739542..8f417e4d6 100644 --- a/src/translator/translation_model.cpp +++ b/src/translator/translation_model.cpp @@ -20,7 +20,7 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory batchingPool_(options), qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memory, options))) { ABORT_IF(replicas == 0, "At least one replica needs to be created."); - backend_.resize(replicas); + scorerReplicas_.resize(replicas); if (options_->hasAndNotEmpty("shortlist")) { int srcIdx = 0, trgIdx = 1; @@ -39,24 +39,10 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory srcIdx, trgIdx, shared_vcb); } } - - for (size_t idx = 0; idx < replicas; idx++) { - loadBackend(idx); - } } -void TranslationModel::loadBackend(size_t idx) { - auto &graph = backend_[idx].graph; - auto &scorerEnsemble = backend_[idx].scorerEnsemble; - - marian::DeviceId device_(idx, DeviceType::cpu); - graph = New(/*inference=*/true); // set the graph to be inference only - auto prec = options_->get>("precision", {"float32"}); - graph->setDefaultElementType(typeFromString(prec[0])); - graph->setDevice(device_); - graph->getBackend()->configureDevice(options_); - graph->reserveWorkspaceMB(options_->get("workspace")); - +TranslationModel::ScorerEnsemble TranslationModel::buildScorerEnsemble() { + ScorerEnsemble scorerEnsemble; // Marian Model: Load from memoryBundle or shortList if (memory_.model.size() > 0 && memory_.model.begin() != @@ -75,13 +61,8 @@ void TranslationModel::loadBackend(size_t idx) { } else { scorerEnsemble = createScorers(options_); } - for (auto scorer : scorerEnsemble) { - scorer->init(graph); - if (shortlistGenerator_) { - scorer->setShortlistGenerator(shortlistGenerator_); - } - } - graph->forward(); + + return scorerEnsemble; } // Make request process is shared between Async and Blocking workflow of translating. @@ -162,10 +143,25 @@ Ptr TranslationModel::convertToMarianBatch(Batch &bat return corpusBatch; } -void TranslationModel::translateBatch(size_t deviceId, Batch &batch) { - auto &backend = backend_[deviceId]; - BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target()); - Histories histories = search.search(backend.graph, convertToMarianBatch(batch)); +void TranslationModel::initScorerOntoGraph(ScorerEnsemble &scorerEnsemble, Graph &graph) { + for (auto scorer : scorerEnsemble) { + scorer->init(graph); + if (shortlistGenerator_) { + scorer->setShortlistGenerator(shortlistGenerator_); + } + } + graph->forward(); +} + +void TranslationModel::translateBatch(Graph &graph, Batch &batch) { + DeviceId deviceId = graph->getDeviceId(); + auto &scorerEnsemble = scorerReplicas_[deviceId.no]; + if (scorerEnsemble.empty()) { + scorerEnsemble = buildScorerEnsemble(); + initScorerOntoGraph(scorerEnsemble, graph); + } + BeamSearch search(options_, scorerEnsemble, vocabs_.target()); + Histories histories = search.search(graph, convertToMarianBatch(batch)); batch.completeBatch(histories); } diff --git a/src/translator/translation_model.h b/src/translator/translation_model.h index 599e6c707..ef7b9d34d 100644 --- a/src/translator/translation_model.h +++ b/src/translator/translation_model.h @@ -30,6 +30,8 @@ class TranslationModel { public: using Config = Ptr; using ShortlistGenerator = Ptr; + using ScorerEnsemble = std::vector>; + using Graph = Ptr; /// Equivalent to options based constructor, where `options` is parsed from string configuration. Configuration can be /// JSON or YAML. Keys expected correspond to those of `marian-decoder`, available at @@ -84,7 +86,7 @@ class TranslationModel { /// @param [in] deviceId: There are replicas of backend created for use in each worker thread. deviceId indicates /// which replica to use. /// @param [in] batch: A batch generated from generateBatch from the same TranslationModel instance. - void translateBatch(size_t deviceId, Batch& batch); + void translateBatch(Graph& graph, Batch& batch); private: Config options_; @@ -95,24 +97,16 @@ class TranslationModel { /// Maintains sentences from multiple requests bucketed by length and sorted by priority in each bucket. BatchingPool batchingPool_; - /// A package of marian-entities which form a backend to translate. - struct MarianBackend { - using Graph = Ptr; - using ScorerEnsemble = std::vector>; - - Graph graph; - ScorerEnsemble scorerEnsemble; - }; - // ShortlistGenerator is purely const, we don't need one per thread. ShortlistGenerator shortlistGenerator_; /// Hold replicas of the backend (graph, scorers, shortlist) for use in each thread. /// Controlled and consistent external access via graph(id), scorerEnsemble(id), - std::vector backend_; + std::vector scorerReplicas_; std::shared_ptr qualityEstimator_; - void loadBackend(size_t idx); + ScorerEnsemble buildScorerEnsemble(); + void initScorerOntoGraph(ScorerEnsemble& scorerEnsemble, Graph& graph); Ptr convertToMarianBatch(Batch& batch); };