Skip to content

Commit

Permalink
feat(frontend): implement array generating function make_circles alon…
Browse files Browse the repository at this point in the history
…g with test (passing) for sklearn
  • Loading branch information
Ishticode committed Aug 26, 2023
1 parent 920c09c commit 6ae1289
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions ivy/functional/frontends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"jax": "0.4.14",
"scipy": "1.10.1",
"paddle": "2.5.1",
"sklearn": "1.3.0",
}


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/frontends/sklearn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import _samples_generator
from ._samples_generator import *
21 changes: 21 additions & 0 deletions ivy/functional/frontends/sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ivy
import numbers


def make_circles(n_samples=100, *, shuffle=True, noise=None, random_state=None, factor=0.8):
# numbers.Integral also includes bool
if isinstance(n_samples, numbers.Integral):
n_samples_out = n_samples // 2
n_samples_in = n_samples - n_samples_out
elif isinstance(n_samples, tuple):
n_samples_out, n_samples_in = n_samples

outer_circ_x = ivy.cos(ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False))
outer_circ_y = ivy.sin(ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False))
inner_circ_x = ivy.cos(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor
inner_circ_y = ivy.sin(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor
X = ivy.concat([ivy.stack([outer_circ_x, outer_circ_y], axis=1),
ivy.stack([inner_circ_x, inner_circ_y], axis=1)], axis=0)
y = ivy.concat([ivy.zeros(n_samples_out, dtype=ivy.int32),
ivy.ones(n_samples_in, dtype=ivy.int32)], axis=0)
return X, y
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test


@handle_frontend_test(
fn_tree="sklearn.datasets.make_circles",
n_samples=helpers.ints(min_value=1, max_value=10),
)
def test_sklearn_make_circles(
n_samples,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
helpers.test_frontend_function(
n_samples=n_samples,
input_dtypes=["int32"],
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
frontend=frontend,
on_device=on_device,
test_values=False,
)

0 comments on commit 6ae1289

Please sign in to comment.