Skip to content

Commit

Permalink
[luci-interpreter] Support bool type in NotEqual and ReduceMax op (Sa…
Browse files Browse the repository at this point in the history
…msung#13343)

This commit supports bool type in NotEqual and ReduceMax op.

ONE-DCO-1.0-Signed-off-by: seongwoo <[email protected]>
  • Loading branch information
mhs4670go authored Jul 3, 2024
1 parent f34fc8f commit f75846d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 0 deletions.
26 changes: 26 additions & 0 deletions compiler/luci-interpreter/src/kernels/NotEqual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void NotEqual::execute() const
case DataType::U8:
evalQuantized();
break;
case DataType::BOOL:
evalBool();
break;
default:
throw std::runtime_error("luci-intp NotEqual Unsupported type.");
}
Expand Down Expand Up @@ -138,5 +141,28 @@ void NotEqual::evalQuantized() const
}
}

void NotEqual::evalBool() const
{
const auto x_data = getTensorData<bool>(x());
const auto y_data = getTensorData<bool>(y());
auto output_data = getTensorData<bool>(output());

tflite::ComparisonParams op_params;
op_params.is_broadcast = x()->shape() != y()->shape();

if (op_params.is_broadcast)
{
tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data,
getTensorShape(y()), y_data,
getTensorShape(output()), output_data);
}
else
{
tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data,
getTensorShape(y()), y_data, getTensorShape(output()),
output_data);
}
}

} // namespace kernels
} // namespace luci_interpreter
1 change: 1 addition & 0 deletions compiler/luci-interpreter/src/kernels/NotEqual.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class NotEqual : public Kernel
void evalFloat() const;
template <typename T> void evalInteger() const;
void evalQuantized() const;
void evalBool() const;

private:
int32_t _x_multiplier = 0;
Expand Down
25 changes: 25 additions & 0 deletions compiler/luci-interpreter/src/kernels/ReduceMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ void ReduceMax::execute() const
case DataType::FLOAT32:
evalFloat();
break;
case DataType::BOOL:
evalBool();
break;
// TODO Support quantized kernels
default:
throw std::runtime_error("luci-intp ReduceMax Unsupported type.");
Expand Down Expand Up @@ -177,5 +180,27 @@ void ReduceMax::evalFloat() const
[](const float current, const float in) -> float { return (in > current) ? in : current; });
}

void ReduceMax::evalBool() const
{
const auto *axes_data = getTensorData<int32_t>(axes());
int num_axes = axes()->shape().num_elements();

auto temp_index = getOutputTensors()[1];
auto resolved_axes = getOutputTensors()[2];

int num_resolved_axis = 0;
LUCI_INTERPRETER_CHECK(
tflite::reference_ops::ResolveAxis(input()->shape().num_dims(), axes_data, num_axes,
getTensorData<int>(resolved_axes), &num_resolved_axis));

bool init_value = std::numeric_limits<bool>::lowest();
tflite::reference_ops::ReduceGeneric<bool>(
getTensorData<bool>(input()), getTensorShape(input()).DimsData(), input()->shape().num_dims(),
getTensorData<bool>(output()), getTensorShape(output()).DimsData(),
output()->shape().num_dims(), axes_data, num_axes, _params.keep_dims,
getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
[](const bool current, const bool in) -> bool { return (in > current) ? in : current; });
}

} // namespace kernels
} // namespace luci_interpreter
1 change: 1 addition & 0 deletions compiler/luci-interpreter/src/kernels/ReduceMax.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ReduceMax : public KernelWithParams<ReducerParams>

private:
void evalFloat() const;
void evalBool() const;
};

} // namespace kernels
Expand Down

0 comments on commit f75846d

Please sign in to comment.