Skip to content

Commit

Permalink
Added Categorical to jax frontend (#22146)
Browse files Browse the repository at this point in the history
Co-authored-by: Saeed Ashraf <[email protected]>
  • Loading branch information
VaishnaviMudaliar and saeedashrraf authored Aug 27, 2023
1 parent a88c50d commit 9d5f27e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
37 changes: 35 additions & 2 deletions ivy/functional/frontends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def _get_seed(key):
def PRNGKey(seed):
return ivy.array([0, seed % 4294967295 - (seed // 4294967295)], dtype=ivy.int64)

def _remove_axis(shape, axis):
return shape[:axis] + shape[axis + 1 :]


@handle_jax_dtype
@to_ivy_arrays_and_back
Expand Down Expand Up @@ -371,8 +374,6 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0):
low=minval, high=maxval, shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1])
)


@handle_jax_dtype
@to_ivy_arrays_and_back
@with_unsupported_dtypes(
{
Expand All @@ -383,9 +384,41 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0):
},
"jax",
)

def categorical(key, logits, axis, shape=None):
_get_seed(key)
logits_arr = ivy.asarray(logits)

if axis >= 0:
axis -= len(logits_arr.shape)
batch_shape = tuple(_remove_axis(logits_arr.shape, axis))

if shape is None:
shape = batch_shape
else:
shape = tuple(shape)
if shape != batch_shape:
raise ValueError(
+ f"Shape {shape} is not compatible with reference shape {batch_shape}"
)

shape_prefix = shape[: len(shape) - len(batch_shape)]
logits_shape = list(shape[len(shape) - len(batch_shape) :])
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])

gumbel_noise = gumbel(key, ivy.array(logits_shape), logits_arr.dtype)
expanded_logits = ivy.expand_dims(logits_arr, axis=axis)
noisy_logits = gumbel_noise + expanded_logits

# Use Ivy's argmax to get indices
indices = ivy.argmax(noisy_logits, axis=axis)

return indices

def weibull_min(key, scale, concentration, shape=(), dtype="float64"):
seed = _get_seed(key)
uniform_x = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype)
x = 1 - uniform_x
weibull = x ** (concentration - 1) * -ivy.log(x / scale)
return weibull

55 changes: 55 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,3 +1518,58 @@ def call():
for u, v in zip(ret_np, ret_from_np):
assert u.dtype == v.dtype
assert u.shape == v.shape


@pytest.mark.xfail
@handle_frontend_test(
fn_tree="jax.random.categorical",
dtype_key=helpers.dtype_and_values(
available_dtypes=["uint32"],
min_value=0,
max_value=2000,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
max_dim_size=2,
),
shape=helpers.get_shape(
min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False
),
dtype=helpers.get_dtypes("float", full=False),
)
def test_jax_categorical(
*,
dtype_key,
shape,
dtype,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):

input_dtype,key = dtype_key

def call():
return helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
key=key[0],
shape=shape,
dtype=dtype[0],
)
ret = call()
if not ivy.exists(ret):
return
ret_np, ret_from_np = ret
ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw)
ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw)
for u, v in zip(ret_np, ret_from_np):
assert u.dtype == v.dtype
assert u.shape == v.shape

0 comments on commit 9d5f27e

Please sign in to comment.