Skip to content

Commit

Permalink
Adding bf16 output dtype for fp8 gemm (opendatahub-io#111)
Browse files Browse the repository at this point in the history
* add unified fp8 gemm kernel

* * merge all fp8 gemm kernel into one

* merge fp8 gemm kernel

* add bf16 output for fp8 gemm

* fix lint

* fix lint

* fix lint
  • Loading branch information
charlifu authored Jul 29, 2024
1 parent 200fbea commit 904b5b8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 264 deletions.
10 changes: 3 additions & 7 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,9 @@ void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data,
torch::Tensor& scale);

#ifdef USE_ROCM
torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
torch::Tensor& scaleA, torch::Tensor& scaleB,
torch::Tensor& scaleD, int algo_idx);

torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
torch::Tensor& scaleA, torch::Tensor& scaleB,
int algo_idx);
void fp8_mm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& result,
torch::Tensor& scale_a, torch::Tensor& scale_b,
const c10::optional<torch::Tensor>& scale_result, int64_t algo_idx);

void create_workspace();
#endif
Expand Down
3 changes: 1 addition & 2 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Convert the key and value cache to fp8 data type");

#ifdef USE_ROCM
ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output");
ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output");
ops.def("fp8_mm", &fp8_mm, "fp8 GEMM with fp8 fp16 bf16 output type");
ops.def("create_workspace", &create_workspace,
"Create workspace for fp8 GEMM");
#endif
Expand Down
208 changes: 28 additions & 180 deletions csrc/quantization/fp8/amd/gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,171 +65,35 @@ void create_workspace() {
CHECK_HIP_ERROR(hipMalloc(&workspace, workspace_size));
}

torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
torch::Tensor& scaleA, torch::Tensor& scaleB,
torch::Tensor& scaleD, int algo_idx) {
void fp8_mm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& result,
torch::Tensor& scale_a, torch::Tensor& scale_b,
const c10::optional<torch::Tensor>& scale_result,
int64_t algo_idx) {
auto a_strides{a.strides()};
auto b_strides{b.strides()};
auto a_sizes{a.sizes()};
auto b_sizes{b.sizes()};

// CHECK_INPUT(a);
// CHECK_INPUT(b);
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz &&
b.dtype() == torch::kFloat8_e4m3fnuz,
"The input tensors should be in fp8.");
"The input tensors type should be float8_e4m3fnuz.");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D.");
TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0.");

auto options{
at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)};
auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)};

constexpr bool transpose_result = true;
bool transpose_a;
bool transpose_b;
if ((b_strides[0] == 1) &&
(b_strides[1] >= std::max<int64_t>(1, b_sizes[0]))) {
transpose_b = false;
} else if ((b_strides[1] == 1) &&
(b_strides[0] >= std::max<int64_t>(1, b_sizes[1]))) {
transpose_b = true;
auto out_dtype = result.dtype();
TORCH_CHECK(out_dtype == torch::kFloat8_e4m3fnuz ||
out_dtype == torch::kFloat16 || out_dtype == torch::kBFloat16,
"Only float16, bfloat16 or float8_e4m3fnuz are supported as the "
"output dtype.");
hipblasDatatype_t hipblas_out_type;
if (out_dtype == torch::kFloat8_e4m3fnuz) {
hipblas_out_type = HIP_R_8F_E4M3_FNUZ;
} else if (out_dtype == torch::kBFloat16) {
hipblas_out_type = HIP_R_16BF;
} else {
assert(false &&
"unusual strides detected, may need to clone a contiguous tensor");
}
if ((a_strides[0] == 1) &&
(a_strides[1] >= std::max<int64_t>(1, a_sizes[0]))) {
transpose_a = false;
} else if ((a_strides[1] == 1) &&
(a_strides[0] >= std::max<int64_t>(1, a_sizes[1]))) {
transpose_a = true;
} else {
assert(false &&
"unusual strides detected, may need to clone a contiguous tensor");
}

if (transpose_result) {
bool tmp = transpose_a;
transpose_a = !transpose_b;
transpose_b = !tmp;
a_strides = b.strides();
b_strides = a.strides();
a_sizes = b.sizes();
b_sizes = a.sizes();
hipblas_out_type = HIP_R_16F;
}

float alpha = 1.0f;
float beta = 0.0f;
int64_t m = a_sizes[transpose_result ? 1 : 0];
int64_t k = a_sizes[transpose_result ? 0 : 1];
int64_t n = b_sizes[transpose_result ? 0 : 1];

void* d_a = static_cast<void*>((transpose_result ? b : a).data_ptr());
void* d_b = static_cast<void*>((transpose_result ? a : b).data_ptr());
void* d_d = static_cast<void*>(result.data_ptr());

// void *d_scaleA, *d_scaleB, *d_workspace;
// CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float)));
// CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float)));
// CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size));
// CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA),
// sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB,
// &(transpose_result ? scaleA : scaleB), sizeof(float),
// hipMemcpyHostToDevice));
auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr();
auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr();
auto d_scaleD = scaleD.data_ptr();

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::cuda::getCurrentCUDAStream();

hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspace_size);
hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ,
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ,
HIPBLAS_COMPUTE_32F);

hipblaslt_ext::GemmEpilogue
epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT.
// (Gemm only)
hipblaslt_ext::GemmInputs inputs;
inputs.a = d_a;
inputs.b = d_b;
inputs.c = d_d;
inputs.d = d_d;
inputs.alpha = &alpha;
inputs.beta = &beta;
inputs.scaleA = d_scaleA;
inputs.scaleB = d_scaleB;
inputs.scaleD = d_scaleD;

auto&& problem = gemm.getProblemTypes();
auto lda = problem.op_a == HIPBLAS_OP_N ? m : k;
auto ldb = problem.op_b == HIPBLAS_OP_N ? k : n;
auto ldc = m;
auto strideA = m * k;
auto strideB = n * k;
auto strideC = m * n;

CHECK_HIPBLASLT_ERROR(gemm.setProblem(m, n, k, 1, lda, ldb, ldc, ldc, strideA,
strideB, strideC, strideC, epilogue,
inputs, problem));

if (algo_idx == 0) {
constexpr int request_solutions = 1024;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
heuristicResult.reserve(request_solutions);
CHECK_HIPBLASLT_ERROR(
gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
static size_t solSize = 0;
if (heuristicResult.size() != solSize) {
std::cout << "fp8 sols: " << heuristicResult.size() << "\n";
solSize = heuristicResult.size();
for (auto& res : heuristicResult) {
auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo);
std::cout << idx << "\n";
}
}
TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!");
algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo);
}
std::vector<int> algoIndex(1);
algoIndex[0] = algo_idx;
std::vector<hipblasLtMatmulHeuristicResult_t> tmpAlgo;
TORCH_CUDABLAS_CHECK(
hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo));

CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace));
CHECK_HIPBLASLT_ERROR(gemm.run(stream));

// hipFree(d_scaleA);
// hipFree(d_scaleB);

return result;
}

torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
torch::Tensor& scaleA, torch::Tensor& scaleB,
int algo_idx) {
auto a_strides{a.strides()};
auto b_strides{b.strides()};
auto a_sizes{a.sizes()};
auto b_sizes{b.sizes()};

// CHECK_INPUT(a);
// CHECK_INPUT(b);
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz &&
b.dtype() == torch::kFloat8_e4m3fnuz,
"The input tensors should be in fp8.");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D.");
TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0.");

auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)};
auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)};

constexpr bool transpose_result = true;
bool transpose_a;
bool transpose_b;
Expand Down Expand Up @@ -274,16 +138,10 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
void* d_b = static_cast<void*>((transpose_result ? a : b).data_ptr());
void* d_d = static_cast<void*>(result.data_ptr());

// void *d_scaleA, *d_scaleB, *d_workspace;
// CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float)));
// CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float)));
// CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size));
// CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA),
// sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB,
// &(transpose_result ? scaleA : scaleB), sizeof(float),
// hipMemcpyHostToDevice));
auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr();
auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr();
auto d_scale_a = transpose_result ? scale_a.data_ptr() : scale_a.data_ptr();
auto d_scale_b = transpose_result ? scale_b.data_ptr() : scale_b.data_ptr();
auto d_scale_d =
scale_result.has_value() ? scale_result.value().data_ptr() : nullptr;

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -292,9 +150,11 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
gemmPref.setMaxWorkspaceBytes(workspace_size);
hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F,
HIP_R_16F, HIPBLAS_COMPUTE_32F);
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ,
hipblas_out_type, hipblas_out_type,
HIPBLAS_COMPUTE_32F);

// TODO(HaiShaw): Add Epilogue usage in cases to support Bias, etc.
hipblaslt_ext::GemmEpilogue
epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT.
// (Gemm only)
Expand All @@ -305,8 +165,9 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
inputs.d = d_d;
inputs.alpha = &alpha;
inputs.beta = &beta;
inputs.scaleA = d_scaleA;
inputs.scaleB = d_scaleB;
inputs.scaleA = d_scale_a;
inputs.scaleB = d_scale_b;
inputs.scaleD = d_scale_d;

auto&& problem = gemm.getProblemTypes();
auto lda = problem.op_a == HIPBLAS_OP_N ? m : k;
Expand All @@ -320,20 +181,12 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
strideB, strideC, strideC, epilogue,
inputs, problem));
if (algo_idx == 0) {
constexpr int request_solutions = 1024;
constexpr int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
heuristicResult.reserve(request_solutions);
CHECK_HIPBLASLT_ERROR(
gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
static size_t solSize = 0;
if (heuristicResult.size() != solSize) {
std::cout << "fp16 sols: " << heuristicResult.size() << "\n";
solSize = heuristicResult.size();
for (auto& res : heuristicResult) {
auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo);
std::cout << idx << "\n";
}
}
algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo);
TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!");
}
Expand All @@ -345,9 +198,4 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,

CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace));
CHECK_HIPBLASLT_ERROR(gemm.run(stream));

// hipFree(d_scaleA);
// hipFree(d_scaleB);

return result;
}
16 changes: 14 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,20 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,

def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
scale: float = 1.0) -> None:
vllm_ops.convert_fp8(output, input, torch.Tensor([scale]))
scale: Optional[torch.Tensor] = None) -> None:
if scale is None:
scale = torch.Tensor([1.0], device=input.device)
vllm_ops.convert_fp8(output, input, scale)


def fp8_mm(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
scale_result: Optional[torch.Tensor], solidx: int) -> torch.Tensor:
result = torch.empty((a.shape[0], b.shape[1]),
dtype=out_dtype,
device=a.device)
vllm_ops.fp8_mm(a, b, result, scale_a, scale_b, scale_result, solidx)
return result


#TODO: cuda_utils, custom_ar
Loading

0 comments on commit 904b5b8

Please sign in to comment.