diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index fec1047b83ec5..4bd88d7b02aa1 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -27,6 +27,9 @@ def _get_seed(key): def PRNGKey(seed): return ivy.array([0, seed % 4294967295 - (seed // 4294967295)], dtype=ivy.int64) +def _remove_axis(shape, axis): + return shape[:axis] + shape[axis + 1 :] + @handle_jax_dtype @to_ivy_arrays_and_back @@ -371,8 +374,6 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0): low=minval, high=maxval, shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1]) ) - -@handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( { @@ -383,9 +384,41 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0): }, "jax", ) + +def categorical(key, logits, axis, shape=None): + _get_seed(key) + logits_arr = ivy.asarray(logits) + + if axis >= 0: + axis -= len(logits_arr.shape) + batch_shape = tuple(_remove_axis(logits_arr.shape, axis)) + + if shape is None: + shape = batch_shape + else: + shape = tuple(shape) + if shape != batch_shape: + raise ValueError( ++ f"Shape {shape} is not compatible with reference shape {batch_shape}" + ) + + shape_prefix = shape[: len(shape) - len(batch_shape)] + logits_shape = list(shape[len(shape) - len(batch_shape) :]) + logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) + + gumbel_noise = gumbel(key, ivy.array(logits_shape), logits_arr.dtype) + expanded_logits = ivy.expand_dims(logits_arr, axis=axis) + noisy_logits = gumbel_noise + expanded_logits + + # Use Ivy's argmax to get indices + indices = ivy.argmax(noisy_logits, axis=axis) + + return indices + def weibull_min(key, scale, concentration, shape=(), dtype="float64"): seed = _get_seed(key) uniform_x = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype) x = 1 - uniform_x weibull = x ** (concentration - 1) * -ivy.log(x / scale) return weibull + diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index aa63058977a9b..bec5396f8dbab 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -1518,3 +1518,58 @@ def call(): for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape + + +@pytest.mark.xfail +@handle_frontend_test( + fn_tree="jax.random.categorical", + dtype_key=helpers.dtype_and_values( + available_dtypes=["uint32"], + min_value=0, + max_value=2000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), + shape=helpers.get_shape( + min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False + ), + dtype=helpers.get_dtypes("float", full=False), +) +def test_jax_categorical( + *, + dtype_key, + shape, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + + input_dtype,key = dtype_key + + def call(): + return helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + key=key[0], + shape=shape, + dtype=dtype[0], + ) + ret = call() + if not ivy.exists(ret): + return + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape