Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating workspace from TranslationModel #223

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/translator/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,30 @@
namespace marian {
namespace bergamot {

BlockingService::BlockingService(const BlockingService::Config &config) : requestId_(0), batchingPool_() {}
Ptr<Options> horribleOptionsHack() {
Ptr<Options> options = std::make_shared<Options>();
options->set<std::string>("gemm-precision", "int8shiftAlphaAll");
options->set<bool>("dump-quantmult", false);
options->set<float>("clip-gemm", 1.0);
options->set<bool>("use-legacy-batching", false);
return options;
}

Graph createGraph(size_t workspaceSizeInMB, size_t cpuId) {
marian::DeviceId device_(cpuId, DeviceType::cpu);
Graph graph = New<ExpressionGraph>(/*inference=*/true); // set the graph to be inference only
// auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
std::vector<std::string> prec = {"float32"}; // Hardcode now.

graph->setDefaultElementType(typeFromString(prec[0]));
graph->setDevice(device_);
graph->getBackend()->configureDevice(horribleOptionsHack());
graph->reserveWorkspaceMB(workspaceSizeInMB);
return graph;
}

BlockingService::BlockingService(const BlockingService::Config &config)
: requestId_(0), batchingPool_(), graph_(createGraph(config.workspaceSizeInMB, /*cpuId=*/0)) {}

std::vector<Response> BlockingService::translateMultiple(std::shared_ptr<TranslationModel> translationModel,
std::vector<std::string> &&sources,
Expand All @@ -28,23 +51,27 @@ std::vector<Response> BlockingService::translateMultiple(std::shared_ptr<Transla
Batch batch;
Ptr<TranslationModel> model{nullptr};
while (batchingPool_.generateBatch(model, batch)) {
model->translateBatch(/*deviceId=*/0, batch);
model->translateBatch(graph_, batch);
}

return responses;
}

AsyncService::AsyncService(const AsyncService::Config &config) : requestId_(0), config_(config), safeBatchingPool_() {
ABORT_IF(config_.numWorkers == 0, "Number of workers should be at least 1 in a threaded workflow");

workers_.reserve(config_.numWorkers);
graphs_.resize(config_.numWorkers, nullptr);
for (size_t cpuId = 0; cpuId < config_.numWorkers; cpuId++) {
workers_.emplace_back([cpuId, this] {
// Consumer thread main-loop. Note that this is an infinite-loop unless the monitor is explicitly told to
// shutdown, which happens in the destructor for this class.
graphs_[cpuId] = createGraph(config_.workspaceSizeInMB, cpuId);

Batch batch;
Ptr<TranslationModel> translationModel{nullptr};
while (safeBatchingPool_.generateBatch(translationModel, batch)) {
translationModel->translateBatch(cpuId, batch);
translationModel->translateBatch(graphs_[cpuId], batch);
}
});
}
Expand Down
10 changes: 9 additions & 1 deletion src/translator/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
namespace marian {
namespace bergamot {

using Graph = Ptr<ExpressionGraph>;

class BlockingService;
class AsyncService;

Expand All @@ -27,7 +29,9 @@ class AsyncService;
/// bunch of texts and optional args to translate, wait till the translation finishes).
class BlockingService {
public:
struct Config {};
struct Config {
size_t workspaceSizeInMB{1024};
};
/// Construct a BlockingService with configuration loaded from an Options object. Does not require any keys, values to
/// be set.
BlockingService(const BlockingService::Config &config);
Expand Down Expand Up @@ -56,6 +60,8 @@ class BlockingService {
/// requests compiled from batching-pools of multiple translation models. Not thread-safe.
AggregateBatchingPool batchingPool_;

Graph graph_;

Config config_;
};

Expand All @@ -66,6 +72,7 @@ class AsyncService {
public:
struct Config {
size_t numWorkers;
size_t workspaceSizeInMB{1024};
};
/// Construct an AsyncService with configuration loaded from Options. Expects positive integer value for
/// `cpu-threads`. Additionally requires options which configure AggregateBatchingPool.
Expand Down Expand Up @@ -98,6 +105,7 @@ class AsyncService {
private:
AsyncService::Config config_;

std::vector<Graph> graphs_;
std::vector<std::thread> workers_;

/// Stores requestId of active request. Used to establish
Expand Down
52 changes: 24 additions & 28 deletions src/translator/translation_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory
batchingPool_(options),
qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memory, options))) {
ABORT_IF(replicas == 0, "At least one replica needs to be created.");
backend_.resize(replicas);
scorerReplicas_.resize(replicas);

if (options_->hasAndNotEmpty("shortlist")) {
int srcIdx = 0, trgIdx = 1;
Expand All @@ -39,24 +39,10 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory
srcIdx, trgIdx, shared_vcb);
}
}

for (size_t idx = 0; idx < replicas; idx++) {
loadBackend(idx);
}
}

void TranslationModel::loadBackend(size_t idx) {
auto &graph = backend_[idx].graph;
auto &scorerEnsemble = backend_[idx].scorerEnsemble;

marian::DeviceId device_(idx, DeviceType::cpu);
graph = New<ExpressionGraph>(/*inference=*/true); // set the graph to be inference only
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(prec[0]));
graph->setDevice(device_);
graph->getBackend()->configureDevice(options_);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));

TranslationModel::ScorerEnsemble TranslationModel::buildScorerEnsemble() {
ScorerEnsemble scorerEnsemble;
// Marian Model: Load from memoryBundle or shortList
if (memory_.model.size() > 0 &&
memory_.model.begin() !=
Expand All @@ -75,13 +61,8 @@ void TranslationModel::loadBackend(size_t idx) {
} else {
scorerEnsemble = createScorers(options_);
}
for (auto scorer : scorerEnsemble) {
scorer->init(graph);
if (shortlistGenerator_) {
scorer->setShortlistGenerator(shortlistGenerator_);
}
}
graph->forward();

return scorerEnsemble;
}

// Make request process is shared between Async and Blocking workflow of translating.
Expand Down Expand Up @@ -162,10 +143,25 @@ Ptr<marian::data::CorpusBatch> TranslationModel::convertToMarianBatch(Batch &bat
return corpusBatch;
}

void TranslationModel::translateBatch(size_t deviceId, Batch &batch) {
auto &backend = backend_[deviceId];
BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target());
Histories histories = search.search(backend.graph, convertToMarianBatch(batch));
void TranslationModel::initScorerOntoGraph(ScorerEnsemble &scorerEnsemble, Graph &graph) {
for (auto scorer : scorerEnsemble) {
scorer->init(graph);
if (shortlistGenerator_) {
scorer->setShortlistGenerator(shortlistGenerator_);
}
}
graph->forward();
}

void TranslationModel::translateBatch(Graph &graph, Batch &batch) {
DeviceId deviceId = graph->getDeviceId();
auto &scorerEnsemble = scorerReplicas_[deviceId.no];
if (scorerEnsemble.empty()) {
scorerEnsemble = buildScorerEnsemble();
initScorerOntoGraph(scorerEnsemble, graph);
}
BeamSearch search(options_, scorerEnsemble, vocabs_.target());
Histories histories = search.search(graph, convertToMarianBatch(batch));
batch.completeBatch(histories);
}

Expand Down
18 changes: 6 additions & 12 deletions src/translator/translation_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class TranslationModel {
public:
using Config = Ptr<Options>;
using ShortlistGenerator = Ptr<data::ShortlistGenerator const>;
using ScorerEnsemble = std::vector<Ptr<Scorer>>;
using Graph = Ptr<ExpressionGraph>;

/// Equivalent to options based constructor, where `options` is parsed from string configuration. Configuration can be
/// JSON or YAML. Keys expected correspond to those of `marian-decoder`, available at
Expand Down Expand Up @@ -84,7 +86,7 @@ class TranslationModel {
/// @param [in] deviceId: There are replicas of backend created for use in each worker thread. deviceId indicates
/// which replica to use.
/// @param [in] batch: A batch generated from generateBatch from the same TranslationModel instance.
void translateBatch(size_t deviceId, Batch& batch);
void translateBatch(Graph& graph, Batch& batch);

private:
Config options_;
Expand All @@ -95,24 +97,16 @@ class TranslationModel {
/// Maintains sentences from multiple requests bucketed by length and sorted by priority in each bucket.
BatchingPool batchingPool_;

/// A package of marian-entities which form a backend to translate.
struct MarianBackend {
using Graph = Ptr<ExpressionGraph>;
using ScorerEnsemble = std::vector<Ptr<Scorer>>;

Graph graph;
ScorerEnsemble scorerEnsemble;
};

// ShortlistGenerator is purely const, we don't need one per thread.
ShortlistGenerator shortlistGenerator_;

/// Hold replicas of the backend (graph, scorers, shortlist) for use in each thread.
/// Controlled and consistent external access via graph(id), scorerEnsemble(id),
std::vector<MarianBackend> backend_;
std::vector<ScorerEnsemble> scorerReplicas_;
std::shared_ptr<QualityEstimator> qualityEstimator_;

void loadBackend(size_t idx);
ScorerEnsemble buildScorerEnsemble();
void initScorerOntoGraph(ScorerEnsemble& scorerEnsemble, Graph& graph);
Ptr<marian::data::CorpusBatch> convertToMarianBatch(Batch& batch);
};

Expand Down