diff --git a/tests/test_models.py b/tests/test_models.py index 0669ccf..374573b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -67,15 +67,15 @@ def test_utils(): stratified=True) assert max_acc_model['accuracy'] >= 0.5 - # test cross validation for SSML with LabelProp - params = {'gamma': 10, 'n_neighbors': 15, 'max_iter': 2022, 'tol': 0.5} - model = LabelProp(params=params) - max_acc_model = utils.cross_validation(model=model, - X=np.append(X, Ux, axis=0), - y=np.append(y, Uy, axis=0), - params=params, - stratified=True) - assert max_acc_model['accuracy'] >= 0.5 + # # test cross validation for SSML with LabelProp + # params = {'gamma': 10, 'n_neighbors': 15, 'max_iter': 2022, 'tol': 0.5} + # model = LabelProp(params=params) + # max_acc_model = utils.cross_validation(model=model, + # X=np.append(X, Ux, axis=0), + # y=np.append(y, Uy, axis=0), + # params=params, + # stratified=True) + # assert max_acc_model['accuracy'] >= 0.5 # data split for data visualization X_train, X_test, y_train, y_test = train_test_split(X,