From 1a85591e318431a89f7134295f43a1c2f1c37cfd Mon Sep 17 00:00:00 2001 From: Jordan Stomps Date: Mon, 31 Oct 2022 14:46:13 -0400 Subject: [PATCH] adding pytest for cotraining --- tests/test_models.py | 83 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 5b66f65..3a28206 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,6 +14,8 @@ import scripts.utils as utils # models from models.LogReg import LogReg +# models +from models.SSML.CoTraining import CoTraining # testing write import joblib import os @@ -190,3 +192,84 @@ def test_LogReg(): assert model_file.best['params'] == model.best['params'] os.remove(filename+ext) + + +def test_CoTraining(): + # test saving model input parameters + params = {'max_iter': 2022, 'tol': 0.5, 'C': 5.0} + model = CoTraining(params=params) + + assert model.model1.max_iter == params['max_iter'] + assert model.model1.tol == params['tol'] + assert model.model1.C == params['C'] + + assert model.model2.max_iter == params['max_iter'] + assert model.model2.tol == params['tol'] + assert model.model2.C == params['C'] + + X, Ux, y, Uy = train_test_split(spectra, + labels, + test_size=0.5, + random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, + y, + test_size=0.2, + random_state=0) + + # normalization + normalizer = StandardScaler() + normalizer.fit(X_train) + + X_train = normalizer.transform(X_train) + X_test = normalizer.transform(X_test) + Ux = normalizer.transform(Ux) + + # default behavior + model = CoTraining(params=None, random_state=0) + model.train(X_train, y_train, Ux) + + # testing train and predict methods + pred, acc, *_ = model.predict(X_test, y_test) + + assert acc > 0.7 + np.testing.assert_equal(pred, y_test) + + # testing hyperopt optimize methods + space = {'max_iter': scope.int(hp.quniform('max_iter', + 10, + 10000, + 10)), + 'tol': hp.loguniform('tol', 1e-5, 1e-3), + 'C': hp.uniform('C', 1.0, 1000.0), + 'n_samples': scope.int(hp.quniform('n_samples', + 1, + 20, + 1)), + 'seed': 0 + } + data_dict = {'trainx': X_train, + 'testx': X_test, + 'trainy': y_train, + 'testy': y_test, + 'Ux': Ux + } + model.optimize(space, data_dict, max_evals=2, verbose=True) + + assert model.best['accuracy'] >= model.worst['accuracy'] + assert model.best['status'] == 'ok' + + # testing model plotting method + filename = 'test_plot' + model.plot_cotraining(model1_accs=model.best['model1_acc_history'], + model2_accs=model.best['model2_acc_history'], + filename=filename) + os.remove(filename+'.png') + + # testing model write to file method + filename = 'test_LogReg' + ext = '.joblib' + model.save(filename) + model_file = joblib.load(filename+ext) + assert model_file.best['params'] == model.best['params'] + + os.remove(filename+ext)