Skip to content

Commit

Permalink
Add float16 (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored Dec 19, 2023
1 parent 52013f9 commit 9d54317
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 33 deletions.
19 changes: 16 additions & 3 deletions ffi/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ static DataType dtypeFromPyTorch(torch::ScalarType t) {
return DataType::Int32;
case torch::ScalarType::Long:
return DataType::Int64;
case torch::ScalarType::Half:
return DataType::Float16;
case torch::ScalarType::Float:
return DataType::Float32;
case torch::ScalarType::Double:
Expand All @@ -69,6 +71,8 @@ static torch::ScalarType dtypeToPyTorch(DataType dtype) {
return torch::ScalarType::Int;
case DataType::Int64:
return torch::ScalarType::Long;
case DataType::Float16:
return torch::ScalarType::Half;
case DataType::Float32:
return torch::ScalarType::Float;
case DataType::Float64:
Expand Down Expand Up @@ -116,9 +120,10 @@ void init_ffi_array(py::module_ &m) {
// or it will all end up in float64 (the first initializer)
throw DriverError(
"Unsupported data type or strides from a NumPy Array. "
"Please "
"use freetensor.array factory function, instead of "
"freetensor.Array, for strided arrays");
"If you are using strided arrays, please use "
"freetensor.array factory function, instead of "
"freetensor.Array. If you are using float16, please use "
"the PyTorch interface instead.");
}),
"data"_a, "dont_drop_borrow"_a = false, "moved"_a = false)
.def("__eq__", [](const Ref<Array> &lhs, const Ref<Array> &rhs) {
Expand Down Expand Up @@ -172,6 +177,14 @@ void init_ffi_array(py::module_ &m) {
SHARE_TO_NUMPY(int64_t, DataType::Int64)
SHARE_TO_NUMPY(int32_t, DataType::Int32)
SHARE_TO_NUMPY(bool, DataType::Bool)
case DataType::Float16:
// TODO: Support fp16 after PyBind11 for NumPy fp16 is
// available. Status:
// https://github.com/pybind/pybind11/issues/1776,
// https://github.com/pybind/pybind11/issues/4061
throw DriverError(
"NumPy interface for float16 is not supported yet. Please "
"use the PyTorch interface instead.");
default:
ASSERT(false);
}
Expand Down
2 changes: 1 addition & 1 deletion include/codegen/code_gen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <class Stream> class CodeGenC : public CodeGen<Stream> {
const std::vector<FuncRet> &returns)
: params_(params), returns_(returns) {}

static std::string gen(DataType dtype);
virtual std::string gen(const DataType &dtype);

protected:
virtual void genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
Expand Down
3 changes: 3 additions & 0 deletions include/codegen/code_gen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {

Expr globalSize() const { return globalSize_; }

std::string gen(const DataType &dtype) override;

private:
bool inKernel() const;

Expand Down Expand Up @@ -87,6 +89,7 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
void visit(const Abs &op) override;
void visit(const Floor &op) override;
void visit(const Ceil &op) override;
void visit(const Cast &op) override;
void visit(const ReduceTo &op) override;
void visit(const Var &op) override;
void visit(const For &op) override;
Expand Down
1 change: 1 addition & 0 deletions include/pass/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConstFold : public Mutator {
case DataType::Int32:
case DataType::Int64:
return wrap(int64_t(v));
case DataType::Float16:
case DataType::Float32:
case DataType::Float64:
return wrap(double(v));
Expand Down
5 changes: 4 additions & 1 deletion include/type/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace freetensor {

enum class BaseDataType : size_t {
Void = 0, // Returns nothing. It is a Unit Type
Float16,
Float32,
Float64,
Int32,
Expand All @@ -24,7 +25,8 @@ enum class BaseDataType : size_t {
};

constexpr std::array baseDataTypeNames = {
"void", "float32", "float64", "int32", "int64", "bool", "custom", "never",
"void", "float16", "float32", "float64", "int32",
"int64", "bool", "custom", "never",
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

Expand Down Expand Up @@ -105,6 +107,7 @@ class DataType {
// and remove the following lines
constexpr static auto Bool = BaseDataType::Bool;
constexpr static auto Custom = BaseDataType::Custom;
constexpr static auto Float16 = BaseDataType::Float16;
constexpr static auto Float32 = BaseDataType::Float32;
constexpr static auto Float64 = BaseDataType::Float64;
constexpr static auto Int32 = BaseDataType::Int32;
Expand Down
4 changes: 2 additions & 2 deletions python/freetensor/core/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def process_annotating_comments(src: str):
for line in src.splitlines():
indent = re.match('\\s*', line)[0]
rest_line = line[len(indent):]
if rest_line.startswith('#! '):
arg = rest_line[3:].replace('"', '\\"')
if rest_line.startswith('#!'):
arg = rest_line[2:].strip().replace('"', '\\"')
new_src.append(f'{indent}__staging_overload__.metadata("{arg}")')
else:
new_src.append(line)
Expand Down
7 changes: 5 additions & 2 deletions runtime/gpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include <stdexcept>
#include <type_traits>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include <cuda_fp16.h>

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>

#include "gpu_context.h"

Expand Down
155 changes: 139 additions & 16 deletions src/codegen/code_gen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

namespace freetensor {

static std::string genCUBLASType(DataType dtype) {
static std::string genCUBLASType(const DataType &dtype) {
switch (dtype.base()) {
case DataType::Float64:
return "CUDA_R_64F";
case DataType::Float32:
return "CUDA_R_32F";
case DataType::Float16:
return "CUDA_R_16F";
case DataType::Int64:
return "CUDA_R_64I";
case DataType::Int32:
Expand All @@ -30,9 +32,32 @@ static std::string genCUBLASType(DataType dtype) {
}
}

static std::string genCUTLASSType(const DataType &dtype) {
switch (dtype.base()) {
case DataType::Float64:
return "double";
case DataType::Float32:
return "float";
case DataType::Float16:
return "cutlass::half_t";
case DataType::Int64:
return "int64_t";
case DataType::Int32:
return "int32_t";
case DataType::Bool:
return "bool";
default:
ASSERT(false);
}
}

static bool canUseTensorCore(const Ref<GPUTarget> &target, DataType dtypeA,
DataType dtypeB, DataType dtypeC) {
// TODO: fp16 is supported after sm70
if (target->computeCapability().first >= 7 && dtypeA == DataType::Float16 &&
dtypeB == DataType::Float16 &&
(dtypeC == DataType::Float16 || dtypeC == DataType::Float32)) {
return true;
}
if (target->computeCapability().first >= 8 && dtypeA == DataType::Float64 &&
dtypeB == DataType::Float64 && dtypeC == DataType::Float64) {
return true;
Expand All @@ -49,7 +74,17 @@ CodeGenCUDA::genMdPtrType(const VarDef &def, bool isConst) {
// Use pointer instead of reference for scalars, because when passing an
// argument from host to a kernel, a reference means copy the value from
// CPU to GPU, while a pointer means passing the address

// NOTE: `[=]` implicitly capturing `this` is deprecated in C++20. Using
// `[=]` will trigger a warning in GCC (because of deprecation), but
// using
// `[=, this]` will trigger a warning in Clang<17 (because it will think
// `this` is duplicated).
#if defined(__clang__) && __clang_major__ < 17
return [=](std::ostream &os) -> std::ostream & {
#else
return [=, this](std::ostream &os) -> std::ostream & {
#endif
if (isConst) {
os << "const ";
}
Expand Down Expand Up @@ -77,6 +112,14 @@ void CodeGenCUDA::genMdPtrDef(const VarDef &def,
CodeGenC<CodeGenCUDAStream>::genMdPtrDef(def, genRawPtr, isConst);
}

std::string CodeGenCUDA::gen(const DataType &dtype) {
if (dtype == DataType::Float16) {
return "__half";
} else {
return CodeGenC::gen(dtype);
}
}

void CodeGenCUDA::genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
const std::string &shapePtr,
const std::string &dimPtr) {
Expand Down Expand Up @@ -383,6 +426,72 @@ void CodeGenCUDA::visit(const Ceil &op) {
os() << ")";
}

void CodeGenCUDA::visit(const Cast &op) {
if (op->destType_.base() == DataType::Float16) {
switch (op->expr_->dtype().base()) {
case DataType::Int32:
os() << "__int2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Int64:
os() << "__ll2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float16:
(*this)(op->expr_);
break;
case DataType::Float32:
os() << "__float2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float64:
os() << "__double2half("; // Always `_rn` (round to nearest even)
(*this)(op->expr_);
os() << ")";
break;
default:
throw InvalidProgram("Converting from " +
freetensor::toString(op->dtype()) +
" to float16 is not supported");
}
} else if (op->expr_->dtype().base() == DataType::Float16) {
switch (op->destType_.base()) {
case DataType::Int32:
os() << "__half2int_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Int64:
os() << "__half2ll_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float16:
(*this)(op->expr_);
break;
case DataType::Float32:
os() << "__half2float("; // Short to long, no rounding is needed
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float64:
os() << "__half2double("; // Short to long, no rounding is needed
(*this)(op->expr_);
os() << ")";
break;
default:
throw InvalidProgram("Converting from float16 to " +
freetensor::toString(op->dtype()) +
" is not supported");
}
} else {
CodeGenC::visit(op);
}
}

void CodeGenCUDA::visit(const Store &op) {
if (buffer(op->var_)->mtype() == MemType::GPUWarp) {
auto id = mangle(op->var_);
Expand Down Expand Up @@ -843,50 +952,64 @@ void CodeGenCUDA::visit(const MatMul &op) {
}

makeIndent();
os() << "cutlass::gemm::device::Gemm<" << gen(op->a_->dtype()) << ", "
beginBlock();
makeIndent();
os() << "using Gemm = cutlass::gemm::device::Gemm<"
<< genCUTLASSType(op->a_->dtype()) << ", "
<< (transA ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->b_->dtype()) << ", "
<< ", " << genCUTLASSType(op->b_->dtype()) << ", "
<< (transB ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->c_->dtype()) << ", "
<< ", " << genCUTLASSType(op->c_->dtype()) << ", "
<< (transC ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->c_->dtype()) // TODO: accumulator type
<< ", "
<< genCUTLASSType(op->c_->dtype()) // TODO: accumulator type
<< ", "
<< (canUseTensorCore(target_, op->a_->dtype(), op->b_->dtype(),
op->c_->dtype())
? "cutlass::arch::OpClassTensorOp"
: "cutlass::arch::OpClassSimt")
<< ", FT_CUTLASS_ARCH> gemm;" << std::endl;
<< ", FT_CUTLASS_ARCH>;" << std::endl;
makeIndent();
os() << "checkCutlassError(gemm({{";
os() << "Gemm gemm;" << std::endl;
// In order for clearer error message, please keep the explicit argument
// types in the following statement.
makeIndent();
os() << "checkCutlassError(gemm(Gemm::Arguments{{";
(*this)(m);
os() << ", ";
(*this)(n);
os() << ", ";
(*this)(k);
os() << "}, {&";
os() << "}, Gemm::TensorRefA{(const " << genCUTLASSType(op->a_->dtype())
<< "*)&";
(*this)(a);
os() << ", ";
(*this)(lda);
os() << "}, {&";
os() << "}, Gemm::TensorRefB{(const " << genCUTLASSType(op->b_->dtype())
<< "*)&";
(*this)(b);
os() << ", ";
(*this)(ldb);
os() << "}, {&";
os() << "}, Gemm::TensorRefC{(const " << genCUTLASSType(op->c_->dtype())
<< "*)&";
(*this)(c);
os() << ", ";
(*this)(ldc);
os() << "}, {&";
os() << "}, Gemm::TensorRefD{(" << genCUTLASSType(op->c_->dtype())
<< "*)&";
(*this)(c);
os() << ", ";
(*this)(ldc);
os() << "}, {";
os() << "}, Gemm::EpilogueOutputOp::Params{("
<< genCUTLASSType(op->c_->dtype()) << ")(";
(*this)(op->alpha_);
os() << ", ";
os() << "), (" << genCUTLASSType(op->c_->dtype()) << ")(";
(*this)(op->beta_);
os() << "}}, nullptr, __stream));" << std::endl;
os() << ")}}, nullptr, __stream));" << std::endl;
endBlock();
break;
}

Expand Down Expand Up @@ -988,7 +1111,7 @@ extern "C" {
for (size_t i = 0, iEnd = shape.size(); i < iEnd; i++) {
os << "__ByValArray<";
}
os << CodeGenCUDA::gen(tensor->dtype());
os << visitor.gen(tensor->dtype());
for (auto it = shape.rbegin(); it != shape.rend(); it++) {
ASSERT((*it)->nodeType() == ASTNodeType::IntConst);
os << ", " << (*it).as<IntConstNode>()->val_ << ">";
Expand Down
Loading

0 comments on commit 9d54317

Please sign in to comment.