Skip to content

Commit

Permalink
Add per channel quantization support for fully connnected (#2673)
Browse files Browse the repository at this point in the history
Adds per channel quantization support for fully connnected reference and cmsis-nn kernel.

BUG=missing support for FC per channel quantization
  • Loading branch information
mansnils authored Sep 27, 2024
1 parent 4ed97bf commit 47b5450
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 87 deletions.
100 changes: 62 additions & 38 deletions tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"

#include "Include/arm_nnfunctions.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"
Expand All @@ -35,12 +35,8 @@ namespace {
struct OpData {
OpDataFullyConnected reference_op_data;

// Conv 1x1 that may be invoked in some cases currently need per channel
// quantization.
int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;

// Index to buffer for optimizations if applicable.
// Index to buffers for optimizations if applicable.
int buffer_conv_1x1_idx;
int buffer_idx;

int32_t* kernel_sums;
Expand Down Expand Up @@ -94,6 +90,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const RuntimeShape output_shape = GetTensorShape(output);
const int filter_dim_count = filter_shape.DimensionsCount();
const int output_dim_count = output_shape.DimensionsCount();

TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);

cmsis_nn_dims filter_dims;
filter_dims.n = filter_shape.Dims(filter_dim_count - 1);
filter_dims.h = 1;
Expand All @@ -106,36 +106,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {

// Set buffer index to a reset value
data->buffer_idx = -1;
data->buffer_conv_1x1_idx = -1;

TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected(
context, params->activation, input->type, input, filter, bias, output,
&(data->reference_op_data)));

// Currently only Int8 is supported for per channel quantization.
TF_LITE_ENSURE(
context, !data->reference_op_data.is_per_channel ||
(data->reference_op_data.is_per_channel &&
input->type == kTfLiteInt8 && filter->type != kTfLiteInt4));

int32_t buf_size = 0;

if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
} else if (input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) {
const RuntimeShape input_shape = GetTensorShape(input);

TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);

if (output_dim_count > 2 && data->accum_depth % 4 == 0) {
data->per_channel_output_multiplier =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, data->output_depth * sizeof(int32_t)));
data->per_channel_output_shift =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, data->output_depth * sizeof(int32_t)));
const bool is_conv_1x1_possible =
output_dim_count > 2 && data->accum_depth % 4 == 0;

if (is_conv_1x1_possible) {
// In case per tensor quantization we use a scratch buffer to fake
// conv1x1 per channel quantization.
if (!data->reference_op_data.is_per_channel) {
const int total_per_channel_quantization_size =
data->output_depth * sizeof(int32_t) * 2;
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, total_per_channel_quantization_size,
&data->buffer_conv_1x1_idx));
}

cmsis_nn_dims input_dims;
input_dims.n = data->batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = data->accum_depth;

buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
} else if (input->type == kTfLiteInt8) {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
Expand Down Expand Up @@ -218,9 +226,6 @@ TfLiteStatus EvalQuantizedInt4(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const int output_dim_count = output_shape.DimensionsCount();
TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);

cmsis_nn_per_tensor_quant_params quant_params;
cmsis_nn_dims input_dims;
Expand Down Expand Up @@ -262,18 +267,16 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
TfLiteEvalTensor* output) {
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const int output_dim_count = output_shape.DimensionsCount();
TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);

cmsis_nn_per_tensor_quant_params quant_params;
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
cmsis_nn_dims input_dims;
cmsis_nn_dims filter_dims;
cmsis_nn_dims bias_dims;
cmsis_nn_dims output_dims;
cmsis_nn_context ctx;

PopulateCommonParams(context, &quant_params, &input_dims, &filter_dims,
&bias_dims, &output_dims, &ctx, data);
PopulateCommonParams(context, &per_tensor_quant_params, &input_dims,
&filter_dims, &bias_dims, &output_dims, &ctx, data);

const int32_t* bias_data =
tflite::micro::GetOptionalTensorData<int32_t>(bias);
Expand All @@ -292,14 +295,23 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
conv_params.activation.max = data.reference_op_data.output_activation_max;

cmsis_nn_per_channel_quant_params per_channel_quant_params;
per_channel_quant_params.multiplier =
const_cast<int32_t*>(data.per_channel_output_multiplier);
per_channel_quant_params.shift =
const_cast<int32_t*>(data.per_channel_output_shift);

for (int i = 0; i < data.output_depth; i++) {
per_channel_quant_params.multiplier[i] = quant_params.multiplier;
per_channel_quant_params.shift[i] = quant_params.shift;
if (data.reference_op_data.is_per_channel) {
per_channel_quant_params.multiplier =
data.reference_op_data.per_channel_output_multiplier;
per_channel_quant_params.shift =
data.reference_op_data.per_channel_output_shift;
} else {
TFLITE_DCHECK_GE(data.buffer_conv_1x1_idx, 4);
per_channel_quant_params.multiplier = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.buffer_conv_1x1_idx));
per_channel_quant_params.shift =
per_channel_quant_params.multiplier + data.output_depth;

for (int i = 0; i < data.output_depth; i++) {
per_channel_quant_params.multiplier[i] =
per_tensor_quant_params.multiplier;
per_channel_quant_params.shift[i] = per_tensor_quant_params.shift;
}
}

TF_LITE_ENSURE_EQ(
Expand All @@ -318,6 +330,18 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
fc_params.activation.min = data.reference_op_data.output_activation_min;
fc_params.activation.max = data.reference_op_data.output_activation_max;

cmsis_nn_quant_params quant_params;
quant_params.is_per_channel = data.reference_op_data.is_per_channel;

if (quant_params.is_per_channel) {
quant_params.multiplier =
data.reference_op_data.per_channel_output_multiplier;
quant_params.shift = data.reference_op_data.per_channel_output_shift;
} else {
quant_params.multiplier = &per_tensor_quant_params.multiplier;
quant_params.shift = &per_tensor_quant_params.shift;
}

if (data.kernel_sums != nullptr) {
ctx.buf = data.kernel_sums;
} else if (ctx.buf != nullptr) {
Expand All @@ -330,7 +354,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,

TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s8(
arm_fully_connected_wrapper_s8(
&ctx, &fc_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims, bias_data,
Expand Down
35 changes: 24 additions & 11 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,16 +144,29 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt8: {
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
data.per_channel_output_multiplier,
reinterpret_cast<const int*>(data.per_channel_output_shift),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output))
: tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
}
default: {
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/lite/micro/kernels/fully_connected.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,6 +45,10 @@ struct OpDataFullyConnected {
// A buffer used to store unpacked filter values. This is used if the source
// tensor is of n-bit precision that cannot be easily processed by kernels.
int filter_buffer_index;

int32_t* per_channel_output_multiplier;
int32_t* per_channel_output_shift;
bool is_per_channel;
#endif
};

Expand Down
104 changes: 78 additions & 26 deletions tensorflow/lite/micro/kernels/fully_connected_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,46 +57,98 @@ TfLiteStatus CalculateOpDataFullyConnected(
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
OpDataFullyConnected* data) {
// TODO(b/324385802): Support per-channel quantization for FullyConnected.
// If you have hit this failure message, you will need to disable this
// behavior. This can be done by setting the following flag to true:
// TfLiteConverter._experimental_disable_per_channel_quantization_for_dense_layers
// https://github.com/tensorflow/tensorflow/blob/377f47694fa790e98db6665b9adecde00b5e0d68/tensorflow/lite/python/lite.py#L674
#ifndef HEXAGON
data->is_per_channel = false;
#endif

if (data_type == kTfLiteFloat32) {
return kTfLiteOk;
}

bool is_per_channel = false;
if (filter->quantization.type == kTfLiteAffineQuantization &&
filter->quantization.params != nullptr) {
TfLiteAffineQuantization* affine_quantization =
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE_MSG(
context, affine_quantization->scale->size == 1,
"FullyConnected per-channel quantization not yet supported. Please set "
"converter._experimental_disable_per_channel_quantization_for_dense_"
"layers = True.");
is_per_channel = affine_quantization->scale->size > 1;
}

if (data_type != kTfLiteFloat32) {
if (is_per_channel) {
// Hexagon currently does not support per-channel fully connected, and the
// existing hexagon support library is intolerant of data members being added to
// OpDataFullyConnected. As such, we have to be careful not to reference newer
// data members. This is why we use a local variable is_per_channel in common
// code, and only reference the data->is_per_channel in non-HEXAGON code.
#ifdef HEXAGON
TF_LITE_ENSURE_MSG(
context, !is_per_channel,
"FullyConnected per-channel quantization not yet supported on Hexagon. "
"Please set converter._experimental_disable_per_channel_quantization_"
"for_dense_layers = True.");
#else
data->is_per_channel = is_per_channel;
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
const int per_channel_quantization_size = affine_quantization->scale->size;

// Currently only Int8 is supported for per channel quantization.
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 && filter->type != kTfLiteInt4);

TF_LITE_ENSURE_EQ(
context, per_channel_quantization_size,
filter->dims->data[affine_quantization->quantized_dimension]);

data->per_channel_output_multiplier =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, per_channel_quantization_size * sizeof(int32_t)));
data->per_channel_output_shift =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, per_channel_quantization_size * sizeof(int32_t)));

// Populate multiplier and shift using affine quantization.
const float input_scale = input->params.scale;
const float output_scale = output->params.scale;
const float* filter_scales = affine_quantization->scale->data;

for (int i = 0; i < per_channel_quantization_size; ++i) {
const float scale = filter_scales[i];
const double filter_scale = static_cast<double>(scale);
const double effective_output_scale = static_cast<double>(input_scale) *
filter_scale /
static_cast<double>(output_scale);
int32_t significand;
int channel_shift;
QuantizeMultiplier(effective_output_scale, &significand, &channel_shift);
data->per_channel_output_multiplier[i] = significand;
data->per_channel_output_shift[i] = channel_shift;
}
#endif
} else {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
&data->output_shift);
}

// Filter weights will always be symmetric quantized since we only support
// int8 quantization. See
// https://github.com/tensorflow/tensorflow/issues/44912 for additional
// context.
TFLITE_DCHECK(filter->params.zero_point == 0);
// Filter weights will always be symmetric quantized since we only support
// int8 quantization. See
// https://github.com/tensorflow/tensorflow/issues/44912 for additional
// context.
TFLITE_DCHECK(filter->params.zero_point == 0);

data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;

return CalculateActivationRangeQuantized(context, activation, output,
&data->output_activation_min,
&data->output_activation_max);
}
return kTfLiteOk;
return CalculateActivationRangeQuantized(context, activation, output,
&data->output_activation_min,
&data->output_activation_max);
}

} // namespace tflite
Loading

0 comments on commit 47b5450

Please sign in to comment.