From 8c92df95c28b02c53a9ec374e5ff9ccec397eaa9 Mon Sep 17 00:00:00 2001 From: fineg74 <61437305+fineg74@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:20:09 -0800 Subject: [PATCH] [SYCL][ESIMD]Limit bfloat16 operators to scalars to enable operations with simd vectors (#12089) The purpose of this change is to limit operators defined for bfloat16 to scalar types to allow arithmetic operations between bfloat16 scalars and simd vectors. This allows to use simd operators that are defined separately and support operations between vectors and scalars --- .../esimd/detail/bfloat16_type_traits.hpp | 2 + sycl/include/sycl/ext/oneapi/bfloat16.hpp | 222 +++++++++++++----- .../bfloat16_half_vector_plus_eq_scalar.cpp | 5 +- .../bfloat16_vector_plus_scalar.cpp | 100 ++++++++ .../bfloat16_vector_plus_scalar_pvc.cpp | 14 ++ 5 files changed, 282 insertions(+), 61 deletions(-) create mode 100644 sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp create mode 100644 sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp diff --git a/sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp b/sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp index 2d0c17a3c6100..6a935ae9c3390 100644 --- a/sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp +++ b/sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp @@ -94,6 +94,8 @@ inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) { return O; } +template <> struct is_esimd_arithmetic_type : std::true_type {}; + } // namespace ext::intel::esimd::detail } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index bd3052e9a0488..fe0324f28ed53 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -132,69 +132,175 @@ class bfloat16 { #endif } -// Increment and decrement operators overloading + bfloat16 &operator+=(const bfloat16 &rhs) { + value = from_float(to_float(value) + to_float(rhs.value)); + return *this; + } + + bfloat16 &operator-=(const bfloat16 &rhs) { + value = from_float(to_float(value) - to_float(rhs.value)); + return *this; + } + + bfloat16 &operator*=(const bfloat16 &rhs) { + value = from_float(to_float(value) * to_float(rhs.value)); + return *this; + } + + bfloat16 &operator/=(const bfloat16 &rhs) { + value = from_float(to_float(value) / to_float(rhs.value)); + return *this; + } + + // Operator ++, -- + bfloat16 &operator++() { + float f = to_float(value); + value = from_float(++f); + return *this; + } + + bfloat16 operator++(int) { + bfloat16 ret(*this); + operator++(); + return ret; + } + + bfloat16 &operator--() { + float f = to_float(value); + value = from_float(--f); + return *this; + } + + bfloat16 operator--(int) { + bfloat16 ret(*this); + operator--(); + return ret; + } + +// Operator +, -, *, / #define OP(op) \ - friend bfloat16 &operator op(bfloat16 &lhs) { \ - float f = to_float(lhs.value); \ - lhs.value = from_float(op f); \ - return lhs; \ - } \ - friend bfloat16 operator op(bfloat16 &lhs, int) { \ - bfloat16 old = lhs; \ - operator op(lhs); \ - return old; \ - } - OP(++) - OP(--) + friend bfloat16 operator op(const bfloat16 lhs, const bfloat16 rhs) { \ + return to_float(lhs.value) op to_float(rhs.value); \ + } \ + friend double operator op(const bfloat16 lhs, const double rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend double operator op(const double lhs, const bfloat16 rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend float operator op(const bfloat16 lhs, const float rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend float operator op(const float lhs, const bfloat16 rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 lhs, const int rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const int lhs, const bfloat16 rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 lhs, const long rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const long lhs, const bfloat16 rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 lhs, const long long rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const long long lhs, const bfloat16 rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 &lhs, const unsigned int &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const unsigned int &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 &lhs, const unsigned long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const unsigned long &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bfloat16 operator op(const bfloat16 &lhs, \ + const unsigned long long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bfloat16 operator op(const unsigned long long &lhs, \ + const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } + OP(+) + OP(-) + OP(*) + OP(/) + #undef OP - // Assignment operators overloading +// Operator ==, !=, <, >, <=, >= #define OP(op) \ - friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \ - float f = static_cast(lhs); \ - f op static_cast(rhs); \ - return lhs = f; \ - } \ - template \ - friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \ - float f = static_cast(lhs); \ - f op static_cast(rhs); \ - return lhs = f; \ - } \ - template friend T &operator op(T &lhs, const bfloat16 &rhs) { \ - float f = static_cast(lhs); \ - f op static_cast(rhs); \ - return lhs = f; \ - } - OP(+=) - OP(-=) - OP(*=) - OP(/=) -#undef OP + friend bool operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \ + return to_float(lhs.value) op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const double &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const double &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const float &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const float &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const int &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const int &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const long &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const long long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const long long &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const unsigned int &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const unsigned int &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, const unsigned long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const unsigned long &lhs, const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } \ + friend bool operator op(const bfloat16 &lhs, \ + const unsigned long long &rhs) { \ + return to_float(lhs.value) op rhs; \ + } \ + friend bool operator op(const unsigned long long &lhs, \ + const bfloat16 &rhs) { \ + return lhs op to_float(rhs.value); \ + } + OP(==) + OP(!=) + OP(<) + OP(>) + OP(<=) + OP(>=) -// Binary operators overloading -#define OP(type, op) \ - friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \ - return type{static_cast(lhs) op static_cast(rhs)}; \ - } \ - template \ - friend type operator op(const bfloat16 &lhs, const T &rhs) { \ - return type{static_cast(lhs) op static_cast(rhs)}; \ - } \ - template \ - friend type operator op(const T &lhs, const bfloat16 &rhs) { \ - return type{static_cast(lhs) op static_cast(rhs)}; \ - } - OP(bfloat16, +) - OP(bfloat16, -) - OP(bfloat16, *) - OP(bfloat16, /) - OP(bool, ==) - OP(bool, !=) - OP(bool, <) - OP(bool, >) - OP(bool, <=) - OP(bool, >=) #undef OP // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported diff --git a/sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp b/sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp index e6b456100762f..17ef57a34f3b6 100644 --- a/sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp +++ b/sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp @@ -91,12 +91,11 @@ int main() { } #ifdef USE_BF16 -// TODO: Reenable once the issue with bfloat16 is resolved -// Passed &= test(Q); + Passed &= test(Q); #endif #ifdef USE_TF32 Passed &= test(Q); #endif std::cout << (Passed ? "Passed\n" : "FAILED\n"); return Passed ? 0 : 1; -} \ No newline at end of file +} diff --git a/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp b/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp new file mode 100644 index 0000000000000..7f6388b0ffc2e --- /dev/null +++ b/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp @@ -0,0 +1,100 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +//==- bfloat16_vector_plus_scalar.cpp - Test for bfloat16 operators ------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "../esimd_test_utils.hpp" +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel::esimd; +using namespace sycl::ext::intel::experimental::esimd; + +template ESIMD_NOINLINE bool test(queue Q) { + std::cout << "Testing T=" << esimd_test::type_name() << "...\n"; + + constexpr int N = 8; + + constexpr int NumOps = 4; + constexpr int CSize = NumOps * N; + + T *Mem = malloc_shared(CSize, Q); + T TOne = static_cast(1); + T TTen = static_cast(10); + + Q.single_task([=]() SYCL_ESIMD_KERNEL { + { + simd Vec(TOne); + Vec = Vec + TTen; + Vec.copy_to(Mem); + } + { + simd Vec(TOne); + Vec = Vec - TTen; + Vec.copy_to(Mem + N); + } + { + simd Vec(TOne); + Vec = Vec * TTen; + Vec.copy_to(Mem + 2 * N); + } + { + simd Vec(TOne); + Vec = Vec / TTen; + Vec.copy_to(Mem + 3 * N); + } + }).wait(); + + bool ReturnValue = true; + for (int i = 0; i < N; ++i) { + if (Mem[i] != TOne + TTen) { + ReturnValue = false; + break; + } + if (Mem[i + N] != TOne - TTen) { + ReturnValue = false; + break; + } + if (Mem[i + 2 * N] != TOne * TTen) { + ReturnValue = false; + break; + } + if (!((Mem[i + 3 * N] == (TOne / TTen)) || + (std::abs((double)(Mem[i + 3 * N] - (TOne / TTen)) / + (double)(TOne / TTen)) <= 0.001))) { + ReturnValue = false; + break; + } + } + + free(Mem, Q); + return ReturnValue; +} + +int main() { + queue Q; + esimd_test::printTestLabel(Q); + + bool SupportsHalf = Q.get_device().has(aspect::fp16); + + bool Passed = true; + Passed &= test(Q); + Passed &= test(Q); + if (SupportsHalf) { + Passed &= test(Q); + } +#ifdef USE_BF16 + Passed &= test(Q); +#endif +#ifdef USE_TF32 + Passed &= test(Q); +#endif + std::cout << (Passed ? "Passed\n" : "FAILED\n"); + return Passed ? 0 : 1; +} \ No newline at end of file diff --git a/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp b/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp new file mode 100644 index 0000000000000..66fa9388b151e --- /dev/null +++ b/sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp @@ -0,0 +1,14 @@ +//==- bfloat16_vector_plus_scalar_pvc.cpp - Test for bfloat16 operators -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: gpu-intel-pvc +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#define USE_BF16 +#define USE_TF32 +#include "bfloat16_vector_plus_scalar.cpp" \ No newline at end of file