Skip to content

Commit

Permalink
CPU/CUDA: Gemma 2 FlashAttention support (#8542)
Browse files Browse the repository at this point in the history
* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
  • Loading branch information
JohannesGaessler authored Aug 24, 2024
1 parent 8f824ff commit e11bd85
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 79 deletions.
3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
float max_bias);
float max_bias,
float logit_softcap);

GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
Expand Down
17 changes: 12 additions & 5 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
Expand Down Expand Up @@ -657,11 +658,17 @@ void launch_fattn(
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale = 1.0f;
float max_bias = 0.0f;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));

if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
Expand All @@ -675,7 +682,7 @@ void launch_fattn(
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Expand Down
52 changes: 43 additions & 9 deletions ggml/src/ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#define FATTN_KQ_STRIDE_TILE_F16 64

template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
Expand All @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
Expand All @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
Expand Down Expand Up @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;

half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
half sum;
if (use_logit_softcap) {
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
sum = logit_softcap * tanhf(tmp.x + tmp.y);
} else {
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);

kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
Expand Down Expand Up @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}

template <int cols_per_block, int parallel_blocks>
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
Expand All @@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];

const int32_t precision = KQV->op_params[2];
const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);

float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}

constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
}
47 changes: 40 additions & 7 deletions ggml/src/ggml-cuda/fattn-tile-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#define FATTN_KQ_STRIDE_TILE_F32 32

template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
Expand All @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
Expand All @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
Expand Down Expand Up @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;

if (use_logit_softcap) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}

sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;

kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
Expand Down Expand Up @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
}
}

template <int cols_per_block, int parallel_blocks>
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
Expand All @@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
}

void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];

float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}

constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
}
Loading

0 comments on commit e11bd85

Please sign in to comment.