Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float16 #570

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions ffi/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ static DataType dtypeFromPyTorch(torch::ScalarType t) {
return DataType::Int32;
case torch::ScalarType::Long:
return DataType::Int64;
case torch::ScalarType::Half:
return DataType::Float16;
case torch::ScalarType::Float:
return DataType::Float32;
case torch::ScalarType::Double:
Expand All @@ -69,6 +71,8 @@ static torch::ScalarType dtypeToPyTorch(DataType dtype) {
return torch::ScalarType::Int;
case DataType::Int64:
return torch::ScalarType::Long;
case DataType::Float16:
return torch::ScalarType::Half;
case DataType::Float32:
return torch::ScalarType::Float;
case DataType::Float64:
Expand Down Expand Up @@ -116,9 +120,10 @@ void init_ffi_array(py::module_ &m) {
// or it will all end up in float64 (the first initializer)
throw DriverError(
"Unsupported data type or strides from a NumPy Array. "
"Please "
"use freetensor.array factory function, instead of "
"freetensor.Array, for strided arrays");
"If you are using strided arrays, please use "
"freetensor.array factory function, instead of "
"freetensor.Array. If you are using float16, please use "
"the PyTorch interface instead.");
}),
"data"_a, "dont_drop_borrow"_a = false, "moved"_a = false)
.def("__eq__", [](const Ref<Array> &lhs, const Ref<Array> &rhs) {
Expand Down Expand Up @@ -172,6 +177,14 @@ void init_ffi_array(py::module_ &m) {
SHARE_TO_NUMPY(int64_t, DataType::Int64)
SHARE_TO_NUMPY(int32_t, DataType::Int32)
SHARE_TO_NUMPY(bool, DataType::Bool)
case DataType::Float16:
// TODO: Support fp16 after PyBind11 for NumPy fp16 is
// available. Status:
// https://github.com/pybind/pybind11/issues/1776,
// https://github.com/pybind/pybind11/issues/4061
throw DriverError(
"NumPy interface for float16 is not supported yet. Please "
"use the PyTorch interface instead.");
default:
ASSERT(false);
}
Expand Down
2 changes: 1 addition & 1 deletion include/codegen/code_gen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <class Stream> class CodeGenC : public CodeGen<Stream> {
const std::vector<FuncRet> &returns)
: params_(params), returns_(returns) {}

static std::string gen(DataType dtype);
virtual std::string gen(const DataType &dtype);

protected:
virtual void genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
Expand Down
3 changes: 3 additions & 0 deletions include/codegen/code_gen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {

Expr globalSize() const { return globalSize_; }

std::string gen(const DataType &dtype) override;

private:
bool inKernel() const;

Expand Down Expand Up @@ -87,6 +89,7 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
void visit(const Abs &op) override;
void visit(const Floor &op) override;
void visit(const Ceil &op) override;
void visit(const Cast &op) override;
void visit(const ReduceTo &op) override;
void visit(const Var &op) override;
void visit(const For &op) override;
Expand Down
1 change: 1 addition & 0 deletions include/pass/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConstFold : public Mutator {
case DataType::Int32:
case DataType::Int64:
return wrap(int64_t(v));
case DataType::Float16:
case DataType::Float32:
case DataType::Float64:
return wrap(double(v));
Expand Down
5 changes: 4 additions & 1 deletion include/type/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace freetensor {

enum class BaseDataType : size_t {
Void = 0, // Returns nothing. It is a Unit Type
Float16,
Float32,
Float64,
Int32,
Expand All @@ -24,7 +25,8 @@ enum class BaseDataType : size_t {
};

constexpr std::array baseDataTypeNames = {
"void", "float32", "float64", "int32", "int64", "bool", "custom", "never",
"void", "float16", "float32", "float64", "int32",
"int64", "bool", "custom", "never",
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

Expand Down Expand Up @@ -105,6 +107,7 @@ class DataType {
// and remove the following lines
constexpr static auto Bool = BaseDataType::Bool;
constexpr static auto Custom = BaseDataType::Custom;
constexpr static auto Float16 = BaseDataType::Float16;
constexpr static auto Float32 = BaseDataType::Float32;
constexpr static auto Float64 = BaseDataType::Float64;
constexpr static auto Int32 = BaseDataType::Int32;
Expand Down
4 changes: 2 additions & 2 deletions python/freetensor/core/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def process_annotating_comments(src: str):
for line in src.splitlines():
indent = re.match('\\s*', line)[0]
rest_line = line[len(indent):]
if rest_line.startswith('#! '):
arg = rest_line[3:].replace('"', '\\"')
if rest_line.startswith('#!'):
arg = rest_line[2:].strip().replace('"', '\\"')
new_src.append(f'{indent}__staging_overload__.metadata("{arg}")')
else:
new_src.append(line)
Expand Down
7 changes: 5 additions & 2 deletions runtime/gpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include <stdexcept>
#include <type_traits>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include <cuda_fp16.h>

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>

#include "gpu_context.h"

Expand Down
155 changes: 139 additions & 16 deletions src/codegen/code_gen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

namespace freetensor {

static std::string genCUBLASType(DataType dtype) {
static std::string genCUBLASType(const DataType &dtype) {
switch (dtype.base()) {
case DataType::Float64:
return "CUDA_R_64F";
case DataType::Float32:
return "CUDA_R_32F";
case DataType::Float16:
return "CUDA_R_16F";
case DataType::Int64:
return "CUDA_R_64I";
case DataType::Int32:
Expand All @@ -30,9 +32,32 @@ static std::string genCUBLASType(DataType dtype) {
}
}

static std::string genCUTLASSType(const DataType &dtype) {
switch (dtype.base()) {
case DataType::Float64:
return "double";
case DataType::Float32:
return "float";
case DataType::Float16:
return "cutlass::half_t";
case DataType::Int64:
return "int64_t";
case DataType::Int32:
return "int32_t";
case DataType::Bool:
return "bool";
default:
ASSERT(false);
}
}

static bool canUseTensorCore(const Ref<GPUTarget> &target, DataType dtypeA,
DataType dtypeB, DataType dtypeC) {
// TODO: fp16 is supported after sm70
if (target->computeCapability().first >= 7 && dtypeA == DataType::Float16 &&
dtypeB == DataType::Float16 &&
(dtypeC == DataType::Float16 || dtypeC == DataType::Float32)) {
return true;
}
if (target->computeCapability().first >= 8 && dtypeA == DataType::Float64 &&
dtypeB == DataType::Float64 && dtypeC == DataType::Float64) {
return true;
Expand All @@ -49,7 +74,17 @@ CodeGenCUDA::genMdPtrType(const VarDef &def, bool isConst) {
// Use pointer instead of reference for scalars, because when passing an
// argument from host to a kernel, a reference means copy the value from
// CPU to GPU, while a pointer means passing the address

// NOTE: `[=]` implicitly capturing `this` is deprecated in C++20. Using
// `[=]` will trigger a warning in GCC (because of deprecation), but
// using
// `[=, this]` will trigger a warning in Clang<17 (because it will think
// `this` is duplicated).
#if defined(__clang__) && __clang_major__ < 17
return [=](std::ostream &os) -> std::ostream & {
#else
return [=, this](std::ostream &os) -> std::ostream & {
#endif
if (isConst) {
os << "const ";
}
Expand Down Expand Up @@ -77,6 +112,14 @@ void CodeGenCUDA::genMdPtrDef(const VarDef &def,
CodeGenC<CodeGenCUDAStream>::genMdPtrDef(def, genRawPtr, isConst);
}

std::string CodeGenCUDA::gen(const DataType &dtype) {
if (dtype == DataType::Float16) {
return "__half";
} else {
return CodeGenC::gen(dtype);
}
}

void CodeGenCUDA::genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
const std::string &shapePtr,
const std::string &dimPtr) {
Expand Down Expand Up @@ -383,6 +426,72 @@ void CodeGenCUDA::visit(const Ceil &op) {
os() << ")";
}

void CodeGenCUDA::visit(const Cast &op) {
if (op->destType_.base() == DataType::Float16) {
switch (op->expr_->dtype().base()) {
case DataType::Int32:
os() << "__int2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Int64:
os() << "__ll2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float16:
(*this)(op->expr_);
break;
case DataType::Float32:
os() << "__float2half_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float64:
os() << "__double2half("; // Always `_rn` (round to nearest even)
(*this)(op->expr_);
os() << ")";
break;
default:
throw InvalidProgram("Converting from " +
freetensor::toString(op->dtype()) +
" to float16 is not supported");
}
} else if (op->expr_->dtype().base() == DataType::Float16) {
switch (op->destType_.base()) {
case DataType::Int32:
os() << "__half2int_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Int64:
os() << "__half2ll_rn(";
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float16:
(*this)(op->expr_);
break;
case DataType::Float32:
os() << "__half2float("; // Short to long, no rounding is needed
(*this)(op->expr_);
os() << ")";
break;
case DataType::Float64:
os() << "__half2double("; // Short to long, no rounding is needed
(*this)(op->expr_);
os() << ")";
break;
default:
throw InvalidProgram("Converting from float16 to " +
freetensor::toString(op->dtype()) +
" is not supported");
}
} else {
CodeGenC::visit(op);
}
}

void CodeGenCUDA::visit(const Store &op) {
if (buffer(op->var_)->mtype() == MemType::GPUWarp) {
auto id = mangle(op->var_);
Expand Down Expand Up @@ -843,50 +952,64 @@ void CodeGenCUDA::visit(const MatMul &op) {
}

makeIndent();
os() << "cutlass::gemm::device::Gemm<" << gen(op->a_->dtype()) << ", "
beginBlock();
makeIndent();
os() << "using Gemm = cutlass::gemm::device::Gemm<"
<< genCUTLASSType(op->a_->dtype()) << ", "
<< (transA ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->b_->dtype()) << ", "
<< ", " << genCUTLASSType(op->b_->dtype()) << ", "
<< (transB ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->c_->dtype()) << ", "
<< ", " << genCUTLASSType(op->c_->dtype()) << ", "
<< (transC ? "cutlass::layout::ColumnMajor"
: "cutlass::layout::RowMajor")
<< ", " << gen(op->c_->dtype()) // TODO: accumulator type
<< ", "
<< genCUTLASSType(op->c_->dtype()) // TODO: accumulator type
<< ", "
<< (canUseTensorCore(target_, op->a_->dtype(), op->b_->dtype(),
op->c_->dtype())
? "cutlass::arch::OpClassTensorOp"
: "cutlass::arch::OpClassSimt")
<< ", FT_CUTLASS_ARCH> gemm;" << std::endl;
<< ", FT_CUTLASS_ARCH>;" << std::endl;
makeIndent();
os() << "checkCutlassError(gemm({{";
os() << "Gemm gemm;" << std::endl;
// In order for clearer error message, please keep the explicit argument
// types in the following statement.
makeIndent();
os() << "checkCutlassError(gemm(Gemm::Arguments{{";
(*this)(m);
os() << ", ";
(*this)(n);
os() << ", ";
(*this)(k);
os() << "}, {&";
os() << "}, Gemm::TensorRefA{(const " << genCUTLASSType(op->a_->dtype())
<< "*)&";
(*this)(a);
os() << ", ";
(*this)(lda);
os() << "}, {&";
os() << "}, Gemm::TensorRefB{(const " << genCUTLASSType(op->b_->dtype())
<< "*)&";
(*this)(b);
os() << ", ";
(*this)(ldb);
os() << "}, {&";
os() << "}, Gemm::TensorRefC{(const " << genCUTLASSType(op->c_->dtype())
<< "*)&";
(*this)(c);
os() << ", ";
(*this)(ldc);
os() << "}, {&";
os() << "}, Gemm::TensorRefD{(" << genCUTLASSType(op->c_->dtype())
<< "*)&";
(*this)(c);
os() << ", ";
(*this)(ldc);
os() << "}, {";
os() << "}, Gemm::EpilogueOutputOp::Params{("
<< genCUTLASSType(op->c_->dtype()) << ")(";
(*this)(op->alpha_);
os() << ", ";
os() << "), (" << genCUTLASSType(op->c_->dtype()) << ")(";
(*this)(op->beta_);
os() << "}}, nullptr, __stream));" << std::endl;
os() << ")}}, nullptr, __stream));" << std::endl;
endBlock();
break;
}

Expand Down Expand Up @@ -988,7 +1111,7 @@ extern "C" {
for (size_t i = 0, iEnd = shape.size(); i < iEnd; i++) {
os << "__ByValArray<";
}
os << CodeGenCUDA::gen(tensor->dtype());
os << visitor.gen(tensor->dtype());
for (auto it = shape.rbegin(); it != shape.rend(); it++) {
ASSERT((*it)->nodeType() == ASTNodeType::IntConst);
os << ", " << (*it).as<IntConstNode>()->val_ << ">";
Expand Down
Loading
Loading