Skip to content

Commit

Permalink
[torch.compile] Dynamic fp8 + rms_norm fusion (vllm-project#10906)
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
2 people authored and BKitor committed Dec 30, 2024
1 parent be75cc2 commit eeb7c15
Show file tree
Hide file tree
Showing 20 changed files with 1,736 additions and 252 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
Expand Down Expand Up @@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
Expand Down
173 changes: 173 additions & 0 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import pickle as pkl
import time
from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm


@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
add_residual: bool
dtype: torch.dtype

def description(self):
return (f'N {self.num_tokens} '
f'x D {self.hidden_size} '
f'x R {self.add_residual} '
f'x DT {self.dtype}')


def get_bench_params() -> List[bench_params_t]:
## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]

combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
return bench_params


# Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _, _ = ops.scaled_int8_quant(torch_out)


def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = ops.scaled_fp8_quant(torch_out)


def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual=residual)


# Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
quant_dtype: torch.dtype, label: str, sub_label: str,
fn: Callable, description: str) -> TMeasurement:

min_run_time = 1

globals = {
"rms_norm_layer": rms_norm_layer,
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)

def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:

# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens,
params.hidden_size,
dtype=params.dtype,
device='cuda') * scale
residual = (torch.randn_like(x) * scale).to(device='cuda') \
if params.add_residual else None

timers = []

# unfused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label,
unfused_int8_impl, "unfused_int8_impl"))

# unfused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
unfused_fp8_impl, "unfused_fp8_impl"))

# fused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
"fused_int8_impl"))

# fused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
fused_impl, "fused_fp8_impl"))

print_timers(timers)

return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def main():
torch.set_default_device('cuda')
bench_params = get_bench_params()

timers = []
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)

# pickle all the results
timestamp = int(time.time())
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
pkl.dump(timers, f)


if __name__ == '__main__':
main()
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);

void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
26 changes: 7 additions & 19 deletions csrc/quantization/fp8/common.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once

#include "quantization/vectorization.cuh"

#include <cmath>
#include <c10/core/ScalarType.h>

#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
Expand All @@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;

namespace vllm {

Expand Down Expand Up @@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
}
}

template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};

typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
Expand Down Expand Up @@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<FP8_TYPE>;
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int64_t const num_vec_elems = num_elems >> 2;

Expand Down
Loading

0 comments on commit eeb7c15

Please sign in to comment.