Skip to content

Commit

Permalink
cmsis-nn: refactor conv to get rid of duplicate code (#2402)
Browse files Browse the repository at this point in the history
BUG=Some duplicate code in cmsis-nn/conv.cc
  • Loading branch information
mansnils authored Jan 29, 2024
1 parent 324ae1e commit 0c4738a
Showing 1 changed file with 63 additions and 211 deletions.
274 changes: 63 additions & 211 deletions tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -159,104 +159,52 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus EvalQuantizedPerChannelInt4(
TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
cmsis_nn_conv_params conv_params;
conv_params.dilation.h = params.dilation_height_factor;
conv_params.dilation.w = params.dilation_width_factor;

// Initialize cmsis_nn convolution parameters
conv_params.input_offset = -data.reference_op_data.input_zero_point;
conv_params.output_offset = data.reference_op_data.output_zero_point;
conv_params.stride.h = params.stride_height;
conv_params.stride.w = params.stride_width;
conv_params.padding.h = data.reference_op_data.padding.height;
conv_params.padding.w = data.reference_op_data.padding.width;
conv_params.activation.min = data.reference_op_data.output_activation_min;
conv_params.activation.max = data.reference_op_data.output_activation_max;

// Initialize cmsis_nn per channel quantization parameters
cmsis_nn_per_channel_quant_params quant_params;
quant_params.multiplier = const_cast<int32_t*>(
data.reference_op_data.per_channel_output_multiplier);
quant_params.shift =
const_cast<int32_t*>(data.reference_op_data.per_channel_output_shift);

RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);

// Consistency check.
TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (tflite::micro::GetOptionalTensorData<int32_t>(bias)) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}

// Initialize cmsis_nn dimensions
// Input
cmsis_nn_dims input_dims;
input_dims.n = batch_size;
input_dims.h = input_shape.Dims(1);
input_dims.w = input_shape.Dims(2);
input_dims.c = input_depth;

// Filter
cmsis_nn_dims filter_dims;
filter_dims.n = output_depth;
filter_dims.h = filter_shape.Dims(1);
filter_dims.w = filter_shape.Dims(2);
filter_dims.c = input_depth;

// Bias
cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

// Output
cmsis_nn_dims output_dims;
output_dims.n = batch_size;
output_dims.h = output_shape.Dims(1);
output_dims.w = output_shape.Dims(2);
output_dims.c = output_depth;

// Initialize cmsis_nn context
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;
template <class ActType, class BiasType, class WeigthsType>
arm_cmsis_nn_status convolve_wrapper(
const cmsis_nn_context* ctx, const cmsis_nn_conv_params* conv_params,
const cmsis_nn_per_channel_quant_params* quant_params,
const cmsis_nn_dims* input_dims, const ActType* input,
const cmsis_nn_dims* filter_dims, const int8_t* filter,
const cmsis_nn_dims* bias_dims, const BiasType* bias,
const cmsis_nn_dims* output_dims, ActType* output, WeigthsType weightsT) {
return ARM_CMSIS_NN_ARG_ERROR;
}

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
// Note: ctx.size is currently not used in cmsis_nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
template <>
arm_cmsis_nn_status convolve_wrapper(
const cmsis_nn_context* ctx, const cmsis_nn_conv_params* conv_params,
const cmsis_nn_per_channel_quant_params* quant_params,
const cmsis_nn_dims* input_dims, const int8_t* input,
const cmsis_nn_dims* filter_dims, const int8_t* filter,
const cmsis_nn_dims* bias_dims, const int32_t* bias,
const cmsis_nn_dims* output_dims, int8_t* output, TfLiteType weightsT) {
if (weightsT == kTfLiteInt8) {
return arm_convolve_wrapper_s8(ctx, conv_params, quant_params, input_dims,
input, filter_dims, filter, bias_dims, bias,
output_dims, output);
} else if (weightsT == kTfLiteInt4) {
return arm_convolve_wrapper_s4(ctx, conv_params, quant_params, input_dims,
input, filter_dims, filter, bias_dims, bias,
output_dims, output);
} else {
return ARM_CMSIS_NN_ARG_ERROR;
}
}

// arm_convolve_wrapper_s4 dispatches the optimized kernel accordingly with
// the parameters passed for convolutions with 4 bit weights
TFLITE_DCHECK_EQ(
arm_convolve_wrapper_s4(
&ctx, &conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
tflite::micro::GetOptionalTensorData<int32_t>(bias), &output_dims,
tflite::micro::GetTensorData<int8_t>(output)),
ARM_CMSIS_NN_SUCCESS);

return kTfLiteOk;
template <>
arm_cmsis_nn_status convolve_wrapper(
const cmsis_nn_context* ctx, const cmsis_nn_conv_params* conv_params,
const cmsis_nn_per_channel_quant_params* quant_params,
const cmsis_nn_dims* input_dims, const int16_t* input,
const cmsis_nn_dims* filter_dims, const int8_t* filter,
const cmsis_nn_dims* bias_dims, const int64_t* bias,
const cmsis_nn_dims* output_dims, int16_t* output, TfLiteType weightsT) {
return arm_convolve_wrapper_s16(ctx, conv_params, quant_params, input_dims,
input, filter_dims, filter, bias_dims, bias,
output_dims, output);
}

template <typename ActType, typename BiasType, TfLiteType type>
TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params,
const OpData& data,
Expand Down Expand Up @@ -298,105 +246,7 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (tflite::micro::GetOptionalTensorData<int32_t>(bias)) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}

// Initialize cmsis_nn dimensions
// Input
cmsis_nn_dims input_dims;
input_dims.n = batch_size;
input_dims.h = input_shape.Dims(1);
input_dims.w = input_shape.Dims(2);
input_dims.c = input_depth;

// Filter
cmsis_nn_dims filter_dims;
filter_dims.n = output_depth;
filter_dims.h = filter_shape.Dims(1);
filter_dims.w = filter_shape.Dims(2);
filter_dims.c = input_depth;

// Bias
cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

// Output
cmsis_nn_dims output_dims;
output_dims.n = batch_size;
output_dims.h = output_shape.Dims(1);
output_dims.w = output_shape.Dims(2);
output_dims.c = output_depth;

// Initialize cmsis_nn context
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
// Note: ctx.size is currently not used in cmsis_nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
}

// arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with
// the parameters passed
TFLITE_DCHECK_EQ(
arm_convolve_wrapper_s8(
&ctx, &conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
tflite::micro::GetOptionalTensorData<int32_t>(bias), &output_dims,
tflite::micro::GetTensorData<int8_t>(output)),
ARM_CMSIS_NN_SUCCESS);

return kTfLiteOk;
}

TfLiteStatus EvalQuantizedPerChannel16x8(
TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
cmsis_nn_conv_params conv_params;
conv_params.dilation.h = params.dilation_height_factor;
conv_params.dilation.w = params.dilation_width_factor;

// Initialize cmsis_nn convolution parameters
conv_params.input_offset = -data.reference_op_data.input_zero_point;
conv_params.output_offset = data.reference_op_data.output_zero_point;
conv_params.stride.h = params.stride_height;
conv_params.stride.w = params.stride_width;
conv_params.padding.h = data.reference_op_data.padding.height;
conv_params.padding.w = data.reference_op_data.padding.width;
conv_params.activation.min = data.reference_op_data.output_activation_min;
conv_params.activation.max = data.reference_op_data.output_activation_max;

// Initialize cmsis_nn per channel quantization parameters
cmsis_nn_per_channel_quant_params quant_params;
quant_params.multiplier = const_cast<int32_t*>(
data.reference_op_data.per_channel_output_multiplier);
quant_params.shift =
const_cast<int32_t*>(data.reference_op_data.per_channel_output_shift);

RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);

// Consistency check.
TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (tflite::micro::GetOptionalTensorData<int64_t>(bias)) {
if (tflite::micro::GetOptionalTensorData<BiasType>(bias)) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}

Expand Down Expand Up @@ -437,17 +287,19 @@ TfLiteStatus EvalQuantizedPerChannel16x8(
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
// Note: ctx.size is currently not used in cmsis_nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
// The buffer should be allocated in the prepare function through
// the corresponding arm_convolve_wrapper_[type]_get_buffer_size
}

// arm_convolve_wrapper_[type] dispatches the optimized kernel accordingly
// with the parameters passed
TFLITE_DCHECK_EQ(
arm_convolve_wrapper_s16(
convolve_wrapper(
&ctx, &conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int16_t>(input), &filter_dims,
tflite::micro::GetTensorData<ActType>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
tflite::micro::GetOptionalTensorData<int64_t>(bias), &output_dims,
tflite::micro::GetTensorData<int16_t>(output)),
tflite::micro::GetOptionalTensorData<BiasType>(bias), &output_dims,
tflite::micro::GetTensorData<ActType>(output), type),
ARM_CMSIS_NN_SUCCESS);

return kTfLiteOk;
Expand All @@ -471,8 +323,8 @@ TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));

return EvalQuantizedPerChannelInt4(context, node, params, data, input, filter,
bias, output);
return EvalQuantizedPerChannel<int8_t, int32_t, kTfLiteInt4>(
context, node, params, data, input, filter, bias, output);
}

TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
Expand All @@ -493,8 +345,8 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));

return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
return EvalQuantizedPerChannel<int8_t, int32_t, kTfLiteInt8>(
context, node, params, data, input, filter, bias, output);
}

TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) {
Expand All @@ -516,8 +368,8 @@ TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) {
const OpData& data = *(static_cast<const OpData*>(node->user_data));

if (bias == nullptr || bias->type == kTfLiteInt64) {
return EvalQuantizedPerChannel16x8(context, node, params, data, input,
filter, bias, output);
return EvalQuantizedPerChannel<int16_t, int64_t, kTfLiteInt16>(
context, node, params, data, input, filter, bias, output);
} else {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data.reference_op_data),
Expand Down Expand Up @@ -580,12 +432,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8: {
switch (filter->type) {
case kTfLiteInt4: {
return EvalQuantizedPerChannelInt4(context, node, params, data, input,
filter, bias, output);
return EvalQuantizedPerChannel<int8_t, int32_t, kTfLiteInt4>(
context, node, params, data, input, filter, bias, output);
}
case kTfLiteInt8: {
return EvalQuantizedPerChannel(context, node, params, data, input,
filter, bias, output);
return EvalQuantizedPerChannel<int8_t, int32_t, kTfLiteInt8>(
context, node, params, data, input, filter, bias, output);
}
default: {
MicroPrintf("Filter type %s (%d) not supported.",
Expand All @@ -597,8 +449,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
case kTfLiteInt16: {
if (bias == nullptr || bias->type == kTfLiteInt64) {
return EvalQuantizedPerChannel16x8(context, node, params, data, input,
filter, bias, output);
return EvalQuantizedPerChannel<int16_t, int64_t, kTfLiteInt16>(
context, node, params, data, input, filter, bias, output);
} else if (bias->type == kTfLiteInt32) {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data.reference_op_data),
Expand Down

0 comments on commit 0c4738a

Please sign in to comment.