From 35d7af936aac6471c4097b14a8d958c2e3086040 Mon Sep 17 00:00:00 2001 From: chkoar Date: Mon, 3 Feb 2020 14:04:35 +0200 Subject: [PATCH] Relax reconstructor checks. Add test for simple lists. --- imblearn/utils/_validation.py | 19 +++++------------- imblearn/utils/estimator_checks.py | 32 +++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 0d02ea28b..df25a52a0 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -49,24 +49,15 @@ def _gets_props(self, array): def _transfrom(self, array, props): type_ = props["type"].lower() - msg = "Could not convert to {}".format(type_) if type_ == "list": ret = array.tolist() elif type_ == "dataframe": - try: - import pandas as pd - ret = pd.DataFrame(array, columns=props["columns"]) - ret = ret.astype(props["dtypes"]) - except Exception: - warnings.warn(msg) + import pandas as pd + ret = pd.DataFrame(array, columns=props["columns"]) + ret = ret.astype(props["dtypes"]) elif type_ == "series": - try: - import pandas as pd - ret = pd.Series(array, - dtype=props["dtypes"], - name=props["name"]) - except Exception: - warnings.warn(msg) + import pandas as pd + ret = pd.Series(array, dtype=props["dtypes"], name=props["name"]) else: ret = array return ret diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 8f094397d..5b5ef34fe 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -258,7 +258,7 @@ def check_samplers_pandas(name, Sampler): X_res_df, y_res_df = sampler.fit_resample(X_df, y_df) X_res, y_res = sampler.fit_resample(X, y) - # check that we return the same type for dataframes or seires types + # check that we return the same type for dataframes or series types assert isinstance(X_res_df, pd.DataFrame) assert isinstance(y_res_df, pd.DataFrame) assert isinstance(y_res_s, pd.Series) @@ -272,6 +272,36 @@ def check_samplers_pandas(name, Sampler): assert_allclose(y_res_s.to_numpy(), y_res) +def check_samplers_list(name, Sampler): + # Check that the can samplers handle simple lists + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + X_list = X.tolist() + y_list = y.tolist() + sampler = Sampler() + if isinstance(Sampler(), NearMiss): + samplers = [Sampler(version=version) for version in (1, 2, 3)] + + else: + samplers = [Sampler()] + + for sampler in samplers: + set_random_state(sampler) + X_res, y_res = sampler.fit_resample(X, y) + X_res_list, y_res_list = sampler.fit_resample(X_list, y_list) + + assert isinstance(X_res_list, list) + assert isinstance(y_res_list, list) + + assert_allclose(X_res, X_res_list) + assert_allclose(y_res, y_res_list) + + def check_samplers_multiclass_ova(name, Sampler): # Check that multiclass target lead to the same results than OVA encoding X, y = make_classification(