Skip to content

Commit

Permalink
CUDA: add FP32 FlashAttention vector kernel (ggerganov#7188)
Browse files Browse the repository at this point in the history
* CUDA: add FP32 FlashAttention vector kernel

* fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
  • Loading branch information
JohannesGaessler authored May 12, 2024
1 parent 6f1b636 commit dc685be
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 458 deletions.
11 changes: 10 additions & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2713,6 +2713,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}

GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
Expand Down Expand Up @@ -2840,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
default:
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {

#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA

static bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
}

static bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
}
Expand Down
47 changes: 47 additions & 0 deletions ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.

template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst) {
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
dst += D * gridDim.y*blockIdx.x;

const int tid = threadIdx.x;
__builtin_assume(tid < D);

__shared__ float2 meta[parallel_blocks];
if (tid < 2*parallel_blocks) {
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
}

__syncthreads();

float kqmax = meta[0].x;
#pragma unroll
for (int l = 1; l < parallel_blocks; ++l) {
kqmax = max(kqmax, meta[l].x);
}

float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
#pragma unroll
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
const float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;

VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}

dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
}
Loading

0 comments on commit dc685be

Please sign in to comment.