Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float8_e4m3 and float8_e3m4 types #16585

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions third_party/tsl/tools/def_file_filter/symbols_pybind.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
[//external/local_xla/xla/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8
tsl::ml_dtypes::RegisterTypes
tsl::ml_dtypes::GetBfloat16Dtype
tsl::ml_dtypes::GetFloat8E3m4Dtype
tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype
tsl::ml_dtypes::GetFloat8E4m3fnDtype
tsl::ml_dtypes::GetFloat8E4m3Dtype
tsl::ml_dtypes::GetFloat8E5m2Dtype
tsl::ml_dtypes::GetBfloat16TypeNum
tsl::ml_dtypes::GetFloat8E3m4TypeNum
tsl::ml_dtypes::GetFloat8E4m3b11fnuzTypeNum
tsl::ml_dtypes::GetFloat8E4m3fnTypeNum
tsl::ml_dtypes::GetFloat8E4m3TypeNum
tsl::ml_dtypes::GetFloat8E5m2TypeNum

[//tensorflow/python:py_func_lib] # py_func
Expand Down
4 changes: 3 additions & 1 deletion third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes

namespace tsl {
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
Expand Down
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ TEST(Array2dTest, LinspaceF8E5M2) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF8E4M3Fn) {
auto arr = MakeLinspaceArray2D<tsl::float8_e4m3fn>(1.0, 3.5, 3, 2);

Expand Down Expand Up @@ -190,6 +204,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) {
}
}

TEST(Array2dTest, LinspaceF8E3M4) {
auto arr = MakeLinspaceArray2D<tsl::float8_e3m4>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "TOKEN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E3M4:
return os << "F8E3M4";
case XLA_FFI_DataType_F8E4M3:
return os << "F8E4M3";
case XLA_FFI_DataType_F8E4M3FN:
return os << "F8E4M3FN";
case XLA_FFI_DataType_F8E4M3B11FNUZ:
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ typedef enum {
XLA_FFI_DataType_C128 = 18,
XLA_FFI_DataType_TOKEN = 17,
XLA_FFI_DataType_F8E5M2 = 19,
XLA_FFI_DataType_F8E3M4 = 29,
XLA_FFI_DataType_F8E4M3 = 28,
XLA_FFI_DataType_F8E4M3FN = 20,
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ enum class DataType : uint8_t {
C128 = XLA_FFI_DataType_C128,
TOKEN = XLA_FFI_DataType_TOKEN,
F8E5M2 = XLA_FFI_DataType_F8E5M2,
F8E4M3 = XLA_FFI_DataType_F8E4M3,
F8E4M3FN = XLA_FFI_DataType_F8E4M3FN,
F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ,
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand All @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64;
inline constexpr DataType C128 = DataType::C128;
inline constexpr DataType TOKEN = DataType::TOKEN;
inline constexpr DataType F8E5M2 = DataType::F8E5M2;
inline constexpr DataType F8E4M3 = DataType::F8E4M3;
inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN;
inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::S8:
case DataType::U8:
case DataType::F8E5M2:
case DataType::F8E4M3:
case DataType::F8E4M3FN:
case DataType::F8E4M3B11FNUZ:
case DataType::F8E5M2FNUZ:
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ),
encoded(DataType::F8E4M3B11FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
ByteWidth(DataType::F8E4M3));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN),
ByteWidth(DataType::F8E4M3FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ),
Expand All @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E5M2FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ),
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
61 changes: 57 additions & 4 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

// Test F8E4M3 floating-point types (F8E4M3FN)
// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN)
template <typename T>
class FP8E4M3DistanceTest : public ::testing::Test {};

using F8E4M3Types = ::testing::Types<tsl::float8_e4m3fn>;
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(8.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(15.5)),
15);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
tsl::float8_e3m4(6)),
8);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
tsl::float8_e3m4(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::denorm_min(),
std::numeric_limits<tsl::float8_e3m4>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(
CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(), tsl::float8_e3m4(0)),
16);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(
-std::numeric_limits<tsl::float8_e3m4>::min(),
std::numeric_limits<tsl::float8_e3m4>::min()),
32);
}

TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) {
// a & b are equal, distance should be 0
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(8.0)), 0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(13)),
5);
EXPECT_EQ(
CalculateDistanceInFloats<TypeParam>(TypeParam(8.0), TypeParam(15.0)), 7);

// a & b have different exponents
EXPECT_EQ(
Expand Down
10 changes: 6 additions & 4 deletions xla/hlo/builder/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F8E3M4:
case F8E4M3:
case F8E5M2:
case F8E4M3FN:
case F8E4M3B11FNUZ:
Expand Down Expand Up @@ -973,8 +975,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1026,8 +1028,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type :
{BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,10 +1743,12 @@ extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;

} // namespace xla

Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License.

namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e3m4, float>;
} // namespace xla
24 changes: 21 additions & 3 deletions xla/hlo/translate/hlo_to_mhlo/tests/import.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,17 @@ add {
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ>
%constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
// CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ>
%constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
// CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>
%constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>
%constant.12 = f8e4m3[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
%constant.13 = f8e3m4[4] constant({1, 2, 3, 4})
}

// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
Expand Down Expand Up @@ -524,7 +530,19 @@ add {
%convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10)

// CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32>
ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)
%convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11)

// CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3>
%convert.13 = f8e4m3[4] convert(f32[4] %convert.12)

// CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32>
%convert.14 = f32[4] convert(f8e4m3[4] %convert.13)

// CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4>
%convert.15 = f8e3m4[4] convert(f32[4] %convert.14)

// CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32>
ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
}

// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>
Expand Down
18 changes: 16 additions & 2 deletions xla/hlo/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,12 @@ func.func @main() {
// CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4})
%cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ>

// CHECK: f8e4m3[4] constant({1, 2, 3, 4})
%cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3>

// CHECK: f8e3m4[4] constant({1, 2, 3, 4})
%cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>

func.return
}

Expand Down Expand Up @@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32>
%6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ>
%7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32>
func.return %7 : tensor<2xf32>
%8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3>
%9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32>
%10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4>
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
func.return %11 : tensor<2xf32>
}

// CHECK: ENTRY
Expand All @@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]])
// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]])
// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]])
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]])
// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]])
// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]])
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])

// -----

Expand Down
Loading
Loading