From d90cf75488cd101fa438d33203ff88f4e9ee1ac9 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 30 Sep 2024 11:20:54 -0700 Subject: [PATCH] PR #16585: Add support for float8_e4m3 and float8_e3m4 types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](https://github.com/openxla/stablehlo/pull/2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](https://github.com/openxla/stablehlo/pull/2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](https://github.com/jax-ml/ml_dtypes/pull/161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](https://github.com/jax-ml/ml_dtypes/pull/171/) Add float8_e3m4 (Merged) - XLA [PR-17075](https://github.com/openxla/xla/pull/17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](https://github.com/openxla/xla/pull/3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](https://github.com/google/jax/pull/23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov : Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037 --- third_party/tsl/tsl/platform/ml_dtypes.h | 2 + xla/BUILD | 2 + xla/array2d_test.cc | 28 ++ xla/ffi/api/api.h | 4 + xla/ffi/api/c_api.h | 2 + xla/ffi/api/ffi.h | 6 + xla/ffi/api/ffi_test.cc | 6 + xla/ffi/call_frame.cc | 2 + xla/fp_util_test.cc | 62 ++- xla/hlo/builder/lib/math.cc | 10 +- .../evaluator/hlo_evaluator_typed_visitor.h | 2 + .../hlo_evaluator_typed_visitor_float8.cc | 2 + .../translate/hlo_to_mhlo/tests/import.hlo | 24 +- .../translate/mhlo_to_hlo/tests/export.mlir | 18 +- xla/literal.cc | 41 +- xla/literal_comparison.cc | 12 +- xla/literal_comparison_test.cc | 37 +- xla/literal_test.cc | 66 ++- xla/mlir/utils/type_util.cc | 8 + xla/mlir/utils/type_util_test.cc | 2 + .../mhlo/hlo-legalize-to-stablehlo.mlir | 14 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 14 + .../mhlo/stablehlo-legalize-to-hlo.mlir | 14 + xla/pjrt/c/CHANGELOG.md | 3 + xla/pjrt/c/pjrt_c_api.h | 6 +- xla/pjrt/c/pjrt_c_api_helpers.cc | 8 + xla/primitive_util.h | 31 +- xla/primitive_util_test.cc | 104 ++++- xla/python/ifrt/dtype.cc | 10 + xla/python/ifrt/dtype.h | 4 +- xla/python/ifrt/dtype.proto | 2 + xla/python/ifrt/dtype_test.cc | 4 + xla/python/pjrt_ifrt/pjrt_dtype.cc | 4 + xla/python/py_values.cc | 17 + xla/python/types.cc | 40 ++ xla/python/types.h | 4 +- xla/python/xla.cc | 3 + xla/python/xla_client.py | 6 + xla/python/xla_client.pyi | 3 + xla/python/xla_client_test.py | 5 + xla/python/xla_extension/__init__.pyi | 2 + xla/service/BUILD | 8 +- xla/service/cpu/cpu_compiler.cc | 4 + xla/service/cpu/onednn_memory_util.h | 2 +- xla/service/elemental_ir_emitter.cc | 407 ++++++++++++++++-- xla/service/elemental_ir_emitter_test.cc | 20 +- xla/service/float8_fnuz_ir_emitter.cc | 7 + xla/service/float_normalization_test.cc | 2 +- .../fusions/transforms/expand_float_ops.cc | 22 +- xla/service/gpu/gpu_compiler.cc | 4 + xla/service/gpu/ir_emission_utils.cc | 3 +- .../gpu/tests/float_conversions_test.cc | 5 +- xla/service/llvm_ir/llvm_util.cc | 2 + xla/stream_executor/data_type.h | 8 + xla/stream_executor/dnn.cc | 2 + xla/stream_executor/gpu/gpu_blas_lt.cc | 10 + xla/stream_executor/rocm/BUILD | 4 +- xla/stream_executor/rocm/hip_blas_utils.cc | 6 +- xla/tests/BUILD | 2 + xla/tests/array_elementwise_ops_test.cc | 3 +- xla/tests/constants_test.cc | 11 +- xla/tests/convert_test.cc | 392 ++++++++++++++++- xla/tests/float8_test.cc | 5 +- xla/tools/driver.cc | 28 +- xla/tsl/framework/type_traits.h | 2 + xla/tsl/protobuf/dnn.proto | 2 + xla/tsl/python/lib/core/ml_dtypes.cc | 6 + xla/tsl/python/lib/core/ml_dtypes.h | 2 + xla/util.cc | 14 +- xla/util.h | 6 + xla/util_test.cc | 37 ++ xla/xla_data.proto | 14 +- 72 files changed, 1516 insertions(+), 158 deletions(-) diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 916be8db4f6998..89a40bd891e106 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -20,6 +20,8 @@ limitations under the License. #include "ml_dtypes/include/intn.h" // from @ml_dtypes namespace tsl { +using float8_e3m4 = ::ml_dtypes::float8_e3m4; +using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; diff --git a/xla/BUILD b/xla/BUILD index 44459e0e6df7d9..fb6cdbe281e4a3 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -316,6 +316,7 @@ xla_cc_test( ":util", "@com_google_absl//absl/base", "@com_google_absl//absl/numeric:bits", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", ], @@ -373,6 +374,7 @@ xla_cc_test( ":test", ":types", ":util", + "@ml_dtypes//:float8", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 4d0fbf3732ff9a..4686e2ec5c1ac6 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF8E4M3) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, LinspaceF8E4M3Fn) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); @@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) { } } +TEST(Array2dTest, LinspaceF8E3M4) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 0e142c42286e12..31a84a7d929e60 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os, return os << "TOKEN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; + case XLA_FFI_DataType_F8E3M4: + return os << "F8E3M4"; + case XLA_FFI_DataType_F8E4M3: + return os << "F8E4M3"; case XLA_FFI_DataType_F8E4M3FN: return os << "F8E4M3FN"; case XLA_FFI_DataType_F8E4M3B11FNUZ: diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index d5e2b11538133f..f0c4f40e78ea7a 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -195,6 +195,8 @@ typedef enum { XLA_FFI_DataType_C128 = 18, XLA_FFI_DataType_TOKEN = 17, XLA_FFI_DataType_F8E5M2 = 19, + XLA_FFI_DataType_F8E3M4 = 29, + XLA_FFI_DataType_F8E4M3 = 28, XLA_FFI_DataType_F8E4M3FN = 20, XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index b31da22175333d..e6560833cb4aeb 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -73,10 +73,12 @@ enum class DataType : uint8_t { C128 = XLA_FFI_DataType_C128, TOKEN = XLA_FFI_DataType_TOKEN, F8E5M2 = XLA_FFI_DataType_F8E5M2, + F8E4M3 = XLA_FFI_DataType_F8E4M3, F8E4M3FN = XLA_FFI_DataType_F8E4M3FN, F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ, F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, + F8E3M4 = XLA_FFI_DataType_F8E3M4, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64; inline constexpr DataType C128 = DataType::C128; inline constexpr DataType TOKEN = DataType::TOKEN; inline constexpr DataType F8E5M2 = DataType::F8E5M2; +inline constexpr DataType F8E4M3 = DataType::F8E4M3; inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN; inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; +inline constexpr DataType F8E3M4 = DataType::F8E3M4; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::S8: case DataType::U8: case DataType::F8E5M2: + case DataType::F8E4M3: case DataType::F8E4M3FN: case DataType::F8E4M3B11FNUZ: case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: + case DataType::F8E3M4: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 74837790c8449c..315587b94463da 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ), encoded(DataType::F8E4M3B11FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); + EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), + ByteWidth(DataType::F8E4M3)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN), ByteWidth(DataType::F8E4M3FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ), @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E5M2FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ), ByteWidth(DataType::F8E4M3FNUZ)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), + ByteWidth(DataType::F8E3M4)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 12fed1ba745440..3fb2ac3c7786fa 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -265,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C128: case PrimitiveType::TOKEN: case PrimitiveType::F8E5M2: + case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: case PrimitiveType::F8E4M3B11FNUZ: case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: + case PrimitiveType::F8E3M4: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 36f0c5be9d5bde..3eb7c54f919b0a 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" #include "absl/numeric/bits.h" #include "xla/bit_cast.h" @@ -111,21 +112,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); -// Test F8E4M3 floating-point types (F8E4M3FN) +// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN) template class FP8E4M3DistanceTest : public ::testing::Test {}; -using F8E4M3Types = ::testing::Types; +using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F8E3M4Distance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(8.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(15.5)), + 15); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(6)), + 8); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ( + CalculateDistanceInFloats( + std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ( + CalculateDistanceInFloats( + -std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 32); +} + TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) { // a & b are equal, distance should be 0 EXPECT_EQ( CalculateDistanceInFloats(TypeParam(8.0), TypeParam(8.0)), 0); // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats(TypeParam(8.0), TypeParam(13)), - 5); + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(15.0)), 7); // a & b have different exponents EXPECT_EQ( diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index f7c00aece14d0e..e7792c65b7370a 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F8E3M4: + case F8E4M3: case F8E5M2: case F8E4M3FN: case F8E4M3B11FNUZ: @@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index d7bc1ba49a9bf7..8157fe34baee0e 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 7c97c210aa36a5..d425d33c2feab5 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -19,8 +19,10 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 68ad73882e02a7..3a1e7ceabb160f 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -410,11 +410,17 @@ add { // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ> %constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> + // CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> %constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> %constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + %constant.12 = f8e4m3[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -524,7 +530,19 @@ add { %convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10) // CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32> - ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + + // CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3> + %convert.13 = f8e4m3[4] convert(f32[4] %convert.12) + + // CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32> + %convert.14 = f32[4] convert(f8e4m3[4] %convert.13) + + // CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4> + %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) + + // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> + ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index b4e7a128a5d1ed..c41792b1556338 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -600,6 +600,12 @@ func.func @main() { // CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4}) %cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: f8e4m3[4] constant({1, 2, 3, 4}) + %cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + + // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) + %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + func.return } @@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32> %6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ> %7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32> - func.return %7 : tensor<2xf32> + %8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3> + %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> + %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> + %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> + func.return %11 : tensor<2xf32> } // CHECK: ENTRY @@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) // CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) // CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) +// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) +// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) +// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 971b1d48ac563b..ed0716f8ec50e8 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,9 +91,10 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || - !proto.f8e4m3fnuzs().empty() || !proto.f16s().empty() || + !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || + !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || + !proto.f8e3m4s().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1684,7 +1685,15 @@ void ConvertBetweenNativeTypes(absl::Span src_data, return std::numeric_limits::lowest(); } } - return static_cast(src); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case e3m4 by + // casting to half first. + if constexpr (sizeof(src) == 1 && + std::is_same_v) { + return static_cast(static_cast(src)); + } else { + return static_cast(src); + } }; NativeDestT* dest_data = static_cast(dst_base); @@ -2258,6 +2267,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E4M3: + *proto->mutable_f8e4m3s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E4M3FN: *proto->mutable_f8e4m3fns() = std::string( reinterpret_cast(data().data()), @@ -2278,6 +2292,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E3M4: + *proto->mutable_f8e3m4s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2436,6 +2455,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E4M3: { + const std::string& s(proto.f8e4m3s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e4m3) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E4M3FN: { const std::string& s(proto.f8e4m3fns()); TF_RET_CHECK(data().size() * @@ -2468,6 +2494,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E3M4: { + const std::string& s(proto.f8e3m4s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e3m4) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index fa3a7cda9824cd..c97629594122bb 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -354,8 +354,16 @@ class NearComparator { return primitive_util::FloatingPointTypeSwitch( [&](const auto kType) -> int { using NarrowNativeT = primitive_util::NativeTypeOf; - return CalculateDistanceInFloats(NarrowNativeT(expected), - NarrowNativeT(actual)); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case + // e3m4 by casting to half first. + if constexpr (std::is_same_v) { + return CalculateDistanceInFloats(NarrowNativeT(half(expected)), + NarrowNativeT(half(actual))); + } else { + return CalculateDistanceInFloats(NarrowNativeT(expected), + NarrowNativeT(actual)); + } }, error_.low_precision_fp_error_spec.type); } diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 37b7c31f267104..7713aceaaa3bc5 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -29,14 +29,15 @@ namespace { template class LiteralComparisonTest : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); - TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), + TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -44,15 +45,19 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 10.0 : 9.0; + float expV = 9.0; // F8E4M3* + if (type == F8E5M2) + expV = 10.0; + else if (type == F8E3M4) + expV = 8.5; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -60,17 +65,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 14.0 : 12.0; + float expV = 12.0; // F8E4M3* + if (type == F8E5M2) + expV = 14.0; + else if (type == F8E3M4) + expV = 10.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -78,17 +87,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(8.0); - float expV = type == F8E5M2 ? 13.0 : 12.1; + float expV = 12.1; // F8E4M3* + if (type == F8E5M2) + expV = 13.0; + else if (type == F8E3M4) + expV = 10.125; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 767fe581121db3..65aa09040668fb 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -125,9 +125,10 @@ template class LiteralUtilFloatTest : public LiteralUtilTest {}; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -184,8 +185,12 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { EXPECT_EQ("f8e5m2[] 3", f8e5m2_lit_truncated.ToString()); auto f8e4m3_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3(0.5)); + EXPECT_EQ("f8e4m3[] 0.5", f8e4m3_lit.ToString()); + + auto f8e4m3fn_lit = LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); - EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); + EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3fn_lit.ToString()); auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0( tsl::float8_e4m3b11fnuz(0.5)); @@ -198,6 +203,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e5m2fnuz_lit = LiteralUtil::CreateR0(tsl::float8_e5m2fnuz(0.5)); EXPECT_EQ("f8e5m2fnuz[] 0.5", f8e5m2fnuz_lit.ToString()); + + auto f8e3m4_lit = + LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); + EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -650,20 +659,24 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); - tsl::float8_e5m2 q16(8); - EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(8)); + tsl::float8_e5m2 p16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false - EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(9)); + EXPECT_FALSE(LiteralUtil::CreateR1({p16}).IsAll(9)); - tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 + tsl::float8_e4m3 q16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(9)); + + tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3fn EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); - tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3 + tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3b11fnuz EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); - tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 + tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3fnuz EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({t16}).IsAll(9)); @@ -672,6 +685,10 @@ TEST_F(LiteralUtilTest, IsAll) { // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false EXPECT_FALSE(LiteralUtil::CreateR1({u16}).IsAll(9)); + tsl::float8_e3m4 v16(9); // Exactly representable in e3m4 + EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2200,9 +2217,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); - using e4 = tsl::float8_e4m3fn; + using e4 = tsl::float8_e4m3; auto vector_f8e4m3 = LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); + using e4fn = tsl::float8_e4m3fn; + auto vector_f8e4m3fn = + LiteralUtil::CreateR1({e4fn{10.0}, e4fn{20.0}, e4fn{-32.0}}); using b11 = tsl::float8_e4m3b11fnuz; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); @@ -2212,6 +2232,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e4f = tsl::float8_e4m3fnuz; auto vector_f8e4m3fnuz = LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); + using e3 = tsl::float8_e3m4; + auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2234,9 +2256,11 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2521,6 +2545,18 @@ TEST_F(LiteralUtilTest, IsEqualAt) { tsl::float8_e4m3fnuz{val_double}); EXPECT_TRUE(c6.IsEqualAt({}, val_double)); EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); + Literal c8 = + LiteralUtil::CreateR0(tsl::float8_e4m3{val_double}); + EXPECT_TRUE(c8.IsEqualAt({}, val_double)); + EXPECT_TRUE(c8.IsEqualAt({}, val_integral)); + Literal c9 = + LiteralUtil::CreateR0(tsl::float8_e4m3fn{val_double}); + EXPECT_TRUE(c9.IsEqualAt({}, val_double)); + EXPECT_TRUE(c9.IsEqualAt({}, val_integral)); + Literal c10 = + LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); + EXPECT_TRUE(c10.IsEqualAt({}, val_double)); + EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2846,10 +2882,10 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, - F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 59b19c34611412..2581390a1e13d7 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -34,6 +34,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getI1Type(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); + case xla::PrimitiveType::F8E4M3: + return b.getFloat8E4M3Type(); case xla::PrimitiveType::F8E4M3FN: return b.getFloat8E4M3FNType(); case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -42,6 +44,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E5M2FNUZType(); case xla::PrimitiveType::F8E4M3FNUZ: return b.getFloat8E4M3FNUZType(); + case xla::PrimitiveType::F8E3M4: + return b.getFloat8E3M4Type(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -76,6 +80,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; + } else if (type.isFloat8E4M3()) { + return xla::PrimitiveType::F8E4M3; } else if (type.isFloat8E4M3FN()) { return xla::PrimitiveType::F8E4M3FN; } else if (type.isFloat8E4M3B11FNUZ()) { @@ -84,6 +90,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E4M3FNUZ; } else if (type.isFloat8E5M2FNUZ()) { return xla::PrimitiveType::F8E5M2FNUZ; + } else if (type.isFloat8E3M4()) { + return xla::PrimitiveType::F8E3M4; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index 6c19098574dec5..a8043ab0b5f140 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -102,6 +102,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, + {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, {F8E4M3B11FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3B11FNUZType(); }}, @@ -109,6 +110,7 @@ INSTANTIATE_TEST_SUITE_P( [](mlir::Builder b) { return b.getFloat8E5M2FNUZType(); }}, {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, + {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 8baa3e0d3298df..59618001c2d7cc 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 65594f55fd979d..03b6a21e07210c 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @f8e4m3(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e4m3fn(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 0f2e1b108a710f..66c388b9ed373e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 477045dfb93adb..20b660ed6eecfd 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,8 @@ # PJRT C API changelog +## 0.55 +* Added types F8E4M3 and F8E3M4. + ## 0.54 * Deprecated PJRT_Buffer_GetMemoryLayout. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 1d5b44c60201c5..a96f35920b9fa1 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 54 +#define PJRT_API_MINOR 55 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -644,6 +644,10 @@ typedef enum { // 2-bit integer types PJRT_Buffer_Type_S2, PJRT_Buffer_Type_U2, + + // More truncated 8 bit floating-point formats. + PJRT_Buffer_Type_F8E4M3, + PJRT_Buffer_Type_F8E3M4, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index b9508cf24950b4..5121877e1f4dd0 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -295,6 +295,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; + case xla::PrimitiveType::F8E4M3: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3; case xla::PrimitiveType::F8E4M3FN: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN; case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -303,6 +305,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2FNUZ; case xla::PrimitiveType::F8E4M3FNUZ: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; + case xla::PrimitiveType::F8E3M4: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -358,6 +362,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C128; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: + return xla::PrimitiveType::F8E4M3; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN: return xla::PrimitiveType::F8E4M3FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ: @@ -366,6 +372,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E5M2FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ: return xla::PrimitiveType::F8E4M3FNUZ; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: + return xla::PrimitiveType::F8E3M4; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 8fbeedbff94dad..de5ee4fde11d7b 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -180,6 +180,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E4M3; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FN; @@ -200,6 +205,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FNUZ; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E3M4; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -309,6 +319,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e4m3; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fn; @@ -329,6 +344,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fnuz; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e3m4; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -362,8 +382,9 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { } constexpr bool IsF8Type(PrimitiveType type) { - return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ || - type == F8E5M2FNUZ || type == F8E4M3FNUZ; + return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || + type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || + type == F8E3M4; } constexpr bool IsFloatingPointType(PrimitiveType type) { @@ -428,6 +449,12 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F8E3M4: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E4M3: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E4M3FN: return std::forward(f)( PrimitiveTypeConstant()); diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index e8c9dc77087062..850203f17379a4 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -76,10 +76,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][BF16] = true; expecteds[PRED][C128] = true; expecteds[PRED][F8E5M2] = true; + expecteds[PRED][F8E4M3] = true; expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = true; expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = true; + expecteds[PRED][F8E3M4] = true; expecteds[S2][PRED] = false; expecteds[S2][S2] = true; expecteds[S2][S4] = true; @@ -100,10 +102,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][BF16] = true; expecteds[S2][C128] = true; expecteds[S2][F8E5M2] = true; + expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; expecteds[S2][F8E4M3B11FNUZ] = true; expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; + expecteds[S2][F8E3M4] = true; expecteds[S4][PRED] = false; expecteds[S4][S2] = false; expecteds[S4][S4] = true; @@ -124,10 +128,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][BF16] = true; expecteds[S4][C128] = true; expecteds[S4][F8E5M2] = true; + expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; expecteds[S4][F8E4M3B11FNUZ] = true; expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; + expecteds[S4][F8E3M4] = true; expecteds[S8][PRED] = false; expecteds[S8][S2] = false; expecteds[S8][S4] = false; @@ -148,10 +154,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][BF16] = true; expecteds[S8][C128] = true; expecteds[S8][F8E5M2] = false; + expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; expecteds[S8][F8E4M3B11FNUZ] = false; expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; + expecteds[S8][F8E3M4] = false; expecteds[S16][PRED] = false; expecteds[S16][S2] = false; expecteds[S16][S4] = false; @@ -172,10 +180,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][BF16] = false; expecteds[S16][C128] = true; expecteds[S16][F8E5M2] = false; + expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; expecteds[S16][F8E4M3B11FNUZ] = false; expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; + expecteds[S16][F8E3M4] = false; expecteds[S32][PRED] = false; expecteds[S32][S2] = false; expecteds[S32][S4] = false; @@ -196,10 +206,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][BF16] = false; expecteds[S32][C128] = true; expecteds[S32][F8E5M2] = false; + expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; expecteds[S32][F8E4M3B11FNUZ] = false; expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; + expecteds[S32][F8E3M4] = false; expecteds[S64][PRED] = false; expecteds[S64][S2] = false; expecteds[S64][S4] = false; @@ -220,10 +232,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][BF16] = false; expecteds[S64][C128] = false; expecteds[S64][F8E5M2] = false; + expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; expecteds[S64][F8E4M3B11FNUZ] = false; expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; + expecteds[S64][F8E3M4] = false; expecteds[U2][PRED] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; @@ -246,10 +260,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][BF16] = true; expecteds[U2][C128] = true; expecteds[U2][F8E5M2] = true; + expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; expecteds[U2][F8E4M3B11FNUZ] = true; expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; + expecteds[U2][F8E3M4] = true; expecteds[U4][PRED] = false; expecteds[U4][S2] = false; expecteds[U4][S4] = false; @@ -272,10 +288,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][BF16] = true; expecteds[U4][C128] = true; expecteds[U4][F8E5M2] = false; + expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; expecteds[U4][F8E4M3B11FNUZ] = true; expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; + expecteds[U4][F8E3M4] = true; expecteds[U8][PRED] = false; expecteds[U8][S2] = false; expecteds[U8][S4] = false; @@ -298,10 +316,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][BF16] = true; expecteds[U8][C128] = true; expecteds[U8][F8E5M2] = false; + expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; expecteds[U8][F8E4M3B11FNUZ] = false; expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; + expecteds[U8][F8E3M4] = false; expecteds[U16][PRED] = false; expecteds[U16][S2] = false; expecteds[U16][S4] = false; @@ -322,10 +342,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][BF16] = false; expecteds[U16][C128] = true; expecteds[U16][F8E5M2] = false; + expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; expecteds[U16][F8E4M3B11FNUZ] = false; expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; + expecteds[U16][F8E3M4] = false; expecteds[U32][PRED] = false; expecteds[U32][S2] = false; expecteds[U32][S4] = false; @@ -346,10 +368,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][BF16] = false; expecteds[U32][C128] = true; expecteds[U32][F8E5M2] = false; + expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; expecteds[U32][F8E4M3B11FNUZ] = false; expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; + expecteds[U32][F8E3M4] = false; expecteds[U64][PRED] = false; expecteds[U64][S2] = false; expecteds[U64][S4] = false; @@ -370,10 +394,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][BF16] = false; expecteds[U64][C128] = false; expecteds[U64][F8E5M2] = false; + expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; expecteds[U64][F8E4M3B11FNUZ] = false; expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; + expecteds[U64][F8E3M4] = false; expecteds[F16][PRED] = false; expecteds[F16][S2] = false; expecteds[F16][S4] = false; @@ -394,10 +420,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][BF16] = false; expecteds[F16][C128] = true; expecteds[F16][F8E5M2] = false; + expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; expecteds[F16][F8E4M3B11FNUZ] = false; expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; + expecteds[F16][F8E3M4] = false; expecteds[F32][PRED] = false; expecteds[F32][S2] = false; expecteds[F32][S4] = false; @@ -418,10 +446,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][BF16] = false; expecteds[F32][C128] = true; expecteds[F32][F8E5M2] = false; + expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; expecteds[F32][F8E4M3B11FNUZ] = false; expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; + expecteds[F32][F8E3M4] = false; expecteds[F64][PRED] = false; expecteds[F64][S2] = false; expecteds[F64][S4] = false; @@ -442,10 +472,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][BF16] = false; expecteds[F64][C128] = true; expecteds[F64][F8E5M2] = false; + expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; expecteds[F64][F8E4M3B11FNUZ] = false; expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; + expecteds[F64][F8E3M4] = false; expecteds[C64][PRED] = false; expecteds[C64][S2] = false; expecteds[C64][S4] = false; @@ -466,10 +498,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][BF16] = false; expecteds[C64][C128] = true; expecteds[C64][F8E5M2] = false; + expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; expecteds[C64][F8E4M3B11FNUZ] = false; expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; + expecteds[C64][F8E3M4] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S2] = false; expecteds[BF16][S4] = false; @@ -490,10 +524,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; expecteds[BF16][F8E5M2] = false; + expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; expecteds[BF16][F8E4M3B11FNUZ] = false; expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; + expecteds[BF16][F8E3M4] = false; expecteds[C128][PRED] = false; expecteds[C128][S2] = false; expecteds[C128][S4] = false; @@ -514,10 +550,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][BF16] = false; expecteds[C128][C128] = true; expecteds[C128][F8E5M2] = false; + expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; expecteds[C128][F8E4M3B11FNUZ] = false; expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; + expecteds[C128][F8E3M4] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S2] = false; expecteds[F8E5M2][S4] = false; @@ -538,10 +576,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; expecteds[F8E5M2][F8E5M2] = true; + expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; expecteds[F8E5M2][F8E4M3B11FNUZ] = false; expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; + expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E4M3][PRED] = false; + expecteds[F8E4M3][S2] = false; + expecteds[F8E4M3][S4] = false; + expecteds[F8E4M3][S8] = false; + expecteds[F8E4M3][S16] = false; + expecteds[F8E4M3][S32] = false; + expecteds[F8E4M3][S64] = false; + expecteds[F8E4M3][U2] = false; + expecteds[F8E4M3][U4] = false; + expecteds[F8E4M3][U8] = false; + expecteds[F8E4M3][U16] = false; + expecteds[F8E4M3][U32] = false; + expecteds[F8E4M3][U64] = false; + expecteds[F8E4M3][F16] = true; + expecteds[F8E4M3][F32] = true; + expecteds[F8E4M3][F64] = true; + expecteds[F8E4M3][C64] = true; + expecteds[F8E4M3][BF16] = true; + expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F8E5M2] = false; + expecteds[F8E4M3][F8E5M2FNUZ] = false; + expecteds[F8E4M3][F8E4M3] = true; + expecteds[F8E4M3][F8E4M3FN] = false; + expecteds[F8E4M3][F8E4M3FNUZ] = false; + expecteds[F8E4M3][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3][F8E3M4] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; @@ -562,8 +628,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; expecteds[F8E4M3FN][F8E5M2] = false; + expecteds[F8E4M3FN][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E4M3] = false; expecteds[F8E4M3FN][F8E4M3FN] = true; + expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3FN][F8E3M4] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S2] = false; expecteds[F8E4M3B11FNUZ][S4] = false; @@ -584,12 +654,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; + expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; expecteds[F8E4M3B11FNUZ][F8E4M3B11FNUZ] = true; expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E4M3FNUZ] = false; + expecteds[F8E4M3B11FNUZ][F8E3M4] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; @@ -610,10 +680,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; expecteds[F8E5M2FNUZ][F8E5M2] = false; + expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; expecteds[F8E5M2FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; + expecteds[F8E5M2FNUZ][F8E3M4] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S2] = false; expecteds[F8E4M3FNUZ][S4] = false; @@ -634,10 +706,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; expecteds[F8E4M3FNUZ][F8E5M2] = false; + expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; expecteds[F8E4M3FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; + expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E3M4][PRED] = false; + expecteds[F8E3M4][S2] = false; + expecteds[F8E3M4][S4] = false; + expecteds[F8E3M4][S8] = false; + expecteds[F8E3M4][S16] = false; + expecteds[F8E3M4][S32] = false; + expecteds[F8E3M4][S64] = false; + expecteds[F8E3M4][U2] = false; + expecteds[F8E3M4][U4] = false; + expecteds[F8E3M4][U8] = false; + expecteds[F8E3M4][U16] = false; + expecteds[F8E3M4][U32] = false; + expecteds[F8E3M4][U64] = false; + expecteds[F8E3M4][F16] = true; + expecteds[F8E3M4][F32] = true; + expecteds[F8E3M4][F64] = true; + expecteds[F8E3M4][C64] = true; + expecteds[F8E3M4][BF16] = true; + expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F8E5M2] = false; + expecteds[F8E3M4][F8E5M2FNUZ] = false; + expecteds[F8E3M4][F8E4M3] = false; + expecteds[F8E3M4][F8E4M3FN] = false; + expecteds[F8E3M4][F8E4M3FNUZ] = false; + expecteds[F8E3M4][F8E4M3B11FNUZ] = false; + expecteds[F8E3M4][F8E3M4] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index 1de5702b6cc8df..17e2cfa281d251 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -37,6 +37,8 @@ std::optional DType::byte_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -78,6 +80,8 @@ std::optional DType::bit_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -133,6 +137,9 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -175,6 +182,9 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 06a92b67f863c8..911702512c0501 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -78,13 +78,15 @@ class DType { // dtype will have empty dimensions. kToken = 17, + kF8E3M4 = 29, + kF8E4M3 = 28, kF8E4M3FN = 20, kF8E4M3B11FNUZ = 23, kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, - // Next = 26 + // Next = 30 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index eadfd42a3550cd..37976833e7e8c7 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -60,6 +60,8 @@ message DTypeProto { // dtype will have empty dimensions. KIND_TOKEN = 17; + KIND_F8E3M4 = 29; + KIND_F8E4M3 = 28; KIND_F8E4M3FN = 20; KIND_F8E4M3B11FNUZ = 23; KIND_F8E4M3FNUZ = 25; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 5ac531dabcb9ce..57fec6702d277d 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -49,6 +49,8 @@ TEST(DTypeTest, ByteSize) { {DType::kPred, 1}, {DType::kS8, 1}, {DType::kU8, 1}, + {DType::kF8E3M4, 1}, + {DType::kF8E4M3, 1}, {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, {DType::kF8E4M3FNUZ, 1}, @@ -85,6 +87,8 @@ TEST(DTypeTest, BitSize) { {DType::kPred, 8}, {DType::kS8, 8}, {DType::kU8, 8}, + {DType::kF8E3M4, 8}, + {DType::kF8E4M3, 8}, {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, {DType::kF8E4M3FNUZ, 8}, diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 36d492f27569a9..10a293778bd467 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,6 +44,8 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); + CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); @@ -80,6 +82,8 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F8E3M4: + case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: case xla::PrimitiveType::F8E4M3B11FNUZ: case xla::PrimitiveType::F8E4M3FNUZ: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index db7a38d0a50363..9a9c63a922e90d 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -185,6 +185,12 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3FN; @@ -394,6 +400,14 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = @@ -583,6 +597,9 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 4a1de389cd5b5d..125f96a75fdf25 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -59,6 +59,8 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float8_e3m4; + std::optional float8_e4m3; nb_dtype float8_e4m3fn; nb_dtype float8_e4m3b11fnuz; nb_dtype float8_e4m3fnuz; @@ -75,6 +77,12 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->float8_e4m3 = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3")); + } dtypes->float8_e4m3fn = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); dtypes->float8_e5m2 = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2")); @@ -140,6 +148,12 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float8_e3m4.has_value()) { + map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); + } + if (custom_dtypes.float8_e4m3.has_value()) { + map->emplace(*custom_dtypes.float8_e4m3, F8E4M3); + } map->emplace(custom_dtypes.float8_e4m3fn, F8E4M3FN); map->emplace(custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); @@ -204,6 +218,16 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case F8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case F8E4M3FN: return custom_dtypes.float8_e4m3fn; case F8E4M3B11FNUZ: @@ -284,6 +308,16 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case ifrt::DType::kF8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case ifrt::DType::kF8E4M3FN: return custom_dtypes.float8_e4m3fn; case ifrt::DType::kF8E4M3B11FNUZ: @@ -347,6 +381,12 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->np_float8_e4m3 = nb::object(ml_dtypes.attr("float8_e4m3")); + } dtypes->np_float8_e4m3fn = nb::object(ml_dtypes.attr("float8_e4m3fn")); dtypes->np_float8_e4m3b11fnuz = nb::object(ml_dtypes.attr("float8_e4m3b11fnuz")); diff --git a/xla/python/types.h b/xla/python/types.h index ed7ca847b1a7f7..fece926edd3017 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -79,6 +79,9 @@ struct NumpyScalarTypes { nanobind::object np_uint32; nanobind::object np_uint64; nanobind::object np_bfloat16; + // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float8_e3m4; + std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; nanobind::object np_float8_e4m3b11fnuz; nanobind::object np_float8_e4m3fnuz; @@ -128,7 +131,6 @@ nanobind::tuple SpanToNbTuple(absl::Span xs) { // references to the objects. nanobind::tuple MutableSpanToNbTuple(absl::Span xs); - template std::vector IterableToVector(const nanobind::iterable& iterable) { std::vector output; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 868a3aa9d74016..70c8c90c0a04e4 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -202,6 +202,9 @@ NB_MODULE(xla_extension, m_nb) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F8E3M4", F8E3M4) + // .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 5cc12efa93709b..51d879814b2e68 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -274,6 +274,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = ml_dtypes.float8_e3m4 +# float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -292,6 +295,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 898a632ab340a7..5a1df08e736f64 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -59,6 +59,9 @@ _version: int mlir_api_version: int bfloat16: type[numpy.generic] +# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4: type[numpy.generic] +# float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index edd582e405c265..441d5fbf450fa4 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -54,6 +54,9 @@ xla_client._xla.jax_jit.global_state().enable_memories = False bfloat16 = xla_client.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = xla_client.float8_e3m4 +# float8_e4m3 = xla_client.float8_e4m3 float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -150,6 +153,8 @@ def TestFactory(xla_backend, # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # standard_dtypes += [float8_e3m4, float8_e4m3] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index b5ae4c6431ca66..e363d8d82471cb 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -73,6 +73,8 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType F8E4M3B11FNUZ: PrimitiveType F8E4M3FNUZ: PrimitiveType diff --git a/xla/service/BUILD b/xla/service/BUILD index 5f1c0407d99fd0..e72fff7038d06c 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5917,6 +5917,7 @@ cc_library( "//xla:xla_data_proto_cc", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:statusor", ], ) @@ -5971,22 +5972,21 @@ xla_test( deps = [ ":elemental_ir_emitter", ":hlo_module_config", - ":hlo_parser", "//xla:error_spec", - "//xla:execution_options_util", "//xla:literal", "//xla:literal_util", - "//xla:status_macros", "//xla:test", + "//xla:types", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:ir_array", - "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 1a77e4ffb8eb13..1b5da63773d85f 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -600,6 +600,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( #endif FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); + FloatSupport f8e4m3_support(F8E4M3, F16); + pipeline.AddPass(&f8e4m3_support); FloatSupport f8e4m3fn_support(F8E4M3FN, F16); pipeline.AddPass(&f8e4m3fn_support); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); @@ -608,6 +610,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&f8e5m2fnuz_support); FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); pipeline.AddPass(&f8e4m3fnuz_support); + FloatSupport f8e3m4_support(F8E3M4, F16); + pipeline.AddPass(&f8e3m4_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index c0c956a32dc0b1..2fef54861722f1 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -71,7 +71,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ + // F8E4M3B11FNUZ, F8E4M3, F8E3M4 default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index aff56f92b15601..b04e4e554a8a8e 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -220,6 +220,90 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +template +llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, + llvm::Value* f8_bits, + llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + // F16 values that are halfway between denormal F8 values. This is used to + // determine how to round to denormal F8 values. + const int halfway_points_e4[8] = { + 0x1400, // 0x1.0p-10 ; halfway between [0/8 * 2^-6, 1/8 * 2^-6] + 0x1A00, // 0x1.8p-9 ; halfway between [1/8 * 2^-6, 2/8 * 2^-6] + 0x1D00, // 0x1.4p-8 ; halfway between [2/8 * 2^-6, 3/8 * 2^-6] + 0x1F00, // 0x1.Cp-8 ; halfway between [3/8 * 2^-6, 4/8 * 2^-6] + 0x2080, // 0x1.2p-7 ; halfway between [4/8 * 2^-6, 5/8 * 2^-6] + 0x2180, // 0x1.6p-7 ; halfway between [5/8 * 2^-6, 6/8 * 2^-6] + 0x2280, // 0x1.Ap-7 ; halfway between [6/8 * 2^-6, 7/8 * 2^-6] + 0x2380, // 0x1.Ep-7 ; halfway between [7/8 * 2^-6, 8/8 * 2^-6] + }; + + const int halfway_points_e3[16] = { + 0x2000, // 0x1.0p-7; halfway between [0/16 * 2^-2, 1/16 * 2^-2] + 0x2600, // 0x1.8p-6; halfway between [1/16 * 2^-2, 2/16 * 2^-2] + 0x2900, // 0x1.4p-5; halfway between [2/16 * 2^-2, 3/16 * 2^-2] + 0x2B00, // 0x1.Cp-5; halfway between [3/16 * 2^-2, 4/16 * 2^-2] + 0x2C80, // 0x1.2p-4; halfway between [4/16 * 2^-2, 5/16 * 2^-2] + 0x2D80, // 0x1.6p-4; halfway between [5/16 * 2^-2, 6/16 * 2^-2] + 0x2E80, // 0x1.Ap-4; halfway between [6/16 * 2^-2, 7/16 * 2^-2] + 0x2F80, // 0x1.Ep-4; halfway between [7/16 * 2^-2, 8/16 * 2^-2] + 0x3040, // 0x1.1p-3; halfway between [8/16 * 2^-2, 9/16 * 2^-2] + 0x30C0, // 0x1.3p-3; halfway between [9/16 * 2^-2, 10/16 * 2^-2] + 0x3140, // 0x1.5p-3; halfway between [10/16 * 2^-2, 11/16 * 2^-2] + 0x31C0, // 0x1.7p-3; halfway between [11/16 * 2^-2, 12/16 * 2^-2] + 0x3240, // 0x1.9p-3; halfway between [12/16 * 2^-2, 13/16 * 2^-2] + 0x32C0, // 0x1.Bp-3; halfway between [13/16 * 2^-2, 14/16 * 2^-2] + 0x3340, // 0x1.Dp-3; halfway between [14/16 * 2^-2, 15/16 * 2^-2] + 0x33C0, // 0x1.Fp-3; halfway between [15/16 * 2^-2, 16/16 * 2^-2] + }; + + const int* halfway_points; + int arr_sz; + if constexpr (f8_exponent_bits == 4) { + halfway_points = halfway_points_e4; + arr_sz = 8; + } else if constexpr (f8_exponent_bits == 3) { + halfway_points = halfway_points_e3; + arr_sz = 16; + } + + // Handle case where output is denormal. If we're rounding to a denormal + // value, ignore the current value of f8_bits and set it to the correct + // denormal value. We emit the equivalent of the following: + // + // if (f16_abs_bits <= halfway_points[0]) { + // f8_bits = 0; + // } else if (f16_abs_bits < halfway_points[1]) { + // f8_bits = 1; + // } else if (f16_abs_bits <= halfway_points[2]) { + // ... // More if-else statements. The comparisons alternate between <= + // ... // and < to handle round-to-even properly. + // } else if (f16_abs_bits < halfway_points[7]) { + // f8_bits = 7; + // } + for (int i = arr_sz - 1; i >= 0; i--) { + Value* comparison; + if (i % 2 == 0) { + comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); + } else { + comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); + } + f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); + } + return f8_bits; +} + absl::StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, llvm::IRBuilder<>* b) { TF_ASSIGN_OR_RETURN( @@ -242,6 +326,223 @@ llvm::Value* EmitF8e5m2ToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { return b->CreateBitCast(shifted, b->getHalfTy()); } +template +absl::StatusOr EmitF16ToF8e(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f16_as_int = bitcast(f16_value, int) + // f16_abs_bits = f16_as_int & 0x7FFF + Value* f16_as_int = b->CreateBitCast(f16_value, i16_type); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_as_int, i16_const(0x7FFF)); + + // Get the sign. + // f8_sign = (f16_as_int & 0x8000) >> 8 + Value* f16_sign = b->CreateAnd(f16_as_int, i16_const(0x8000)); + f16_sign = b->CreateLShr(f16_sign, i16_const(8)); + Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); + + // Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits + // Denormal values are not handled properly here and are + // dealt with later in this function. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/f8_exponent_bits, + /*dest_mantissa_bits=*/f8_mantissa_bits, + /*quiet_nans=*/true, b); + CHECK_OK(f16_reduced_statusor.status()); // Crash OK + Value* f16_reduced = f16_reduced_statusor.value(); + f16_reduced = b->CreateBitCast(f16_reduced, i16_type); + + // Remove the sign bit. + // f16_reduced = f16_reduced & 0x7FFF + f16_reduced = b->CreateAnd(f16_reduced, i16_const(0x7FFF)); + + // F16 inf in binary: 0 11111 0000000000 + constexpr int f16_inf_value = 0x7C00; + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int min_normal_value = (exponent_bias_difference + 1) + << f16_mantissa_bits; + + // Round values smaller than the smallest F8 normal value up to the smallest + // F8 normal value. The case where we round to a denormal value is handled + // later. + // f16_reduced = max(f16_reduced, min_normal_value) + f16_reduced = b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(min_normal_value)), + i16_const(min_normal_value), f16_reduced); + + // Adjust the exponent by subtracting the difference in exponent bias: + // f16_reduced -= (exponent_bias_difference << f16_mantissa_bits) + // For infinity/NaN values, subtract twice the difference in exponent bias + // to ensure the leading exponent bit(s) of f16_reduced are set to zero. + f16_reduced = b->CreateSub( + f16_reduced, + b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(f16_inf_value)), + i16_const(exponent_bias_difference << f16_mantissa_bits), + i16_const(exponent_bias_difference << (f16_mantissa_bits + 1)))); + + // Shift to convert to F8. + // f16_reduced = f16_reduced >> mantissa_bits_difference; + f16_reduced = b->CreateLShr(f16_reduced, i16_const(mantissa_bits_difference)); + + Value* f8_bits = b->CreateTrunc(f16_reduced, i8_type); + + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + + // Set the sign bit. + // f8_bits |= f8_sign + f8_bits = b->CreateOr(f8_bits, f8_sign); + return f8_bits; +} + +template +llvm::Value* EmitToF16F8e(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e4[8] = { + 0x0000, // 0 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e3[16] = { + 0x0000, // 0 + 0x2400, // 1/16 * 2^-2 + 0x2800, // 2/16 * 2^-2 + 0x2A00, // 3/16 * 2^-2 + 0x2C00, // 4/16 * 2^-2 + 0x2D00, // 5/16 * 2^-2 + 0x2E00, // 6/16 * 2^-2 + 0x2F00, // 7/16 * 2^-2 + 0x3000, // 8/16 * 2^-2 + 0x3080, // 9/16 * 2^-2 + 0x3100, // 10/16 * 2^-2 + 0x3180, // 11/16 * 2^-2 + 0x3200, // 12/16 * 2^-2 + 0x3280, // 13/16 * 2^-2 + 0x3300, // 14/16 * 2^-2 + 0x3380, // 15/16 * 2^-2 + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f8_as_int = bitcast(f16_value, int) + // f8_abs_bits = f8_as_int & 0x7F + Value* f8_as_int = b->CreateBitCast(f8_value, i8_type); + Value* f8_abs_bits = b->CreateAnd(f8_as_int, i8_const(0x7F)); + + // We assume below that the value is neither NaN nor denormal. If it NaN or + // denormal, the output is set to NaN or zero at the end using Select + // instructions. + + // Get the sign: + // f16_sign = (f8_as_int & 0x80) << 8 + Value* f8_sign = b->CreateAnd(f8_as_int, i8_const(0x80)); + Value* f16_sign = b->CreateZExt(f8_sign, i16_type); + f16_sign = b->CreateShl(f16_sign, i16_const(8)); + + int exponent_mask; + const int* f8_denormal_to_f16; + int f8_denormal_size; + if constexpr (f8_exponent_bits == 4) { + exponent_mask = 0x78; + f8_denormal_to_f16 = f8_denormal_to_f16_e4; + f8_denormal_size = 8; + } else if constexpr (f8_exponent_bits == 3) { + exponent_mask = 0x70; + f8_denormal_to_f16 = f8_denormal_to_f16_e3; + f8_denormal_size = 16; + } + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int f8_mantissa_mask = (1 << f8_mantissa_bits) - 1; + + // Get the exponent: + // f8_exponent = (f8_as_int & exponent_mask) >> f8_mantissa_bits + Value* f8_exponent_bits_v = b->CreateAnd(f8_as_int, i8_const(exponent_mask)); + Value* f8_exponent = + b->CreateLShr(f8_exponent_bits_v, i8_const(f8_mantissa_bits)); + + // Adjust the exponent by adding the difference in exponent bias: + // f16_exponent = (f8_exponent + exponent_bias_difference) + // << f16_mantissa_bits + Value* f16_exponent = + b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); + f16_exponent = b->CreateZExt(f16_exponent, i16_type); + f16_exponent = b->CreateShl(f16_exponent, i16_const(f16_mantissa_bits)); + + // Set output exponent to 11111 if input exponent is 111 (Inf or NaN) + // 0.11111.0000000000 is 0x7C00 + Value* is_exp_1111 = + b->CreateICmpEQ(f8_exponent_bits_v, i8_const(exponent_mask)); + f16_exponent = b->CreateSelect(is_exp_1111, i16_const(0x7C00), f16_exponent); + + // Get the mantissa: + // f16_mantissa = (f8_mantissa & f8_mantissa_mask) + // << mantissa_bits_difference + Value* f8_mantissa = b->CreateAnd(f8_as_int, i8_const(f8_mantissa_mask)); + Value* f16_mantissa = b->CreateZExt(f8_mantissa, i16_type); + f16_mantissa = + b->CreateShl(f16_mantissa, i16_const(mantissa_bits_difference)); + + // Combine the exponent and mantissa: + // f16_as_int = f16_exponent | f16_mantissa + Value* f16_as_int = b->CreateOr(f16_exponent, f16_mantissa); + + // If the F8 value is denormal, use the map above to determine the correct F16 + // value. + // if (f8_abs_bits < 8) { f16_as_int = f8_denormal_to_f16[f8_abs_bits]; } + for (int i = 0; i < f8_denormal_size; i++) { + Value* is_denormal_value = b->CreateICmpEQ(f8_abs_bits, i8_const(i)); + f16_as_int = b->CreateSelect(is_denormal_value, + i16_const(f8_denormal_to_f16[i]), f16_as_int); + } + + // Set the sign bit. + // f16_as_int |= f16_sign + f16_as_int = b->CreateOr(f16_as_int, f16_sign); + return b->CreateBitCast(f16_as_int, b->getHalfTy()); +} + llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { using llvm::APInt; using llvm::Value; @@ -297,6 +598,7 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { i16_const(min_normal_value), f16_reduced); constexpr int exponent_bias_difference = 15 - 7; + constexpr int f8_exponent_bits = 4; constexpr int f16_mantissa_bits = 10; constexpr int f8_mantissa_bits = 3; constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; @@ -322,42 +624,9 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { b->CreateICmpUGT(f16_abs_bits, i16_const(max_finite_value)), i8_const(0x7F), f8_bits); - // F16 values that are halfway between denormal F8 values. This is used to - // determine how to round to denormal F8 values. - const int halfway_points[8] = { - 0x1400, // 2**-10; halfway between [0, 2**-9] - 0x1A00, // 1.5 * 2**-9; halfway between [2**-9, 2**-8] - 0x1D00, // 1.25 * 2**-8; halfway between [2**-8, 1.5 * 2**-8] - 0x1F00, // 1.75 * 2**-8; halfway between [1.5 * 2**-8, 2**-7] - 0x2080, // 1.125 * 2**-7; halfway between [2**-7, 1.25 * 2**-7] - 0x2180, // 1.375 * 2**-7; halfway between [1.25 * 2**-7, 1.5 * 2**-7] - 0x2280, // 1.625 * 2**-7; halfway between [1.5 * 2**-7, 1.75 * 2**-7] - 0x2380, // 1.875 * 2**-7; halfway between [1.75 * 2**-7, 2**-6] - }; - - // Handle case where output is denormal. If we're rounding to a denormal - // value, ignore the current value of f8_bits and set it to the correct - // denormal value. We emit the equivalent of the following: - // - // if (f16_abs_bits <= halfway_points[0]) { - // f8_bits = 0; - // } else if (f16_abs_bits < halfway_points[1]) { - // f8_bits = 1; - // } else if (f16_abs_bits <= halfway_points[2]) { - // ... // More if-else statements. The comparisons alternate between <= - // ... // and < to handle round-to-even properly. - // } else if (f16_abs_bits < halfway_points[7]) { - // f8_bits = 7; - // } - for (int i = ABSL_ARRAYSIZE(halfway_points) - 1; i >= 0; i--) { - Value* comparison; - if (i % 2 == 0) { - comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); - } else { - comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); - } - f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); - } + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -408,7 +677,7 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { b->CreateLShr(f8_exponent_bits, i8_const(f8_mantissa_bits)); // Adjust the exponent by adding the difference in exponent bias: - // f16_exponent = (f8_exopnent + exponent_bias_difference) + // f16_exponent = (f8_exponent + exponent_bias_difference) // << f16_mantissa_bits Value* f16_exponent = b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); @@ -435,13 +704,13 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { // Map from F8 denormal value to F16 value. int f8_denormal_to_f16[8] = { 0x0000, // 0 - 0x1800, // 2**-9 - 0x1C00, // 2**-8 - 0x1E00, // 1.5 * 2**-8 - 0x2000, // 2**-7 - 0x2100, // 1.25 * 2**-7 - 0x2200, // 1.5 * 2**-7 - 0x2300, // 1.75 * 2**-7 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 }; // If the F8 value is denormal, use the map above to determine the correct F16 @@ -604,6 +873,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E4M3) { + return EmitF16ToF8e<4>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } if (to_type == F8E4M3FN) { return EmitF16ToF8e4m3fn( EmitIntegralToFloating(operand_value, from_type, F16, module_, @@ -623,6 +898,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), to_type, b_); } + if (to_type == F8E3M4) { + return EmitF16ToF8e<3>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } return EmitIntegralToFloating(operand_value, from_type, to_type, module_, b_); } @@ -789,6 +1070,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E4M3) { + TF_RET_CHECK(to_type != F8E4M3); + operand_value = EmitToF16F8e<4>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E4M3FN) { TF_RET_CHECK(to_type != F8E4M3FN); operand_value = EmitF8e4m3fnToF16(operand_value, b_); @@ -817,6 +1106,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E3M4) { + TF_RET_CHECK(to_type != F8E3M4); + operand_value = EmitToF16F8e<3>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (primitive_util::IsComplexType(to_type)) { PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); @@ -844,6 +1141,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e5m2(operand_value, b_); } + if (to_type == F8E4M3) { + // Cast to F16 first. Casts to F8E4M3 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<4>(operand_value, b_); + } if (to_type == F8E4M3FN) { // Cast to F16 first. Casts to F8E4M3FN must be from F16. if (from_type != F16) { @@ -863,6 +1168,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } + if (to_type == F8E3M4) { + // Cast to F16 first. Casts to F8E3M4 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<3>(operand_value, b_); + } if (to_type == PRED) { return b_->CreateZExt( FCmpUNE(operand_value, @@ -1391,6 +1704,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( if (operand_type == F8E5M2) { lhs_value = EmitF8e5m2ToF16(lhs_value, b_); rhs_value = EmitF8e5m2ToF16(rhs_value, b_); + } else if (operand_type == F8E4M3) { + lhs_value = EmitToF16F8e<4>(lhs_value, b_); + rhs_value = EmitToF16F8e<4>(rhs_value, b_); } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); @@ -1401,6 +1717,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( TF_ASSIGN_OR_RETURN( rhs_value, EmitF8fnuzToFloating(operand_type, rhs_value, F16, b_, module_)); + } else if (operand_type == F8E3M4) { + lhs_value = EmitToF16F8e<3>(lhs_value, b_); + rhs_value = EmitToF16F8e<3>(rhs_value, b_); } switch (op->comparison_direction()) { case ComparisonDirection::kEq: diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 7c73dd3a1d0dd7..60c4535909d158 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include #include +#include #include +#include #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -37,6 +39,8 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -89,9 +93,9 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -249,8 +253,10 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { auto tname = this->TypeName(); - if (std::is_same() || - std::is_same()) { + if (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -413,8 +419,10 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, CompareFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { auto tname = this->TypeName(); if (std::is_same() || + std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index fe3a1041933cb5..22916aa084fc47 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/APFloat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Intrinsics.h" #include "xla/primitive_util.h" @@ -39,6 +40,10 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F8E3M4: + return &llvm::APFloat::Float8E3M4(); + case F8E4M3: + return &llvm::APFloat::Float8E4M3(); case F8E4M3B11FNUZ: return &llvm::APFloat::Float8E4M3B11FNUZ(); case F8E4M3FN: @@ -67,6 +72,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, PrimitiveType type) { switch (type) { + case F8E3M4: + case F8E4M3: case F8E4M3B11FNUZ: case F8E4M3FN: case F8E4M3FNUZ: diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index a140d2e933af9a..d62443a7d9ff06 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -144,7 +144,7 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E5M2)); + ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/xla/service/gpu/fusions/transforms/expand_float_ops.cc index aebb7f44608559..e9b9731756f7db 100644 --- a/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -175,12 +175,19 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { } assert(ty.getIntOrFloatBitWidth() == 8); - if (!ty.isFloat8E5M2()) { - // F8E5M2 is the only 8 bit float with infinities. + // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. + if (ty.isFloat8E5M2()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x7C; + } else if (ty.isFloat8E4M3()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x78; + } else if (ty.isFloat8E3M4()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x70; + } else { return b.create(false, b.getI1Type()); } - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x7C; } Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -193,8 +200,12 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { Val bits{b.create(b.getI8Type(), value), &b}; if (ty.isFloat8E5M2()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1100); + } else if (ty.isFloat8E4M3()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1000); } else if (ty.isFloat8E4M3FN()) { return (bits & 0b0111'1111) == 0b0111'1111; + } else if (ty.isFloat8E3M4()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); } return bits == 0x80; } @@ -544,7 +555,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. if (rhs_cst.isZero() && - (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { + (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || + lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { int_value = int_value & 0x7f; constant &= 0x7f; } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index f574d64a2290c4..afb1002cc19461 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1441,19 +1441,23 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Lambdas and related constants: const GpuFloatSupport bf16_support(gpu_version, BF16); const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); + const GpuFloatSupport f8e4m3_support(gpu_version, F8E4M3, F16); const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); + const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); sub_pipeline.AddPass(&bf16_support); sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3_support); sub_pipeline.AddPass(&f8e4m3fn_support); sub_pipeline.AddPass(&f8e4m3b11fnuz_support); sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); + sub_pipeline.AddPass(&f8e3m4_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 406fcd9534a9dc..01b32bf8301d34 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -101,7 +101,8 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = - (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + (output_primitive_type == F8E3M4 || output_primitive_type == F8E4M3 || + output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || output_primitive_type == F8E4M3FNUZ || output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 || output_primitive_type == BF16 || output_primitive_type == F32 || diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index b5d571e4c7be3f..16383324dfb016 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,8 +29,9 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3fn", - "f8e4m3fnuz", "f8e4m3b11fnuz")); + "f8e5m2", "f8e5m2fnuz", "f8e4m3", + "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 27630b674d2ce4..036d742c05158e 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -200,9 +200,11 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: case F8E5M2FNUZ: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E4M3FNUZ: + case F8E3M4: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index 03b09f9b644f07..f5246389e485c3 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,14 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E3M4; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E4M3; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E4M3FN; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 10270d0b3c1be2..f46506c0fda2c7 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -66,6 +66,8 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6a604e20619455..6aee86bf2cbc19 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -46,12 +46,16 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { switch (dtype) { case PrimitiveType::F8E5M2: return DataType::kF8E5M2; + case PrimitiveType::F8E4M3: + return DataType::kF8E4M3; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; case PrimitiveType::F8E5M2FNUZ: return DataType::kF8E5M2FNUZ; case PrimitiveType::F8E4M3FNUZ: return DataType::kF8E4M3FNUZ; + case PrimitiveType::F8E3M4: + return DataType::kF8E3M4; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -79,12 +83,16 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { switch (dtype) { case DataType::kF8E5M2: return PrimitiveType::F8E5M2; + case DataType::kF8E4M3: + return PrimitiveType::F8E4M3; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; case DataType::kF8E5M2FNUZ: return PrimitiveType::F8E5M2FNUZ; case DataType::kF8E4M3FNUZ: return PrimitiveType::F8E4M3FNUZ; + case DataType::kF8E3M4: + return PrimitiveType::F8E3M4; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -141,9 +149,11 @@ absl::StatusOr GetBlasComputationType( if (algorithm == xla::PrecisionConfig::ALG_UNSET) { switch (output_dtype) { case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3: // fall-through case PrimitiveType::F8E4M3FN: // fall-through case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through + case PrimitiveType::F8E3M4: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 57e9f9651e0f50..f6e8a867e2b484 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -745,13 +745,11 @@ cc_library( deps = [ ":hipblas_lt_header", ":rocblas_plugin", - "//xla/stream_executor", "//xla/stream_executor:blas", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", ], ) diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index a59c935614cd8f..e5730121addd8d 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/rocm/hip_blas_utils.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "xla/stream_executor/blas.h" @@ -35,8 +36,11 @@ absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: + case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; + case blas::DataType::kF8E3M4: + LOG(FATAL) + << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 34f124c69fb640..9472d3f5b6f31d 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1182,10 +1182,12 @@ xla_test( "//xla:array3d", "//xla:array4d", "//xla:literal_util", + "//xla:types", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:constants", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index b8b1613768c2bc..c12ce79a06e8fa 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -1423,7 +1423,8 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types #include +#include #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/types.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" @@ -46,10 +48,11 @@ class ConstantsTest : public ClientLibraryTestBase { template class ConstantsFloatTest : public ConstantsTest {}; -typedef ::testing::Types - FloatTypes; +using FloatTypes = + ::testing::Types; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index f5a68c32886410..4f06ea0cc290c7 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,10 +54,9 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -873,22 +872,200 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive) { XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive2) { // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e5m2; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// ----- F8E4M3 + +XLA_TEST_F(ConvertTest, ConvertF16F8e4m3Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFCp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.004p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFCp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFFFFEp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFFFFEp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e4m3; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); + Eigen::numext::bit_cast(static_cast(i)))); } - ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive3) { // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - using From = tsl::float8_e5m2; + using From = tsl::float8_e4m3; std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); @@ -899,7 +1076,7 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3F16RoundtripExhaustive4) { // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); @@ -910,7 +1087,7 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { } xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2); + ConvertElementType(all_f16_to_f8, F8E4M3); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } @@ -1366,15 +1543,21 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - static_cast(Eigen::numext::bit_cast( - static_cast(i)))); - } + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } - ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } } XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive3) { @@ -1569,5 +1752,178 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E3M4 + +XLA_TEST_F(ConvertTest, ConvertF16F8e3m4Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7Cp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.004p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7Cp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e3m4Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7FFFEp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.000002p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7FFFEp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tests/float8_test.cc b/xla/tests/float8_test.cc index 02be9bfa9356ea..648c718d7cd958 100644 --- a/xla/tests/float8_test.cc +++ b/xla/tests/float8_test.cc @@ -27,11 +27,12 @@ limitations under the License. namespace xla { namespace { -// Test FP8 floating-point types (F8E5M2, F8E4M3FN) +// Test FP8 floating-point types template class Float8Test : public ClientLibraryTestBase {}; -using DataTypes = ::testing::Types; +using DataTypes = ::testing::Types; TYPED_TEST_SUITE(Float8Test, DataTypes); XLA_TYPED_TEST(Float8Test, ScalarOperation) { diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 780968098cf32b..4f4895b57123ae 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -120,22 +120,28 @@ enum PrimitiveType { C64, C128, F8E5M2, + F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, + F8E3M4, }; const std::vector& primitive_strings() { - static auto vec = - new std::vector({"s2", "s4", "s8", - "s16", "s32", "s64", - "u2", "u4", "u8", - "u16", "u32", "u64", - "f16", "bf16", "f32", - "f64", "c64", "c128", - "f8e5m2", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz"}); + static auto vec = new std::vector({"s2", "s4", + "s8", "s16", + "s32", "s64", + "u2", "u4", + "u8", "u16", + "u32", "u64", + "f16", "bf16", + "f32", "f64", + "c64", "c128", + "f8e5m2", "f8e4m3", + "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz", + "f8e3m4"}); return *vec; } @@ -413,10 +419,12 @@ void Fill(void* buffer, const ArrayShape& shape) { return FillFloatT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: @@ -469,10 +477,12 @@ void Display(const void* buffer, const ArrayShape& shape) { return DisplayT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 46fa640ee62298..39644589d309e6 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,6 +70,8 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 695db935f6a0b4..2ac31005c16629 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -22,6 +22,8 @@ enum DataType { kF8E5M2FNUZ = 10; kF8E4M3FNUZ = 11; kInt64 = 12; + kF8E4M3 = 13; + kF8E3M4 = 14; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index 717ab3e462a7bf..e2c5eb295c6b12 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,10 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float8_e3m4 = + py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); + numpy_dtypes.float8_e4m3 = + py::dtype::from_args(ml_dtypes.attr("float8_e4m3")).num(); numpy_dtypes.float8_e4m3fn = py::dtype::from_args(ml_dtypes.attr("float8_e4m3fn")).num(); numpy_dtypes.float8_e5m2 = @@ -81,6 +85,8 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float8_e3m4 == NPY_NOTYPE || + numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || numpy_dtypes.float8_e4m3fnuz == NPY_NOTYPE || numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index bf9eab2200a76b..b3aa94e430239a 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,8 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float8_e3m4; + int float8_e4m3; int float8_e4m3fn; int float8_e4m3b11fnuz; int float8_e4m3fnuz; diff --git a/xla/util.cc b/xla/util.cc index 9b1a6db1fa22c0..0c92df2c69e76b 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -137,7 +137,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { static_assert(!std::is_same::value, - "RoundTripNanPayload does not support E4M3"); + "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FNUZ"); static_assert(!std::is_same::value, @@ -168,6 +168,12 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e4m3 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(tsl::float8_e4m3fnuz value) { std::string result = GenericRoundTripFpToString(value); return result; @@ -188,6 +194,12 @@ std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value) { return result; } +std::string RoundTripFpToString(tsl::float8_e3m4 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index a6e74601a809f7..a62096c866d1f5 100644 --- a/xla/util.h +++ b/xla/util.h @@ -420,6 +420,9 @@ std::string VectorString(const std::initializer_list& c) { std::string RoundTripFpToString(tsl::float8_e5m2 value); // Returns a string which can losslessly round trip to a float8 E4M3. +std::string RoundTripFpToString(tsl::float8_e4m3 value); + +// Returns a string which can losslessly round trip to a float8 E4M3FN. std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. @@ -431,6 +434,9 @@ std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); // Returns a string which can losslessly round trip to a float8 E4M3FNUZ. std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); +// Returns a string which can losslessly round trip to a float8 E3M4. +std::string RoundTripFpToString(tsl::float8_e3m4 value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); diff --git a/xla/util_test.cc b/xla/util_test.cc index 707696ea1c3a99..83b1b149c6916d 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "ml_dtypes/include/float8.h" #include "xla/maybe_owning.h" #include "xla/test.h" #include "xla/types.h" @@ -130,6 +131,18 @@ TEST(UtilTest, RoundTripFpToString) { EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( true, QuietNanWithoutPayload())), "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); EXPECT_EQ( RoundTripFpToString(std::numeric_limits::quiet_NaN()), "nan"); @@ -237,6 +250,18 @@ TEST(UtilTest, TotalOrder_F8E5M2) { } } +TEST(UtilTest, TotalOrder_F8E4M3) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e4m3 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e4m3 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E4M3FN) { for (int a = 0; a < 256; ++a) { tsl::float8_e4m3fn x = @@ -287,6 +312,18 @@ TEST(UtilTest, TotalOrder_F8E5M2FNUZ) { } } +TEST(UtilTest, TotalOrder_F8E3M4) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e3m4 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e3m4 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 57e8f5a93a7073..c67116a167eea5 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -66,6 +66,9 @@ enum PrimitiveType { // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the // existing IEEE types. // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only // Finite and NaN values are supported. Unlike IEEE types, infinities are not // supported. NaN is represented when the exponent and mantissa bits are all @@ -77,12 +80,17 @@ enum PrimitiveType { // the exponent and mantissa bits are all 0s with a sign bit of 1. All other // values are finite. // + // F8E3M4 has 3 exponent bits and 4 mantissa bits, and is similar to the + // existing IEEE types. + // // Support for these dtypes is under development. They do not yet work // properly in most cases. // TODO(b/259609697): Fully support FP8. F8E5M2 = 19; + F8E4M3 = 28; F8E4M3FN = 20; F8E4M3B11FNUZ = 23; + F8E3M4 = 29; // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 // @@ -126,7 +134,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 28 + // Next = 30 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -572,12 +580,14 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; bytes f8e5m2s = 19; + bytes f8e4m3s = 28; bytes f8e4m3fns = 20; bytes f8e4m3b11fnuzs = 23; bytes f8e5m2fnuzs = 24; bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 28 + // Next = 30 } message WindowDimension {