Skip to content

Commit

Permalink
Allow explicit casts to any type.
Browse files Browse the repository at this point in the history
... as long as `UnderlyingTy` allows such a cast.

XLA has some generic code that casts between types
such as `std::complex` or `absl::int128` which are not
marked as arithmetic types.

PiperOrigin-RevId: 563459153
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Sep 7, 2023
1 parent 780b6d0 commit a9c93c0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct i4 {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
template <typename T>
explicit constexpr operator T() const {
return static_cast<T>(v);
}
Expand Down
13 changes: 12 additions & 1 deletion ml_dtypes/tests/int4_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,23 @@ TYPED_TEST(Int4Test, ToString) {
}
}

struct CustomInt {
constexpr CustomInt() : x(0) {}
constexpr CustomInt(int x) : x(x) {}
// NOLINTNEXTLINE(google-explicit-constructor)
constexpr operator int() const { return x; }
constexpr bool operator==(const CustomInt& other) const { return x == other.x; }
private:
int x;
};

#define GEN_DEST_TYPES(Type) \
std::pair<Type, bool>, std::pair<Type, uint4>, std::pair<Type, uint8_t>, \
std::pair<Type, uint16_t>, std::pair<Type, uint32_t>, \
std::pair<Type, uint64_t>, std::pair<Type, int4>, \
std::pair<Type, int8_t>, std::pair<Type, int16_t>, \
std::pair<Type, int32_t>, std::pair<Type, int64_t>
std::pair<Type, int32_t>, std::pair<Type, int64_t>, \
std::pair<Type, CustomInt>

#define GEN_TYPE_PAIRS() GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4)

Expand Down

0 comments on commit a9c93c0

Please sign in to comment.