Skip to content

Commit

Permalink
Enable optimized ADD/SUB for HiFi5 (#2399)
Browse files Browse the repository at this point in the history
The HiFi5 nnlib has the same optimized kernels for ADD and SUB as the HiFi4 nnlib, but the proper defines were not added to the kernel implementation. This was resulting in falling back to reference kernels for both of those operators on HiFi5.

BUG=none
  • Loading branch information
rascani authored Jan 19, 2024
1 parent 2d80ee4 commit 73e419e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions tensorflow/lite/micro/kernels/xtensa/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.output_shift = data->output_shift;
SetActivationParams(data->output_activation_min, data->output_activation_max,
&op_params);
#if !(defined(HIFI3) || defined(HIFI4))
#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
#endif // !defined(HIFI3) && !defined(HIFI4)
#endif // !defined(HIFI3) && !defined(HIFI4) && !defined(HIFI5)

switch (output->type) {
case kTfLiteInt8: {
Expand All @@ -126,7 +126,7 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
*(reinterpret_cast<XtensaAddOpData*>(node->user_data));
AddEvalQuantizedVision(context, node, *params, op_data, input1, input2,
output);
#elif defined(HIFI3) || defined(HIFI4) // defined(VISION_P6)
#elif defined(HIFI3) || defined(HIFI4) || defined(HIFI5) // defined(VISION_P6)
int err;
const RuntimeShape extended_input1_shape =
RuntimeShape::ExtendedShape(4, tflite::micro::GetTensorShape(input1));
Expand All @@ -150,7 +150,7 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.left_shift);

TF_LITE_ENSURE(context, err == 0);
#else // defined(VISION_P6)
#else // defined(VISION_P6)
if (need_broadcast) {
reference_integer_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
Expand All @@ -168,11 +168,11 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
#endif // defined(VISION_P6)
#endif // defined(VISION_P6)
break;
}
case kTfLiteInt16: {
#if defined(HIFI3) || defined(HIFI4)
#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
int err;
const RuntimeShape extended_input1_shape =
RuntimeShape::ExtendedShape(4, tflite::micro::GetTensorShape(input1));
Expand All @@ -196,7 +196,7 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.left_shift);

TF_LITE_ENSURE(context, err == 0);
#else // defined(HIFI3) || defined(HIFI4)
#else // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
if (need_broadcast) {
reference_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
Expand All @@ -214,7 +214,7 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int16_t>(output),
false);
}
#endif // defined(HIFI3) || defined(HIFI4)
#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
break;
}
default:
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/lite/micro/kernels/xtensa/sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
&op_params);
// TODO(b/259724572): vision_p6 and hifi code path is getting very confusing.
// Let's separate them into two different files.
#if !(defined(HIFI3) || defined(HIFI4))
#if !(defined(HIFI3) || defined(HIFI4) || defined(HIFI5))
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
#endif // !(defined(HIFI3) || defined(HIFI4))

switch (output->type) {
case kTfLiteInt8: {
#if defined(HIFI3) || defined(HIFI4)
#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
int err;
const RuntimeShape extended_input1_shape =
RuntimeShape::ExtendedShape(5, tflite::micro::GetTensorShape(input1));
Expand Down Expand Up @@ -133,7 +133,7 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,

TF_LITE_ENSURE(context, err == 0);
}
#else // defined(HIFI3) || defined(HIFI4)
#else // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
if (need_broadcast) {
tflite::reference_ops::BroadcastQuantSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
Expand All @@ -151,11 +151,11 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
#endif // defined(HIFI3) || defined(HIFI4)
#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
break;
}
case kTfLiteInt16: {
#if defined(HIFI3) || defined(HIFI4)
#if defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
int err;
const RuntimeShape extended_input1_shape =
RuntimeShape::ExtendedShape(5, tflite::micro::GetTensorShape(input1));
Expand Down Expand Up @@ -196,7 +196,7 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,

TF_LITE_ENSURE(context, err == 0);
}
#else // defined(HIFI3) || defined(HIFI4)
#else // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
if (need_broadcast) {
tflite::reference_ops::BroadcastQuantSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
Expand All @@ -214,7 +214,7 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
}
#endif // defined(HIFI3) || defined(HIFI4)
#endif // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
break;
}
default:
Expand Down Expand Up @@ -256,4 +256,4 @@ TFLMRegistration Register_SUB() {
return tflite::micro::RegisterOp(SubInit, SubPrepare, SubEval);
}

} // namespace tflite
} // namespace tflite

0 comments on commit 73e419e

Please sign in to comment.