From a88c50df69cad78f595bd73ee4bb6839d7e7d243 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:19:25 +0100 Subject: [PATCH] refactor(frontend): simplify train_size and test_size derivation in test_train_split function --- .../sklearn/model_selection/_split.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py index b331da8ab40b2..9738f472a6f97 100644 --- a/ivy/functional/frontends/sklearn/model_selection/_split.py +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -40,27 +40,15 @@ def train_test_split(*arrays, test_size=None, train_size=None, random_state=None # TODO: implement stratify if stratify is not None: raise NotImplementedError - n_arrays = len(arrays) - if n_arrays == 0: + if len(arrays) == 0: raise ValueError("At least one array required as input") if test_size is None and train_size is None: test_size = 0.25 n_samples = arrays[0].shape[0] - test_size_type, train_size_type = type(test_size), type(train_size) - if "f" in str(test_size_type): - n_test = ivy.ceil(test_size * n_samples) - elif "i" in str(test_size_type): - n_test = float(test_size) - else: - n_test = 0 - - if "f" in str(train_size_type): - n_train = ivy.floor(train_size * n_samples) - elif "i" in str(train_size_type): - n_train = float(train_size) - else: - n_train = 0 - + n_train = ivy.floor(train_size * n_samples) if isinstance(train_size, float) \ + else float(train_size) if isinstance(train_size, int) else None + n_test = ivy.ceil(test_size * n_samples) if isinstance(test_size, float) \ + else float(test_size) if isinstance(test_size, int) else None if train_size is None: n_train = n_samples - n_test elif test_size is None: