diff --git a/sources/classification/ml_svm.cpp b/sources/classification/ml_svm.cpp index c387458..b47b2a3 100644 --- a/sources/classification/ml_svm.cpp +++ b/sources/classification/ml_svm.cpp @@ -139,7 +139,7 @@ namespace ml FLEXT_CADDATTR_SET(c, "cost", set_cost); FLEXT_CADDATTR_SET(c, "nu", set_nu); FLEXT_CADDATTR_SET(c, "probs", set_probs); - FLEXT_CADDATTR_SET(c, "mode", set_kfold_value); + FLEXT_CADDATTR_SET(c, "num_folds", set_kfold_value); FLEXT_CADDATTR_SET(c, "enable_cross_validation", set_enable_cross_validation); FLEXT_CADDATTR_GET(c, "type", get_type); @@ -150,7 +150,7 @@ namespace ml FLEXT_CADDATTR_GET(c, "cost", get_cost); FLEXT_CADDATTR_GET(c, "nu", get_nu); FLEXT_CADDATTR_GET(c, "probs", get_probs); - FLEXT_CADDATTR_GET(c, "mode", get_kfold_value); + FLEXT_CADDATTR_GET(c, "num_folds", get_kfold_value); FLEXT_CADDMETHOD_(c, 0, "cross_validation", cross_validation); @@ -297,6 +297,11 @@ namespace ml } + void svm::get_kfold_value(int &mode) const + { + mode = grt_svm.getKFoldCrossValidationValue(); + } + void svm::get_degree(int °ree) const { degree = grt_svm.getDegree(); diff --git a/sources/ml_doc_populate.cpp b/sources/ml_doc_populate.cpp index f72ce25..26a5653 100644 --- a/sources/ml_doc_populate.cpp +++ b/sources/ml_doc_populate.cpp @@ -392,41 +392,19 @@ namespace ml_doc 0.5 ); - ranged_message_descriptor epsilon( - "epsilon", - "set the epsilon in loss function of epsilon-SVR", - INFINITY * -1.f, INFINITY, - 0.1 - ); - - ranged_message_descriptor cachesize( - "cachesize", - "set cache memory size in MB", - 8, - 1024, - 100 - ); - - ranged_message_descriptor tolerance( - "tolerance", - "set tolerance of termination criterion", - INFINITY * -1.f, INFINITY, - 0.001 - ); - - valued_message_descriptor shrinking( - "shrinking", - "whether to use the shrinking heuristics", - {0, 1}, - 1 - ); - message_descriptor cross_validation( "cross_validation", "perform cross validation" ); - descriptors[ml::k_svm].add_message_descriptor(cross_validation, type, kernel, degree, svm_gamma, coef0, cost, nu, epsilon, cachesize, tolerance, shrinking); + ranged_message_descriptor num_folds( + "num_folds", + "set the number of folds used for cross validation", + 1, 100, + 10 + ); + + descriptors[ml::k_svm].add_message_descriptor(cross_validation, num_folds, type, kernel, degree, svm_gamma, coef0, cost, nu); //---- ml.adaboost ranged_message_descriptor num_boosting_iterations( @@ -524,14 +502,6 @@ namespace ml_doc 5 ); - ranged_message_descriptor num_symbols( - "num_symbols", - "sets the number of symbols in the model", - 0, - 100, - 10 - ); - valued_message_descriptor model_type( "model_type", "set the model type used, 0:ERGODIC, 1:LEFTRIGHT", @@ -588,7 +558,7 @@ namespace ml_doc ); descriptors[ml::k_hmm].insert_message_descriptor(record); - descriptors[ml::k_hmm].add_message_descriptor(num_states, num_symbols, model_type, delta, max_num_iterations, num_random_training_iterations, min_improvement, committee_size, downsample_factor); + descriptors[ml::k_hmm].add_message_descriptor(num_states, model_type, delta, max_num_iterations, num_random_training_iterations, min_improvement, committee_size, downsample_factor); //---- ml.softmax