From a3146df58ae28485986818a1f1ac4059755a22d6 Mon Sep 17 00:00:00 2001 From: Mohamed Ibrahim Date: Sat, 23 Mar 2024 18:18:37 +0200 Subject: [PATCH] feat: add `ifftn` to jax frontend along with the test (#28550) passing --- ivy/functional/frontends/jax/numpy/fft.py | 8 ++++ .../test_jax/test_numpy/test_fft.py | 29 ++++++++++++++ .../test_experimental/test_nn/test_layers.py | 39 +++++++++++++++++++ 3 files changed, 76 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index d1500f394e5e1..125696d97faa3 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -70,6 +70,14 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) +@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy") +@to_ivy_arrays_and_back +def ifftn(a, s=None, axes=None, norm=None): + a = ivy.asarray(a, dtype=ivy.complex128) + a = ivy.ifftn(a, s=s, axes=axes, norm=norm) + return a + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 1030839ae8e07..09925136e06aa 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -4,6 +4,9 @@ # local import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_experimental.test_nn.test_layers import ( + _x_and_ifftn_jax, +) # fft @@ -242,6 +245,32 @@ def test_jax_numpy_ifft2( ) +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftn", + dtype_and_x=_x_and_ifftn_jax(), +) +def test_jax_numpy_ifftn( + dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device +): + input_dtype, x, s, axes, norm = dtype_and_x + + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + atol=1e-09, + rtol=1e-08, + a=x, + s=s, + axes=axes, + norm=norm, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index cbe9a2ace12ce..8289adef6238d 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -457,6 +457,45 @@ def _x_and_ifftn(draw): return dtype, x, s, axes, norm +@st.composite +def _x_and_ifftn_jax(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("complex")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e-10, + max_value=1e10, + ) + ) + axes = draw( + st.lists( + st.integers(0, len(x_dim) - 1), + min_size=1, + max_size=min(len(x_dim), 3), + unique=True, + ) + ) + norm = draw(st.sampled_from(["forward", "ortho", "backward"])) + + # Shape for s can be larger, smaller or equal to the size of the input + # along the axes specified by axes. + # Here, we're generating a list of integers corresponding to each axis in axes. + s = draw( + st.lists( + st.integers(min_fft_points, 256), min_size=len(axes), max_size=len(axes) + ) + ) + + return dtype, x, s, axes, norm + + @st.composite def _x_and_rfft(draw): min_fft_points = 2