Skip to content

Commit

Permalink
Update CMSIS-NN Transpose conv function call (#2760)
Browse files Browse the repository at this point in the history
Also removes OPTIMIZE_FOR_SIZE compiler option as the new implementation is more space efficient. 

BUG=CMSIS-NN update
  • Loading branch information
AdrianLundell authored Nov 12, 2024
1 parent d7c42b9 commit 182c8c7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 58 deletions.
71 changes: 15 additions & 56 deletions tensorflow/lite/micro/kernels/cmsis_nn/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_dims.w = output_shape.Dims(2);
output_dims.c = output_depth;

#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
cmsis_nn_transpose_conv_params conv_params;
conv_params.stride.w = params->stride_width;
conv_params.stride.h = params->stride_height;

const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);

cmsis_nn_dims input_dims;
Expand All @@ -234,18 +237,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
filter_dims.c = input_depth;

const size_t buf_size = arm_transpose_conv_s8_get_buffer_size(
&input_dims, &filter_dims, &output_dims);
&conv_params, &input_dims, &filter_dims, &output_dims);
TFLITE_DCHECK(context->RequestScratchBufferInArena(
context, buf_size, &(data->scratch_buffer_index)) ==
kTfLiteOk);
#endif

// Quantized 8-bit kernels use an int32 scratch buffer.
TFLITE_DCHECK(
context->RequestScratchBufferInArena(
context,
output_dims.h * output_dims.w * output_dims.c * sizeof(int32_t),
&(data->scratch_buffer_output_index)) == kTfLiteOk);

// Quantized 8-bit kernels use a second scratch buffer for reversing the
// filter for certain configurations.
const size_t reverse_buf_size =
arm_transpose_conv_s8_get_reverse_conv_buffer_size(
&conv_params, &input_dims, &filter_dims);
TFLITE_DCHECK(context->RequestScratchBufferInArena(
context, reverse_buf_size,
&(data->scratch_buffer_output_index)) == kTfLiteOk);
}

// Quantized 16x8 kernels use an int64 scratch buffer.
Expand Down Expand Up @@ -295,7 +299,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params,
const OpData& data,
Expand Down Expand Up @@ -377,7 +380,7 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
context->GetScratchBuffer(context, data.scratch_buffer_output_index);

TFLITE_DCHECK_EQ(
arm_transpose_conv_s8(
arm_transpose_conv_wrapper_s8(
&ctx, &scratch_output_ctx, &conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
Expand All @@ -387,7 +390,6 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,

return kTfLiteOk;
}
#endif

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
Expand Down Expand Up @@ -428,29 +430,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt8: {
#if defined(KERNELS_OPTIMIZED_FOR_SIZE)
int32_t* scratch_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
reference_integer_ops::TransposeConv(
data.params, data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
#elif defined(KERNELS_OPTIMIZED_FOR_SPEED)
return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
#else
MicroPrintf(
"Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED "
"must be defined");
return kTfLiteError;
#endif
break;
}
case kTfLiteInt16: {
Expand Down Expand Up @@ -514,33 +495,11 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));

#if defined(KERNELS_OPTIMIZED_FOR_SIZE)
int32_t* scratch_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
reference_integer_ops::TransposeConv(
data.params, data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output),
tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
#elif defined(KERNELS_OPTIMIZED_FOR_SPEED)
const auto& params =
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));

return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
#else
MicroPrintf(
"Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED must "
"be defined");
return kTfLiteError;
#endif
return kTfLiteOk;
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ source ${TENSORFLOW_ROOT}tensorflow/lite/micro/tools/make/bash_helpers.sh
DOWNLOADS_DIR=${1}
DOWNLOADED_CMSIS_NN_PATH=${DOWNLOADS_DIR}/cmsis_nn

ZIP_PREFIX_NN="f2cb41ca1450a4eb4307b2779dd5aae9028285a5"
ZIP_PREFIX_NN="22080c68d040c98139e6cb1549473e3149735f4d"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
CMSIS_NN_MD5="4d0e623432d6f8d3b201cbcd89218adf"
CMSIS_NN_MD5="32aa69692541060a76b18bd5d2d98956"

should_download=$(check_should_download ${DOWNLOADS_DIR})

Expand Down

0 comments on commit 182c8c7

Please sign in to comment.