Skip to content

Commit

Permalink
Make eigen threads all share the same copy of the model
Browse files Browse the repository at this point in the history
  • Loading branch information
lightvector committed Dec 17, 2023
1 parent ce4e410 commit 85b6dce
Showing 1 changed file with 48 additions and 18 deletions.
66 changes: 48 additions & 18 deletions cpp/neuralnet/eigenbackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,32 @@ static size_t roundUpToMultiple(size_t size, size_t ofThis) {

// --------------------------------------------------------------------------------------------------------------

struct Model;

struct ComputeContext {
const int nnXLen;
const int nnYLen;

std::mutex cachedModelsMutex;
std::map<std::string,std::shared_ptr<const Model>> cachedModels;
std::map<std::string,int> cachedModelsRefCount;

ComputeContext() = delete;
ComputeContext(const ComputeContext&) = delete;
ComputeContext& operator=(const ComputeContext&) = delete;

ComputeContext(int nnX, int nnY)
: nnXLen(nnX),
nnYLen(nnY)
nnYLen(nnY),
cachedModelsMutex(),
cachedModels(),
cachedModelsRefCount()
{}
~ComputeContext()
{}
{
// This should only be freed after all the handles are freed
assert(cachedModels.size() == 0);
}
};

// --------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1634,28 +1646,46 @@ void NeuralNet::freeComputeContext(ComputeContext* computeContext) {
//------------------------------------------------------------------------------

struct ComputeHandle {
const ComputeContext* context;
ComputeContext* context;
bool inputsUseNHWC;
ComputeHandleInternal handleInternal;
const Model model;
const std::string modelCacheKey;
std::shared_ptr<const Model> model;
std::unique_ptr<ScratchBuffers> scratch;
std::unique_ptr<Buffers> buffers;

ComputeHandle() = delete;
ComputeHandle(const ComputeHandle&) = delete;
ComputeHandle& operator=(const ComputeHandle&) = delete;

ComputeHandle(const ComputeContext* ctx, const LoadedModel& loadedModel, int maxBatchSize, bool iNHWC)
ComputeHandle(ComputeContext* ctx, const LoadedModel& loadedModel, int maxBatchSize, bool iNHWC)
: context(ctx),
inputsUseNHWC(iNHWC),
handleInternal(ctx),
model(loadedModel.modelDesc,ctx->nnXLen,ctx->nnYLen)
modelCacheKey(loadedModel.modelDesc.name + "-" + loadedModel.modelDesc.sha256),
model(nullptr)
{
scratch = std::make_unique<ScratchBuffers>(maxBatchSize,ctx->nnXLen,ctx->nnYLen);
buffers = std::make_unique<Buffers>(loadedModel.modelDesc,model,maxBatchSize,ctx->nnXLen,ctx->nnYLen);
{
std::lock_guard<std::mutex> lock(context->cachedModelsMutex);
if(context->cachedModels.find(modelCacheKey) == context->cachedModels.end()) {
context->cachedModels[modelCacheKey] = std::make_shared<const Model>(loadedModel.modelDesc,context->nnXLen,context->nnYLen);
}
model = context->cachedModels[modelCacheKey];
context->cachedModelsRefCount[modelCacheKey] += 1;
}

scratch = std::make_unique<ScratchBuffers>(maxBatchSize,context->nnXLen,context->nnYLen);
buffers = std::make_unique<Buffers>(loadedModel.modelDesc,*model,maxBatchSize,context->nnXLen,context->nnYLen);
}

~ComputeHandle() {
std::lock_guard<std::mutex>(context->cachedModelsMutex);
context->cachedModelsRefCount[modelCacheKey] -= 1;
assert(context->cachedModelsRefCount[modelCacheKey] >= 0);
if(context->cachedModelsRefCount[modelCacheKey] == 0) {
context->cachedModelsRefCount.erase(modelCacheKey);
context->cachedModels.erase(modelCacheKey);
}
}
};

Expand Down Expand Up @@ -1703,14 +1733,14 @@ void NeuralNet::getOutput(
const int batchSize = numBatchEltsFilled;
const int nnXLen = computeHandle->context->nnXLen;
const int nnYLen = computeHandle->context->nnYLen;
const int modelVersion = computeHandle->model.modelVersion;
const int modelVersion = computeHandle->model->modelVersion;

const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion);
const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion);
assert(numSpatialFeatures == computeHandle->model.numInputChannels);
assert(numSpatialFeatures == computeHandle->model->numInputChannels);
assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts);
assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts);
const int numPolicyChannels = computeHandle->model.numPolicyChannels;
const int numPolicyChannels = computeHandle->model->numPolicyChannels;

for(int nIdx = 0; nIdx<batchSize; nIdx++) {
float* rowSpatialInput = inputBuffers->spatialInput.data() + (inputBuffers->singleInputElts * nIdx);
Expand Down Expand Up @@ -1742,7 +1772,7 @@ void NeuralNet::getOutput(
computeMaskSum(&mask,maskSum.data());
vector<float>& convWorkspace = buffers.convWorkspace;

computeHandle->model.apply(
computeHandle->model->apply(
&computeHandle->handleInternal,
computeHandle->scratch.get(),
&input,
Expand Down Expand Up @@ -1801,7 +1831,7 @@ void NeuralNet::getOutput(
policyProbs[inputBuffers->singlePolicyResultElts] = policyPassSrcBuf[0];
}

int numValueChannels = computeHandle->model.numValueChannels;
int numValueChannels = computeHandle->model->numValueChannels;
assert(numValueChannels == 3);
output->whiteWinProb = valueData[row * numValueChannels];
output->whiteLossProb = valueData[row * numValueChannels + 1];
Expand All @@ -1811,12 +1841,12 @@ void NeuralNet::getOutput(
//As usual the client does the postprocessing.
if(output->whiteOwnerMap != NULL) {
const float* ownershipSrcBuf = ownershipData + row * nnXLen * nnYLen;
assert(computeHandle->model.numOwnershipChannels == 1);
assert(computeHandle->model->numOwnershipChannels == 1);
SymmetryHelpers::copyOutputsWithSymmetry(ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry);
}

if(modelVersion >= 9) {
int numScoreValueChannels = computeHandle->model.numScoreValueChannels;
int numScoreValueChannels = computeHandle->model->numScoreValueChannels;
assert(numScoreValueChannels == 6);
output->whiteScoreMean = scoreValueData[row * numScoreValueChannels];
output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1];
Expand All @@ -1826,7 +1856,7 @@ void NeuralNet::getOutput(
output->shorttermScoreError = scoreValueData[row * numScoreValueChannels + 5];
}
else if(modelVersion >= 8) {
int numScoreValueChannels = computeHandle->model.numScoreValueChannels;
int numScoreValueChannels = computeHandle->model->numScoreValueChannels;
assert(numScoreValueChannels == 4);
output->whiteScoreMean = scoreValueData[row * numScoreValueChannels];
output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1];
Expand All @@ -1836,7 +1866,7 @@ void NeuralNet::getOutput(
output->shorttermScoreError = 0;
}
else if(modelVersion >= 4) {
int numScoreValueChannels = computeHandle->model.numScoreValueChannels;
int numScoreValueChannels = computeHandle->model->numScoreValueChannels;
assert(numScoreValueChannels == 2);
output->whiteScoreMean = scoreValueData[row * numScoreValueChannels];
output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1];
Expand All @@ -1846,7 +1876,7 @@ void NeuralNet::getOutput(
output->shorttermScoreError = 0;
}
else if(modelVersion >= 3) {
int numScoreValueChannels = computeHandle->model.numScoreValueChannels;
int numScoreValueChannels = computeHandle->model->numScoreValueChannels;
assert(numScoreValueChannels == 1);
output->whiteScoreMean = scoreValueData[row * numScoreValueChannels];
//Version 3 neural nets don't have any second moment output, implicitly already folding it in, so we just use the mean squared
Expand Down

0 comments on commit 85b6dce

Please sign in to comment.