From 69e3f0d37de437337d3bad91af651237c3402a80 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Mon, 25 Nov 2024 03:30:19 -0800 Subject: [PATCH] [pallas:mosaic_gpu] Add test for FragmentedArray.bitcast. PiperOrigin-RevId: 699919048 --- .../mosaic/gpu/fragmented_array.py | 10 ++++-- tests/mosaic/gpu_test.py | 34 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index e1ee37f3d24d..094c683c1695 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -463,8 +463,8 @@ def __init__( if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): raise TypeError( - "is_signed must only be non-None if the MLIR type is an integer" - f" type, got {_is_signed=} for {self.mlir_dtype}" + "is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {_is_signed=} for {self.mlir_dtype}" ) match self.layout: @@ -962,6 +962,12 @@ def fast_instr(x): return fast_instr def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): + if (output_is_signed is not None) != ir.IntegerType.isinstance(elt): + raise TypeError( + "output_is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {output_is_signed=} for {elt}" + ) + if elt == self.mlir_dtype: return self reg_type = self.registers.flat[0].type diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 87dc2c452041..aeddbc7e033d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1577,6 +1577,40 @@ def kernel(ctx, _): _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)() + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast(self, in_dtype, out_dtype): + out_ir_type = utils.dtype_to_ir_type(out_dtype) + in_is_signed = utils.is_signed(in_dtype) + out_is_signed = utils.is_signed(out_dtype) + + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed) + arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed) + arr.store_untiled(out) + + x = jnp.arange(256, dtype=in_dtype) + reference = jax.lax.bitcast_convert_type(x, out_dtype) + + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + reference, + None, + )(x) + np.testing.assert_array_equal(result, reference) + class ProfilerTest(TestCase):