From 08d1cfbe515098dcadfed25eb67a736b1440d910 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:15:11 -0600 Subject: [PATCH] revert change in tensor_core_tile_layout.cu --- setup.py | 6 +- .../tensor_core_tiled_layout.cu | 61 +------------------ 2 files changed, 7 insertions(+), 60 deletions(-) diff --git a/setup.py b/setup.py index 36810d0a49..fef71dcbdb 100644 --- a/setup.py +++ b/setup.py @@ -148,9 +148,11 @@ def get_extensions(): if IS_ROCM and use_cuda: # Add ROCm GPU architecture check gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != 'gfx942': + if gpu_arch != "gfx942": print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print("Currently only gfx942 is supported. Skipping compilation of ROCm extensions") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) return None sources += hip_sources diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index d1c5d49fda..d3ddd66fe6 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere #include #include @@ -7,24 +7,13 @@ #include #include -#if defined(USE_ROCM) -#include -#include -#include -#endif - template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); const uint64_t blocks = a / b + (a % b != 0); return blocks; } - -#if defined(USE_ROCM) -constexpr int32_t kWarpSize = 64; -#else constexpr int32_t kWarpSize = 32; -#endif //Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization //https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 @@ -41,71 +30,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. -#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; -#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; -// AMD MI300X ISA that performs two bitwise operations in a single instruction: -// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM -// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values -// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16 -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[0]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif - #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[ii]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif } // This is the BF16 {-136, -136} represented as an integer. -#if defined(USE_ROCM) -#if ROCM_VERSION >= 60200 - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); -#else - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); -#endif -#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; -#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction -#if defined(USE_ROCM) - result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); -#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); -#endif } return result; @@ -167,22 +123,11 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; -#if defined(USE_ROCM) - __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); - __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); - - if (scales_and_zeros) { - const auto& sz = *scales_and_zeros; - const __nv_bfloat16* pSZ = reinterpret_cast(&sz[qgroup][n0][0]); - - scale2 = __bfloat162bfloat162(pSZ[0]); - zero2 = __bfloat162bfloat162(pSZ[1]); - } -#else const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); -#endif #pragma unroll for (int i = 0; i < 4; i++) {