From 08510a3aa31a09c45df4602c484ac81056f0c4e8 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Thu, 22 Jun 2023 08:31:47 +0200 Subject: [PATCH] Test support for bfloat16 using ml_dtypes. --- .github/workflows/ci.yml | 8 +++---- tests/array.rs | 47 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 376ded49c..14713e7d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,7 +67,7 @@ jobs: shell: python - name: Test run: | - pip install numpy + pip install numpy ml_dtypes cargo test --all-features # Not on PyPy, because no embedding API if: ${{ !startsWith(matrix.python-version, 'pypy') }} @@ -101,7 +101,7 @@ jobs: continue-on-error: true - uses: taiki-e/install-action@valgrind - run: | - pip install numpy + pip install numpy ml_dtypes cargo test --all-features --release env: CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1 @@ -115,7 +115,7 @@ jobs: - uses: Swatinem/rust-cache@v2 continue-on-error: true - run: | - pip install numpy + pip install numpy ml_dtypes cargo install --locked cargo-careful cargo careful test --all-features @@ -201,7 +201,7 @@ jobs: python-version: 3.9 architecture: x64 - name: Install numpy - run: pip install numpy + run: pip install numpy ml_dtypes - uses: Swatinem/rust-cache@v2 continue-on-error: true - uses: dtolnay/rust-toolchain@stable diff --git a/tests/array.rs b/tests/array.rs index 6cfa8ac63..3564c9c76 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1,7 +1,7 @@ use std::mem::size_of; #[cfg(feature = "half")] -use half::f16; +use half::{bf16, f16}; use ndarray::{array, s, Array1, Dim}; use numpy::{ dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr, @@ -527,7 +527,7 @@ fn reshape() { #[cfg(feature = "half")] #[test] -fn half_works() { +fn half_f16_works() { Python::with_gil(|py| { let np = py.eval("__import__('numpy')", None, None).unwrap(); let locals = [("np", np)].into_py_dict(py); @@ -558,7 +558,48 @@ fn half_works() { py_run!( py, array np, - "np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))" + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))" + ); + }); +} + +#[cfg(feature = "half")] +#[test] +fn half_bf16_works() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + // NumPy itself does not provide a `bfloat16` dtype itself, + // so we import ml_dtypes which does register such a dtype. + let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap(); + let locals = [("np", np), ("mldt", mldt)].into_py_dict(py); + + let array = py + .eval( + "np.array([[1, 2], [3, 4]], dtype='bfloat16')", + None, + Some(locals), + ) + .unwrap() + .downcast::>() + .unwrap(); + + assert_eq!( + array.readonly().as_array(), + array![ + [bf16::from_f32(1.0), bf16::from_f32(2.0)], + [bf16::from_f32(3.0), bf16::from_f32(4.0)] + ] + ); + + array + .readwrite() + .as_array_mut() + .map_inplace(|value| *value *= bf16::from_f32(2.0)); + + py_run!( + py, + array np, + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))" ); }); }