Skip to content

Commit

Permalink
revert change in tensor_core_tile_layout.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Jan 9, 2025
1 parent c678cb0 commit 08d1cfb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 60 deletions.
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def get_extensions():
if IS_ROCM and use_cuda:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != 'gfx942':
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print("Currently only gfx942 is supported. Skipping compilation of ROCm extensions")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
)
return None
sources += hip_sources

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
Expand All @@ -7,24 +7,13 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif

template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}

#if defined(USE_ROCM)
constexpr int32_t kWarpSize = 64;
#else
constexpr int32_t kWarpSize = 32;
#endif

//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
Expand All @@ -41,71 +30,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;

// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
// AMD MI300X ISA that performs two bitwise operations in a single instruction:
// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif

#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}

// This is the BF16 {-136, -136} represented as an integer.
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#endif

// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
#endif
}

return result;
Expand Down Expand Up @@ -167,22 +123,11 @@ __global__ void _dequantize_int4_kernel(
// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
#if defined(USE_ROCM)
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));

if (scales_and_zeros) {
const auto& sz = *scales_and_zeros;
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);

scale2 = __bfloat162bfloat162(pSZ[0]);
zero2 = __bfloat162bfloat162(pSZ[1]);
}
#else
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);

// Vectorize scales and zeros
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
#endif

#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down

0 comments on commit 08d1cfb

Please sign in to comment.