Skip to content

Commit

Permalink
workaround with float (#992)
Browse files Browse the repository at this point in the history
Co-authored-by: Jing Zhang <[email protected]>
  • Loading branch information
zjing14 and Jing Zhang authored Oct 16, 2023
1 parent 707ad00 commit 39430bf
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions include/ck/utility/type_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,16 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
#endif
}

Expand All @@ -164,9 +166,11 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
Expand Down Expand Up @@ -222,14 +226,16 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
#endif
}

Expand All @@ -240,9 +246,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
Expand Down Expand Up @@ -354,7 +362,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif
}
#endif
Expand Down Expand Up @@ -406,7 +414,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif
}
#endif
Expand Down

0 comments on commit 39430bf

Please sign in to comment.