From 85b6dce7aebd5c4305b217c1724f4264810e715f Mon Sep 17 00:00:00 2001 From: David Wu Date: Sun, 17 Dec 2023 13:31:15 -0500 Subject: [PATCH] Make eigen threads all share the same copy of the model --- cpp/neuralnet/eigenbackend.cpp | 66 ++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/cpp/neuralnet/eigenbackend.cpp b/cpp/neuralnet/eigenbackend.cpp index 1f0ecfeb3..88709cbaa 100644 --- a/cpp/neuralnet/eigenbackend.cpp +++ b/cpp/neuralnet/eigenbackend.cpp @@ -191,20 +191,32 @@ static size_t roundUpToMultiple(size_t size, size_t ofThis) { // -------------------------------------------------------------------------------------------------------------- +struct Model; + struct ComputeContext { const int nnXLen; const int nnYLen; + std::mutex cachedModelsMutex; + std::map> cachedModels; + std::map cachedModelsRefCount; + ComputeContext() = delete; ComputeContext(const ComputeContext&) = delete; ComputeContext& operator=(const ComputeContext&) = delete; ComputeContext(int nnX, int nnY) : nnXLen(nnX), - nnYLen(nnY) + nnYLen(nnY), + cachedModelsMutex(), + cachedModels(), + cachedModelsRefCount() {} ~ComputeContext() - {} + { + // This should only be freed after all the handles are freed + assert(cachedModels.size() == 0); + } }; // -------------------------------------------------------------------------------------------------------------- @@ -1634,10 +1646,11 @@ void NeuralNet::freeComputeContext(ComputeContext* computeContext) { //------------------------------------------------------------------------------ struct ComputeHandle { - const ComputeContext* context; + ComputeContext* context; bool inputsUseNHWC; ComputeHandleInternal handleInternal; - const Model model; + const std::string modelCacheKey; + std::shared_ptr model; std::unique_ptr scratch; std::unique_ptr buffers; @@ -1645,17 +1658,34 @@ struct ComputeHandle { ComputeHandle(const ComputeHandle&) = delete; ComputeHandle& operator=(const ComputeHandle&) = delete; - ComputeHandle(const ComputeContext* ctx, const LoadedModel& loadedModel, int maxBatchSize, bool iNHWC) + ComputeHandle(ComputeContext* ctx, const LoadedModel& loadedModel, int maxBatchSize, bool iNHWC) : context(ctx), inputsUseNHWC(iNHWC), handleInternal(ctx), - model(loadedModel.modelDesc,ctx->nnXLen,ctx->nnYLen) + modelCacheKey(loadedModel.modelDesc.name + "-" + loadedModel.modelDesc.sha256), + model(nullptr) { - scratch = std::make_unique(maxBatchSize,ctx->nnXLen,ctx->nnYLen); - buffers = std::make_unique(loadedModel.modelDesc,model,maxBatchSize,ctx->nnXLen,ctx->nnYLen); + { + std::lock_guard lock(context->cachedModelsMutex); + if(context->cachedModels.find(modelCacheKey) == context->cachedModels.end()) { + context->cachedModels[modelCacheKey] = std::make_shared(loadedModel.modelDesc,context->nnXLen,context->nnYLen); + } + model = context->cachedModels[modelCacheKey]; + context->cachedModelsRefCount[modelCacheKey] += 1; + } + + scratch = std::make_unique(maxBatchSize,context->nnXLen,context->nnYLen); + buffers = std::make_unique(loadedModel.modelDesc,*model,maxBatchSize,context->nnXLen,context->nnYLen); } ~ComputeHandle() { + std::lock_guard(context->cachedModelsMutex); + context->cachedModelsRefCount[modelCacheKey] -= 1; + assert(context->cachedModelsRefCount[modelCacheKey] >= 0); + if(context->cachedModelsRefCount[modelCacheKey] == 0) { + context->cachedModelsRefCount.erase(modelCacheKey); + context->cachedModels.erase(modelCacheKey); + } } }; @@ -1703,14 +1733,14 @@ void NeuralNet::getOutput( const int batchSize = numBatchEltsFilled; const int nnXLen = computeHandle->context->nnXLen; const int nnYLen = computeHandle->context->nnYLen; - const int modelVersion = computeHandle->model.modelVersion; + const int modelVersion = computeHandle->model->modelVersion; const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); - assert(numSpatialFeatures == computeHandle->model.numInputChannels); + assert(numSpatialFeatures == computeHandle->model->numInputChannels); assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts); - const int numPolicyChannels = computeHandle->model.numPolicyChannels; + const int numPolicyChannels = computeHandle->model->numPolicyChannels; for(int nIdx = 0; nIdxspatialInput.data() + (inputBuffers->singleInputElts * nIdx); @@ -1742,7 +1772,7 @@ void NeuralNet::getOutput( computeMaskSum(&mask,maskSum.data()); vector& convWorkspace = buffers.convWorkspace; - computeHandle->model.apply( + computeHandle->model->apply( &computeHandle->handleInternal, computeHandle->scratch.get(), &input, @@ -1801,7 +1831,7 @@ void NeuralNet::getOutput( policyProbs[inputBuffers->singlePolicyResultElts] = policyPassSrcBuf[0]; } - int numValueChannels = computeHandle->model.numValueChannels; + int numValueChannels = computeHandle->model->numValueChannels; assert(numValueChannels == 3); output->whiteWinProb = valueData[row * numValueChannels]; output->whiteLossProb = valueData[row * numValueChannels + 1]; @@ -1811,12 +1841,12 @@ void NeuralNet::getOutput( //As usual the client does the postprocessing. if(output->whiteOwnerMap != NULL) { const float* ownershipSrcBuf = ownershipData + row * nnXLen * nnYLen; - assert(computeHandle->model.numOwnershipChannels == 1); + assert(computeHandle->model->numOwnershipChannels == 1); SymmetryHelpers::copyOutputsWithSymmetry(ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); } if(modelVersion >= 9) { - int numScoreValueChannels = computeHandle->model.numScoreValueChannels; + int numScoreValueChannels = computeHandle->model->numScoreValueChannels; assert(numScoreValueChannels == 6); output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; @@ -1826,7 +1856,7 @@ void NeuralNet::getOutput( output->shorttermScoreError = scoreValueData[row * numScoreValueChannels + 5]; } else if(modelVersion >= 8) { - int numScoreValueChannels = computeHandle->model.numScoreValueChannels; + int numScoreValueChannels = computeHandle->model->numScoreValueChannels; assert(numScoreValueChannels == 4); output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; @@ -1836,7 +1866,7 @@ void NeuralNet::getOutput( output->shorttermScoreError = 0; } else if(modelVersion >= 4) { - int numScoreValueChannels = computeHandle->model.numScoreValueChannels; + int numScoreValueChannels = computeHandle->model->numScoreValueChannels; assert(numScoreValueChannels == 2); output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; @@ -1846,7 +1876,7 @@ void NeuralNet::getOutput( output->shorttermScoreError = 0; } else if(modelVersion >= 3) { - int numScoreValueChannels = computeHandle->model.numScoreValueChannels; + int numScoreValueChannels = computeHandle->model->numScoreValueChannels; assert(numScoreValueChannels == 1); output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; //Version 3 neural nets don't have any second moment output, implicitly already folding it in, so we just use the mean squared