diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc index 20cf0e104f3..0ece82793fe 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc @@ -218,7 +218,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims.w = output_shape.Dims(2); output_dims.c = output_depth; -#if defined(KERNELS_OPTIMIZED_FOR_SPEED) + cmsis_nn_transpose_conv_params conv_params; + conv_params.stride.w = params->stride_width; + conv_params.stride.h = params->stride_height; + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); cmsis_nn_dims input_dims; @@ -234,18 +237,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { filter_dims.c = input_depth; const size_t buf_size = arm_transpose_conv_s8_get_buffer_size( - &input_dims, &filter_dims, &output_dims); + &conv_params, &input_dims, &filter_dims, &output_dims); TFLITE_DCHECK(context->RequestScratchBufferInArena( context, buf_size, &(data->scratch_buffer_index)) == kTfLiteOk); -#endif - - // Quantized 8-bit kernels use an int32 scratch buffer. - TFLITE_DCHECK( - context->RequestScratchBufferInArena( - context, - output_dims.h * output_dims.w * output_dims.c * sizeof(int32_t), - &(data->scratch_buffer_output_index)) == kTfLiteOk); + + // Quantized 8-bit kernels use a second scratch buffer for reversing the + // filter for certain configurations. + const size_t reverse_buf_size = + arm_transpose_conv_s8_get_reverse_conv_buffer_size( + &conv_params, &input_dims, &filter_dims); + TFLITE_DCHECK(context->RequestScratchBufferInArena( + context, reverse_buf_size, + &(data->scratch_buffer_output_index)) == kTfLiteOk); } // Quantized 16x8 kernels use an int64 scratch buffer. @@ -295,7 +299,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -#if defined(KERNELS_OPTIMIZED_FOR_SPEED) TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params, const OpData& data, @@ -377,7 +380,7 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, context->GetScratchBuffer(context, data.scratch_buffer_output_index); TFLITE_DCHECK_EQ( - arm_transpose_conv_s8( + arm_transpose_conv_wrapper_s8( &ctx, &scratch_output_ctx, &conv_params, &quant_params, &input_dims, tflite::micro::GetTensorData(input), &filter_dims, tflite::micro::GetTensorData(filter), &bias_dims, @@ -387,7 +390,6 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -#endif TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = @@ -428,29 +430,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } case kTfLiteInt8: { -#if defined(KERNELS_OPTIMIZED_FOR_SIZE) - int32_t* scratch_buffer = static_cast( - context->GetScratchBuffer(context, data.scratch_buffer_index)); - reference_integer_ops::TransposeConv( - data.params, data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); -#elif defined(KERNELS_OPTIMIZED_FOR_SPEED) return EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, output); -#else - MicroPrintf( - "Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED " - "must be defined"); - return kTfLiteError; -#endif break; } case kTfLiteInt16: { @@ -514,33 +495,11 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); -#if defined(KERNELS_OPTIMIZED_FOR_SIZE) - int32_t* scratch_buffer = static_cast( - context->GetScratchBuffer(context, data.scratch_buffer_index)); - reference_integer_ops::TransposeConv( - data.params, data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); -#elif defined(KERNELS_OPTIMIZED_FOR_SPEED) const auto& params = *(reinterpret_cast(node->builtin_data)); return EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, output); -#else - MicroPrintf( - "Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED must " - "be defined"); - return kTfLiteError; -#endif - return kTfLiteOk; } } // namespace diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh index 04e76dd508c..a211a2b38a3 100755 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh @@ -38,9 +38,9 @@ source ${TENSORFLOW_ROOT}tensorflow/lite/micro/tools/make/bash_helpers.sh DOWNLOADS_DIR=${1} DOWNLOADED_CMSIS_NN_PATH=${DOWNLOADS_DIR}/cmsis_nn -ZIP_PREFIX_NN="f2cb41ca1450a4eb4307b2779dd5aae9028285a5" +ZIP_PREFIX_NN="22080c68d040c98139e6cb1549473e3149735f4d" CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip" -CMSIS_NN_MD5="4d0e623432d6f8d3b201cbcd89218adf" +CMSIS_NN_MD5="32aa69692541060a76b18bd5d2d98956" should_download=$(check_should_download ${DOWNLOADS_DIR})