Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing own quite naive gemv kernel as replacement of default used in nn.Linear gives 20% better speed on MI100 #1408

Open
Epliz opened this issue May 3, 2024 · 0 comments

Comments

@Epliz
Copy link

Epliz commented May 3, 2024

🐛 Describe the bug

Hi,

I profiled the generation of text with the Mistral 7b LLM on my MI100 GPU and saw that some gemv fp16 kernels don't seem to reach memory bandwidth.
Implementing my own naive gemv kernel seems to improve the performance by ~20% overall, but apparently 2x for some specific instances.

Code to generate original trace:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id,padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model: nn.Module = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device="cuda", dtype=torch.float16)

from typing import List, Union


def generate(model, prompt:Union[str, List[str]], max_new_tokens=20) -> Union[str, List[str]]:
  single_prompt = isinstance(prompt, str)
  if single_prompt:
    prompts = [prompt]
  else:
    prompts = prompt

  with torch.no_grad():
    inputs = tokenizer(prompts, return_tensors="pt", padding="longest").to(device="cuda")
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True)
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

  texts = [text[len(prompts[i]):] for i, text in enumerate(texts)]

  if single_prompt:
    return texts[0]
  else:
    return texts
  

def time_func(f):
  import time
  start_time = time.time()
  ret = f()
  end_time = time.time()
  elapsed_time = end_time - start_time
  return ret, elapsed_time

def profile_func(f, trace_path= "trace.json"):
  from torch.profiler import profile, ProfilerActivity
  with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    ret = f()
  prof.export_chrome_trace(trace_path)
  return ret

text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
text, time = time_func(lambda: generate(model, "Hello my name is", 50))
print("[Optimized] Completion: ", text)
print("[Optimized] Time: ", time)
text, time = profile_func(lambda: time_func(lambda: generate(model, "Hello my name is", 50)), trace_path="trace_orig.json")

Naive gemv kernel that seems faster:

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include <cuda_fp16.h>

#define ROWS_PER_BLOCK 4
#define THREADS_PER_BLOCK 256

#define DIV_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define FULL_MASK32 0xffffffff
#define FULL_MASK64 0xffffffffffffffff

#ifdef  __CUDA_ARCH__
#define __xx_shfl_down(mask, val, offset) __shfl_down_sync(mask, val, offset)
#elif defined(__HIP_PLATFORM_AMD__) // AMD
#define __xx_shfl_down(mask, val, offset) __shfl_down(val, offset)
#else
#error "Unsupported compiler"
#endif

__device__ float warpReduce(float val) {
  if (warpSize == 32) {
    for (int offset = 16; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK32, val, offset);
  }
  if (warpSize == 64) {
    for (int offset = 32; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK64, val, offset);

  }
  return val;
}

__global__ void muillm_gemv_kernel(
    const half* __restrict__ W, // weight matrix - size N x K
    const half* __restrict__ B, // optional bias - size N
    const half* __restrict__ X, // input = size K
    half* __restrict__ Y, // output - size N
    unsigned N,
    unsigned K
) {
  int warpCounts = THREADS_PER_BLOCK / warpSize;
  int warpId = threadIdx.x / warpSize;
  int laneId = threadIdx.x % warpSize;

#if ROWS_PER_BLOCK == 4
  // shared state to do the reductions
  __shared__ float shared_accs[ROWS_PER_BLOCK];
  __shared__ int shared_reduction_counter;

  if (laneId == 0) {
    shared_accs[warpId] = 0.f;
    shared_reduction_counter = 0;
  }
  __syncthreads();

  // compute the t-th element of Y. by doing the dot product with the
  // t-th row of W
  int current_row = blockIdx.x * ROWS_PER_BLOCK + 0;
  const half* W0 = &W[(current_row + 0) * K];
  const half* W1 = &W[(current_row + 1) * K];
  const half* W2 = &W[(current_row + 2) * K];
  const half* W3 = &W[(current_row + 3) * K];

  // do the dot product
  float acc0 = 0.f;
  float acc1 = 0.f;
  float acc2 = 0.f;
  float acc3 = 0.f;
  for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
    float x = __half2float(X[k]);
    float w0 = __half2float(W0[k]);
    float w1 = __half2float(W1[k]);
    float w2 = __half2float(W2[k]);
    float w3 = __half2float(W3[k]);
    acc0 += w0 * x;
    acc1 += w1 * x;
    acc2 += w2 * x;
    acc3 += w3 * x;
  }

  // warp reduce
  acc0 = warpReduce(acc0);
  acc1 = warpReduce(acc1);
  acc2 = warpReduce(acc2);
  acc3 = warpReduce(acc3);

  // reduce accross warps
  if (laneId == 0) {
    atomicAdd(&shared_accs[0], acc0);
    atomicAdd(&shared_accs[1], acc1);
    atomicAdd(&shared_accs[2], acc2);
    atomicAdd(&shared_accs[3], acc3);
    int old_count = atomicAdd(&shared_reduction_counter, 1);

    if (old_count == (warpCounts - 1)) {
      // we are the last warp to contribute
      // do the final write to memory

      acc0 = shared_accs[0]; // read the fully reduced value
      acc1 = shared_accs[1]; // read the fully reduced value
      acc2 = shared_accs[2]; // read the fully reduced value
      acc3 = shared_accs[3]; // read the fully reduced value
      if (B != nullptr) { // add the bias first if there is one
        acc0 += __half2float(B[current_row + 0]);
        acc1 += __half2float(B[current_row + 1]);
        acc2 += __half2float(B[current_row + 2]);
        acc3 += __half2float(B[current_row + 3]);
      }

      // write the output value
      Y[current_row + 0] = __float2half(acc0);
      Y[current_row + 1] = __float2half(acc1);
      Y[current_row + 2] = __float2half(acc2);
      Y[current_row + 3] = __float2half(acc3);
    }
  }

#else
  // shared state to do the reductions
  __shared__ float shared_accs[ROWS_PER_BLOCK];
  __shared__ int shared_reduction_counters[ROWS_PER_BLOCK];

  if (laneId == 0) {
    shared_accs[warpId] = 0.f;
    shared_reduction_counters[warpId] = 0;
  }
  __syncthreads();

  for (int i = 0; i < ROWS_PER_BLOCK; i++) {
    // compute the t-th element of Y. by doing the dot product with the
    // t-th row of W
    int current_row = blockIdx.x * ROWS_PER_BLOCK + i;
    const half* W_ = &W[current_row * K];
  
    // do the dot product
    float acc = 0.f;
    for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
      float w = __half2float(W_[k]);
      acc += w * __half2float(X[k]);
    }

    // warp reduce
    acc = warpReduce(acc);

    // reduce accross warps
    if (laneId == 0) {
      atomicAdd(&shared_accs[i], acc);
      int old_count = atomicAdd(&shared_reduction_counters[i], 1);

      if (old_count == (warpCounts - 1)) {
        // we are the last warp to contribute
        // do the final write to memory

        acc = shared_accs[i]; // read the fully reduced value
        if (B != nullptr) { // add the bias first if there is one
          acc += __half2float(B[current_row]);
        }

        // write the output value
        Y[current_row] = __float2half(acc);
      }
    }
  }
#endif

}

at::Tensor muillm_linear_forward_cuda(
    torch::Tensor& weights,
    torch::Tensor* bias,
    torch::Tensor& x) {

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  const auto N = weights.size(0);
  const auto K = weights.size(1);

  auto dtype = torch::kFloat16;
  auto output_options = at::TensorOptions()
                            .dtype(dtype)
                            .layout(at::kStrided)
                            .device(at::kCUDA)
                            .requires_grad(false);

  // y has the same dimensions as x, except the last dim that is given by
  // the out_features of weights
  auto output_sizes = x.sizes().vec();
  output_sizes[output_sizes.size() - 1] = N;

  auto y = torch::empty(output_sizes, output_options);

  const int threads_per_blocks = 256;
  const int num_blocks = DIV_ROUND_UP(N, ROWS_PER_BLOCK);

  muillm_gemv_kernel<<<num_blocks, threads_per_blocks, 0, stream>>>(
    (const half*)weights.data_ptr(),
    bias == nullptr ? nullptr : (const half*)bias->data_ptr(),
    (const half*)x.data_ptr(),
    (half*)y.data_ptr(),
    N,
    K
  );

  return y;
}

Let me know if you want the glue code to get this kernel usable from torch.

Versions

Environment:

Collecting environment information...
PyTorch version: 2.3.0.dev20240204+rocm5.7
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 5.7.31921-d1770ee1b

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI100 (gfx908:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 5.7.31921
MIOpen runtime version: 2.20.0
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             16
On-line CPU(s) list:                0-15
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 7 5800X3D 8-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
Stepping:                           2
Frequency boost:                    enabled
CPU max MHz:                        4548.8281
CPU min MHz:                        2200.0000
BogoMIPS:                           6800.77
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          256 KiB (8 instances)
L1i cache:                          256 KiB (8 instances)
L2 cache:                           4 MiB (8 instances)
L3 cache:                           96 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton-rocm==3.0.0+dafe145982
[pip3] torch==2.3.0.dev20240204+rocm5.7
[pip3] torchaudio==2.2.0.dev20240204+rocm5.7
[pip3] torchvision==0.18.0.dev20240204+rocm5.7
[conda] Could not collect

Python packages:

accelerate==0.28.0
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
build==1.2.1
certifi==2022.12.7
charset-normalizer==2.1.1
comm==0.2.1
contourpy==1.2.0
cycler==0.12.1
datasets==2.16.1
debugpy==1.8.1
decorator==5.1.1
deepspeed==0.14.0
diffusers==0.27.2
dill==0.3.7
exceptiongroup==1.2.0
executing==2.0.1
filelock==3.9.0
fonttools==4.48.1
frozenlist==1.4.1
fsspec==2023.10.0
hjson==3.1.0
huggingface-hub==0.20.3
idna==3.4
importlib_metadata==7.1.0
ipykernel==6.29.2
ipython==8.21.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
jupyter_client==8.6.0
jupyter_core==5.7.1
kiwisolver==1.4.5
MarkupSafe==2.1.3
matplotlib==3.8.2
matplotlib-inline==0.1.6
mpmath==1.2.1
multidict==6.0.5
multiprocess==0.70.15
nest-asyncio==1.6.0
networkx==3.0rc1
ninja==1.11.1.1
numpy==1.24.1
packaging==23.2
pandas==2.2.0
parso==0.8.3
peft==0.8.2
pexpect==4.9.0
Pillow==9.3.0
platformdirs==4.2.0
prompt-toolkit==3.0.43
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==15.0.0
pyarrow-hotfix==0.6
pydantic==2.6.4
pydantic_core==2.16.3
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.1
pyproject_hooks==1.0.0
python-dateutil==2.8.2
pytorch-triton-rocm==3.0.0+dafe145982
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
regex==2023.12.25
requests==2.28.1
safetensors==0.4.2
scikit-learn==1.4.0
scipy==1.12.0
six==1.16.0
stack-data==0.6.3
sympy==1.11.1
threadpoolctl==3.2.0
tokenizers==0.15.1
tomli==2.0.1
torch==2.3.0.dev20240204+rocm5.7
torchaudio==2.2.0.dev20240204+rocm5.7
torchvision==0.18.0.dev20240204+rocm5.7
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
transformers==4.37.2
typing_extensions==4.8.0
tzdata==2023.4
UNKNOWN==0.0.0
urllib3==1.26.13
wcwidth==0.2.13
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants