From 1fd27d520f6aaa54c20cd6cc4439fb9c9f60c508 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 20 Oct 2023 08:00:45 -0500 Subject: [PATCH] Fix bf8 conversion issues (#1003) * Fix the conversion * Add bf8 functionality * Enable example on MI200 as well --- .../20_grouped_conv_bwd_weight/CMakeLists.txt | 6 +- .../element/unary_element_wise_operation.hpp | 3 +- include/ck/utility/data_type.hpp | 7 +- include/ck/utility/f8_utils.hpp | 147 ++++++++++++------ include/ck/utility/type_convert.hpp | 24 +-- 5 files changed, 114 insertions(+), 73 deletions(-) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 2b0c4a28ce..c28fca6fa2 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -10,10 +10,8 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - endif() + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) set(target 1) endif() diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index dabdf649e4..af48c6c1af 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -207,7 +207,8 @@ struct ConvertF8SR __host__ __device__ void operator()(Y& y, const X& x) const { // check Y datatype - static_assert(is_same::value, "Data type is not supported by this operation!"); + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); // check X datatype static_assert(is_same::value || is_same::value, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index ceaca27a45..d367ad8df5 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1075,6 +1075,7 @@ struct NumericUtils { static constexpr int exp = 8; static constexpr int mant = 23; + static constexpr int bias = 127; static constexpr uint32_t nan_mask = 0x7F800000; static constexpr uint32_t head_mask = 0xFF800000; static constexpr uint32_t mant_mask = 0x7FFFFF; @@ -1091,6 +1092,7 @@ struct NumericUtils { static constexpr int exp = 5; static constexpr int mant = 10; + static constexpr int bias = 15; static constexpr uint16_t nan_mask = 0x7C00; static constexpr uint16_t head_mask = 0xFC00; static constexpr uint16_t mant_mask = 0x3FF; @@ -1107,6 +1109,8 @@ struct NumericUtils { static constexpr int exp = 4; static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode }; template <> @@ -1114,6 +1118,7 @@ struct NumericUtils { static constexpr int exp = 5; static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode }; -// } // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index b63c82fe9a..1960667732 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -5,7 +5,6 @@ #include "ck/utility/data_type.hpp" -// these conversions are disabled if native conversions available namespace ck { // fp8 rounding modes @@ -17,6 +16,9 @@ enum class f8_rounding_mode stochastic }; +__host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +__device__ inline int clz(uint32_t x) { return __clz(x); } + } // namespace ck namespace ck::utils { @@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) constexpr int in_exp = NumericUtils::exp; constexpr int in_mant = NumericUtils::mant; - int exponent; + int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half constexpr Y nan_code = 0x80; @@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) mantissa = x_bitwise & NumericUtils::mant_mask; exponent = (head >> in_mant) & NumericUtils::exp_mask; sign = head >> (in_exp + in_mant); + bias = NumericUtils::bias; uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); - constexpr int exp_low_cutoff = - (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); if constexpr(negative_zero_nan) { @@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) return signed_inf + (mantissa != 0 ? 1 : 0); } - // if input is half and output is bf8 - if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && negative_zero_nan && - exponent == 0) - { - exponent += 1; - while(mantissa < (1 << in_mant)) - { - mantissa <<= 1; - exponent -= 1; - } - mantissa &= ~(1 << in_mant); - } - // check if x is 0.0 if(x_bitwise == 0) return 0; - exponent -= exp_low_cutoff - 1; - if(exponent <= 0) - drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1; - mantissa += 1 << in_mant; - // apply random number if needed - mantissa += (stoch ? rng : mantissa) & drop_mask; - if(mantissa >= (2 << in_mant)) - { - mantissa >>= 1; - exponent++; + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa } - mantissa >>= (in_mant - out_mant); - // check negative exponent - if(exponent <= 0) + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) { - if(x_bitwise == 0) - return 0; - else + if((1 << in_mant) & mantissa) { - // subnormal range; represented by a subnormal float8 (exponent 0) - // and involves loss of accuracy - mantissa >>= 1 - exponent; - exponent = 0; + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later } } - // above range: quantize to maximum possible float of the same sign - else if(exponent > max_exp) + else + { + if((1 << (in_mant + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later + } + } + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) { if(clip) { - mantissa = (1 << out_mant) - 1; - exponent = max_exp; + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; } else { @@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) } // check if x is 0.0 or -0.0 - if(exponent == 0 && mantissa == 0) + if(out_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); mantissa &= (1 << out_mant) - 1; - return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; } template @@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x) if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - exponent++; - while(mantissa < (1 << in_mant)) - { - mantissa <<= 1; - exponent--; - } + int sh = 1 + clz(mantissa) - (32 - in_mant); + mantissa <<= sh; + exponent += 1 - sh; mantissa &= ((1 << in_mant) - 1); } exponent += exp_low_cutoff - 1; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index aba8baa593..12d628a4bb 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -145,7 +145,7 @@ inline __host__ __device__ f8_t type_convert(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return type_convert(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -153,8 +153,6 @@ inline __host__ __device__ f8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); -#else - return type_convert(type_convert(x)); #endif } @@ -165,11 +163,9 @@ inline __host__ __device__ half_t type_convert(f8_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); -#else - return type_convert(type_convert(x)); #endif } @@ -223,7 +219,7 @@ inline __host__ __device__ bf8_t type_convert(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return type_convert(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -231,8 +227,6 @@ inline __host__ __device__ bf8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); -#else - return type_convert(type_convert(x)); #endif } @@ -243,11 +237,9 @@ inline __host__ __device__ half_t type_convert(bf8_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); -#else - return type_convert(type_convert(x)); #endif } @@ -347,7 +339,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -356,8 +348,6 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); -#else - return f8_convert_sr(type_convert(x)); #endif } @@ -396,7 +386,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); -#elif 0 +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -406,8 +396,6 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); -#else - return f8_convert_sr(type_convert(x)); #endif }