Skip to content

Commit

Permalink
Enable bool input for CAST (#2401)
Browse files Browse the repository at this point in the history
BUG=none
  • Loading branch information
rascani authored Jan 22, 2024
1 parent 4b2bdd1 commit 324ae1e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/lite/micro/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ TfLiteStatus CastEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
return copyToTensor(context, tflite::micro::GetTensorData<float>(input),
output, num_elements);
case kTfLiteBool:
return copyToTensor(context, tflite::micro::GetTensorData<bool>(input),
output, num_elements);
default:
// Unsupported type.
MicroPrintf("Input type %s (%d) not supported.",
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/lite/micro/kernels/cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,12 @@ TF_LITE_MICRO_TEST(CastUInt32ToInt32) {
tflite::testing::TestCast(input_dims, input_values, golden, output_data);
}

TF_LITE_MICRO_TEST(CastBoolToFloat) {
float output_data[6];
int input_dims[] = {2, 2, 3};
const bool input_values[] = {true, true, false, true, false, true};
const float golden[] = {1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f};
tflite::testing::TestCast(input_dims, input_values, golden, output_data);
}

TF_LITE_MICRO_TESTS_END

0 comments on commit 324ae1e

Please sign in to comment.