diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index 609580e73a2896..c3d3863de6761c 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -14,6 +14,10 @@ The following convention is followed: or any of the C++ headers in which the type is defined. * If `emitc.array` with a dimension of size zero, then the code requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html). +* If `_Float16` is used, the code requires the support of C additional + floating types. +* If `__bf16` is used, the code requires a compiler that supports it, such as + GCC or Clang. * Else the generated code is compatible with C99. These restrictions are neither inherent to the EmitC dialect itself nor to the diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 660a5b1b1309b4..4aa1d8af2edfa4 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -114,7 +114,21 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) { } bool mlir::emitc::isSupportedFloatType(Type type) { - return isa(type); + if (auto floatType = llvm::dyn_cast(type)) { + switch (floatType.getWidth()) { + case 16: { + if (llvm::isa(type)) + return true; + return false; + } + case 32: + case 64: + return isa(type); + default: + return false; + } + } + return false; } bool mlir::emitc::isPointerWideType(Type type) { diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index c043582b7be9c6..30657d8fccb154 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1258,6 +1258,12 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { val.toString(strValue, 0, 0, false); os << strValue; switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { + case llvm::APFloatBase::S_IEEEhalf: + os << "f16"; + break; + case llvm::APFloatBase::S_BFloat: + os << "bf16"; + break; case llvm::APFloatBase::S_IEEEsingle: os << "f"; break; @@ -1277,17 +1283,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { // Print floating point attributes. if (auto fAttr = dyn_cast(attr)) { - if (!isa(fAttr.getType())) { - return emitError(loc, - "expected floating point attribute to be f32 or f64"); + if (!isa( + fAttr.getType())) { + return emitError( + loc, "expected floating point attribute to be f16, bf16, f32 or f64"); } printFloat(fAttr.getValue()); return success(); } if (auto dense = dyn_cast(attr)) { - if (!isa(dense.getElementType())) { - return emitError(loc, - "expected floating point attribute to be f32 or f64"); + if (!isa( + dense.getElementType())) { + return emitError( + loc, "expected floating point attribute to be f16, bf16, f32 or f64"); } os << '{'; interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); @@ -1640,6 +1648,14 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { } if (auto fType = dyn_cast(type)) { switch (fType.getWidth()) { + case 16: { + if (llvm::isa(type)) + return (os << "_Float16"), success(); + else if (llvm::isa(type)) + return (os << "__bf16"), success(); + else + return emitError(loc, "cannot emit float type ") << type; + } case 32: return (os << "float"), success(); case 64: diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 1342eadc434e2b..27affa281a8650 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -31,36 +31,35 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> { } // ----- - -func.func @arith_cast_bf16(%arg0: bf16) -> i32 { +func.func @arith_cast_f80(%arg0: f80) -> i32 { // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : bf16 to i32 + %t = arith.fptosi %arg0 : f80 to i32 return %t: i32 } // ----- -func.func @arith_cast_f16(%arg0: f16) -> i32 { +func.func @arith_cast_f128(%arg0: f128) -> i32 { // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : f16 to i32 + %t = arith.fptosi %arg0 : f128 to i32 return %t: i32 } // ----- -func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 { +func.func @arith_cast_to_f80(%arg0: i32) -> f80 { // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} - %t = arith.sitofp %arg0 : i32 to bf16 - return %t: bf16 + %t = arith.sitofp %arg0 : i32 to f80 + return %t: f80 } // ----- -func.func @arith_cast_to_f16(%arg0: i32) -> f16 { +func.func @arith_cast_to_f128(%arg0: i32) -> f128 { // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} - %t = arith.sitofp %arg0 : i32 to f16 - return %t: f16 + %t = arith.sitofp %arg0 : i32 to f128 + return %t: f128 } // ----- @@ -135,23 +134,6 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec return %divui: vector<5xi32> } -// ----- - -func.func @arith_extf_to_bf16(%arg0: f8E4M3FN) { - // expected-error @+1 {{failed to legalize operation 'arith.extf'}} - %ext = arith.extf %arg0 : f8E4M3FN to bf16 - return -} - -// ----- - -func.func @arith_extf_to_f16(%arg0: f8E4M3FN) { - // expected-error @+1 {{failed to legalize operation 'arith.extf'}} - %ext = arith.extf %arg0 : f8E4M3FN to f16 - return -} - - // ----- func.func @arith_extf_to_tf32(%arg0: f8E4M3FN) { @@ -194,22 +176,6 @@ func.func @arith_truncf_to_tf32(%arg0: f64) { // ----- -func.func @arith_truncf_to_f16(%arg0: f64) { - // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} - %trunc = arith.truncf %arg0 : f64 to f16 - return -} - -// ----- - -func.func @arith_truncf_to_bf16(%arg0: f64) { - // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} - %trunc = arith.truncf %arg0 : f64 to bf16 - return -} - -// ----- - func.func @arith_truncf_to_f8E4M3FN(%arg0: f64) { // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} %trunc = arith.truncf %arg0 : f64 to f8E4M3FN diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 836d8aedefc1f0..dee9cc97a14493 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -46,9 +46,9 @@ memref.global "nested" constant @nested_global : memref<3x7xf32> // ----- -func.func @unsupported_type_f16() { +func.func @unsupported_type_f128() { // expected-error@+1 {{failed to legalize operation 'memref.alloca'}} - %0 = memref.alloca() : memref<4xf16> + %0 = memref.alloca() : memref<4xf128> return } diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index c89a9a0e9bd4a3..87ad5b4cf1a9ca 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -132,22 +132,6 @@ func.func @illegal_f8E5M2FNUZ_type(%arg0: f8E5M2FNUZ, %arg1: f8E5M2FNUZ) { // ----- -func.func @illegal_f16_type(%arg0: f16, %arg1: f16) { - // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f16'}} - %mul = "emitc.mul" (%arg0, %arg1) : (f16, f16) -> f16 - return -} - -// ----- - -func.func @illegal_bf16_type(%arg0: bf16, %arg1: bf16) { - // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'bf16'}} - %mul = "emitc.mul" (%arg0, %arg1) : (bf16, bf16) -> bf16 - return -} - -// ----- - func.func @illegal_f80_type(%arg0: f80, %arg1: f80) { // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f80'}} %mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80 diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir index 3658455d669438..d3656f830c48c3 100644 --- a/mlir/test/Target/Cpp/const.mlir +++ b/mlir/test/Target/Cpp/const.mlir @@ -11,6 +11,8 @@ func.func @emitc_constant() { %c6 = "emitc.constant"(){value = 2 : index} : () -> index %c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32 %f64 = "emitc.constant"(){value = 4.0 : f64} : () -> f64 + %f16 = "emitc.constant"(){value = 2.0 : f16} : () -> f16 + %bf16 = "emitc.constant"(){value = 4.0 : bf16} : () -> bf16 %c8 = "emitc.constant"(){value = dense<0> : tensor} : () -> tensor %c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex> %c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> @@ -26,6 +28,8 @@ func.func @emitc_constant() { // CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2; // CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = 2.000000000e+00f; // CPP-DEFAULT-NEXT: double [[F64:[^ ]*]] = 4.00000000000000000e+00; +// CPP-DEFAULT-NEXT: _Float16 [[F16:[^ ]*]] = 2.00000e+00f16; +// CPP-DEFAULT-NEXT: __bf16 [[BF16:[^ ]*]] = 4.0000e+00bf16; // CPP-DEFAULT-NEXT: Tensor [[V8:[^ ]*]] = {0}; // CPP-DEFAULT-NEXT: Tensor [[V9:[^ ]*]] = {0, 1}; // CPP-DEFAULT-NEXT: Tensor [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f}; @@ -40,6 +44,8 @@ func.func @emitc_constant() { // CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]]; // CPP-DECLTOP-NEXT: float [[V7:[^ ]*]]; // CPP-DECLTOP-NEXT: double [[F64:[^ ]*]]; +// CPP-DECLTOP-NEXT: _Float16 [[F16:[^ ]*]]; +// CPP-DECLTOP-NEXT: __bf16 [[BF16:[^ ]*]]; // CPP-DECLTOP-NEXT: Tensor [[V8:[^ ]*]]; // CPP-DECLTOP-NEXT: Tensor [[V9:[^ ]*]]; // CPP-DECLTOP-NEXT: Tensor [[V10:[^ ]*]]; @@ -52,6 +58,8 @@ func.func @emitc_constant() { // CPP-DECLTOP-NEXT: [[V6]] = 2; // CPP-DECLTOP-NEXT: [[V7]] = 2.000000000e+00f; // CPP-DECLTOP-NEXT: [[F64]] = 4.00000000000000000e+00; +// CPP-DECLTOP-NEXT: [[F16]] = 2.00000e+00f16; +// CPP-DECLTOP-NEXT: [[BF16]] = 4.0000e+00bf16; // CPP-DECLTOP-NEXT: [[V8]] = {0}; // CPP-DECLTOP-NEXT: [[V9]] = {0, 1}; // CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f}; diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir index deda383b3b0a72..e7f935c7374382 100644 --- a/mlir/test/Target/Cpp/types.mlir +++ b/mlir/test/Target/Cpp/types.mlir @@ -22,6 +22,10 @@ func.func @ptr_types() { emitc.call_opaque "f"() {template_args = [!emitc.ptr]} : () -> () // CHECK-NEXT: f(); emitc.call_opaque "f"() {template_args = [!emitc.ptr]} : () -> () + // CHECK-NEXT: f<_Float16*>(); + emitc.call_opaque "f"() {template_args = [!emitc.ptr]} : () -> () + // CHECK-NEXT: f<__bf16*>(); + emitc.call_opaque "f"() {template_args = [!emitc.ptr]} : () -> () // CHECK-NEXT: f(); emitc.call_opaque "f"() {template_args = [!emitc.ptr]} : () -> () // CHECK-NEXT: f();