Skip to content

Commit

Permalink
Added MinMagnitude and MaxMagnitude ops
Browse files Browse the repository at this point in the history
  • Loading branch information
johnplatts committed Oct 11, 2024
1 parent 0a3c901 commit 18c9d93
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 0 deletions.
18 changes: 18 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,24 @@ is qNaN, and NaN if both are.

* <code>V **Max**(V a, V b)</code>: returns `max(a[i], b[i])`.

* <code>V **MinMagnitude**(V a, V b)</code>: returns the number with the
smaller magnitude if `a[i]` and `b[i]` are both non-NaN values.

If `a[i]` and `b[i]` are both non-NaN, `MinMagnitude(a, b)` returns
`(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? a[i] : b[i]`.

Otherwise, the results of `MinMagnitude(a, b)` are implementation-defined
if `a[i]` is NaN or `b[i]` is NaN.

* <code>V **MaxMagnitude**(V a, V b)</code>: returns the number with the
larger magnitude if `a[i]` and `b[i]` are both non-NaN values.

If `a[i]` and `b[i]` are both non-NaN, `MaxMagnitude(a, b)` returns
`(|a[i]| < |b[i]| || (|a[i]| == |b[i]| && a[i] < b[i])) ? b[i] : a[i]`.

Otherwise, the results of `MaxMagnitude(a, b)` are implementation-defined
if `a[i]` is NaN or `b[i]` is NaN.

All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:

* `V`: `u64` \
Expand Down
57 changes: 57 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,63 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif

// ------------------------------ MinMagnitude/MaxMagnitude

#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#else
#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MinMagnitude(V a, V b) {
const auto abs_a = Abs(a);
const auto abs_b = Abs(b);
return IfThenElse(Lt(abs_a, abs_b), a,
Min(IfThenElse(Eq(abs_a, abs_b), a, b), b));
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
const auto abs_a = Abs(a);
const auto abs_b = Abs(b);
return IfThenElse(Lt(abs_a, abs_b), b,
Max(IfThenElse(Eq(abs_a, abs_b), b, a), a));
}

#endif // HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE

template <class V, HWY_IF_SIGNED_V(V)>
HWY_API V MinMagnitude(V a, V b) {
const DFromV<V> d;
const RebindToUnsigned<decltype(d)> du;
const auto abs_a = BitCast(du, Abs(a));
const auto abs_b = BitCast(du, Abs(b));
return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), a,
Min(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), a, b), b));
}

template <class V, HWY_IF_SIGNED_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
const DFromV<V> d;
const RebindToUnsigned<decltype(d)> du;
const auto abs_a = BitCast(du, Abs(a));
const auto abs_b = BitCast(du, Abs(b));
return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), b,
Max(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), b, a), a));
}

template <class V, HWY_IF_UNSIGNED_V(V)>
HWY_API V MinMagnitude(V a, V b) {
return Min(a, b);
}

template <class V, HWY_IF_UNSIGNED_V(V)>
HWY_API V MaxMagnitude(V a, V b) {
return Max(a, b);
}

// ------------------------------ AddSub

template <class V, HWY_IF_LANES_D(DFromV<V>, 1)>
Expand Down
101 changes: 101 additions & 0 deletions hwy/tests/minmax_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,106 @@ HWY_NOINLINE void TestAllMinMax128Upper() {
ForGEVectors<128, TestMinMax128Upper>()(uint64_t());
}

struct TestMinMaxMagnitude {
template <class T>
static constexpr MakeSigned<T> MaxPosIotaVal(hwy::FloatTag /*type_tag*/) {
return static_cast<MakeSigned<T>>(MantissaMask<T>() + 1);
}
template <class T>
static constexpr MakeSigned<T> MaxPosIotaVal(hwy::NonFloatTag /*type_tag*/) {
return static_cast<MakeSigned<T>>(((LimitsMax<MakeSigned<T>>()) >> 1) + 1);
}

template <class D>
HWY_NOINLINE static void VerifyMinMaxMagnitude(
D d, const TFromD<D>* HWY_RESTRICT in1_lanes,
const TFromD<D>* HWY_RESTRICT in2_lanes, const int line) {
using T = TFromD<D>;
using TAbs = If<IsFloat<T>() || IsSpecialFloat<T>(), T, MakeUnsigned<T>>;

const char* file = __FILE__;
const size_t N = Lanes(d);
auto expected_min_mag = AllocateAligned<T>(N);
auto expected_max_mag = AllocateAligned<T>(N);
HWY_ASSERT(expected_min_mag && expected_max_mag);

for (size_t i = 0; i < N; i++) {
const T val1 = in1_lanes[i];
const T val2 = in2_lanes[i];
const TAbs abs_val1 = static_cast<TAbs>(ScalarAbs(val1));
const TAbs abs_val2 = static_cast<TAbs>(ScalarAbs(val2));
if (abs_val1 < abs_val2 || (abs_val1 == abs_val2 && val1 < val2)) {
expected_min_mag[i] = val1;
expected_max_mag[i] = val2;
} else {
expected_min_mag[i] = val2;
expected_max_mag[i] = val1;
}
}

const auto in1 = Load(d, in1_lanes);
const auto in2 = Load(d, in2_lanes);
AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in1, in2), file,
line);
AssertVecEqual(d, expected_min_mag.get(), MinMagnitude(in2, in1), file,
line);
AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in1, in2), file,
line);
AssertVecEqual(d, expected_max_mag.get(), MaxMagnitude(in2, in1), file,
line);
}

template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
using TI = MakeSigned<T>;
using TU = MakeSigned<T>;
constexpr TI kMaxPosIotaVal = MaxPosIotaVal<T>(hwy::IsFloatTag<T>());
static_assert(kMaxPosIotaVal > 0, "kMaxPosIotaVal > 0 must be true");

constexpr size_t kPositiveIotaMask = static_cast<size_t>(
static_cast<TU>(kMaxPosIotaVal - 1) & (HWY_MAX_LANES_D(D) - 1));

const size_t N = Lanes(d);
auto in1_lanes = AllocateAligned<T>(N);
auto in2_lanes = AllocateAligned<T>(N);
auto in3_lanes = AllocateAligned<T>(N);
auto in4_lanes = AllocateAligned<T>(N);
HWY_ASSERT(in1_lanes && in2_lanes && in3_lanes && in4_lanes);

for (size_t i = 0; i < N; i++) {
const TI x1 = static_cast<TI>((i & kPositiveIotaMask) + 1);
const TI x2 = static_cast<TI>(kMaxPosIotaVal - x1);
const TI x3 = static_cast<TI>(-x1);
const TI x4 = static_cast<TI>(-x2);

in1_lanes[i] = ConvertScalarTo<T>(x1);
in2_lanes[i] = ConvertScalarTo<T>(x2);
in3_lanes[i] = ConvertScalarTo<T>(x3);
in4_lanes[i] = ConvertScalarTo<T>(x4);
}

VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__);

in2_lanes[0] = HighestValue<T>();
in4_lanes[0] = LowestValue<T>();

VerifyMinMaxMagnitude(d, in1_lanes.get(), in2_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in1_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in3_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in2_lanes.get(), in4_lanes.get(), __LINE__);
VerifyMinMaxMagnitude(d, in3_lanes.get(), in4_lanes.get(), __LINE__);
}
};

HWY_NOINLINE void TestAllMinMaxMagnitude() {
ForAllTypes(ForPartialVectors<TestMinMaxMagnitude>());
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
Expand All @@ -269,6 +369,7 @@ HWY_BEFORE_TEST(HwyMinMaxTest);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMax128Upper);
HWY_EXPORT_AND_TEST_P(HwyMinMaxTest, TestAllMinMaxMagnitude);
HWY_AFTER_TEST();
} // namespace hwy

Expand Down

0 comments on commit 18c9d93

Please sign in to comment.