Skip to content

Commit

Permalink
Add get and set for num_folds for cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie Bullock committed Feb 15, 2019
1 parent 5c20523 commit bf3c140
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 41 deletions.
9 changes: 7 additions & 2 deletions sources/classification/ml_svm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ namespace ml
FLEXT_CADDATTR_SET(c, "cost", set_cost);
FLEXT_CADDATTR_SET(c, "nu", set_nu);
FLEXT_CADDATTR_SET(c, "probs", set_probs);
FLEXT_CADDATTR_SET(c, "mode", set_kfold_value);
FLEXT_CADDATTR_SET(c, "num_folds", set_kfold_value);
FLEXT_CADDATTR_SET(c, "enable_cross_validation", set_enable_cross_validation);

FLEXT_CADDATTR_GET(c, "type", get_type);
Expand All @@ -150,7 +150,7 @@ namespace ml
FLEXT_CADDATTR_GET(c, "cost", get_cost);
FLEXT_CADDATTR_GET(c, "nu", get_nu);
FLEXT_CADDATTR_GET(c, "probs", get_probs);
FLEXT_CADDATTR_GET(c, "mode", get_kfold_value);
FLEXT_CADDATTR_GET(c, "num_folds", get_kfold_value);

FLEXT_CADDMETHOD_(c, 0, "cross_validation", cross_validation);

Expand Down Expand Up @@ -297,6 +297,11 @@ namespace ml

}

void svm::get_kfold_value(int &mode) const
{
mode = grt_svm.getKFoldCrossValidationValue();
}

void svm::get_degree(int &degree) const
{
degree = grt_svm.getDegree();
Expand Down
48 changes: 9 additions & 39 deletions sources/ml_doc_populate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,41 +392,19 @@ namespace ml_doc
0.5
);

ranged_message_descriptor<float> epsilon(
"epsilon",
"set the epsilon in loss function of epsilon-SVR",
INFINITY * -1.f, INFINITY,
0.1
);

ranged_message_descriptor<int> cachesize(
"cachesize",
"set cache memory size in MB",
8,
1024,
100
);

ranged_message_descriptor<float> tolerance(
"tolerance",
"set tolerance of termination criterion",
INFINITY * -1.f, INFINITY,
0.001
);

valued_message_descriptor<bool> shrinking(
"shrinking",
"whether to use the shrinking heuristics",
{0, 1},
1
);

message_descriptor cross_validation(
"cross_validation",
"perform cross validation"
);

descriptors[ml::k_svm].add_message_descriptor(cross_validation, type, kernel, degree, svm_gamma, coef0, cost, nu, epsilon, cachesize, tolerance, shrinking);
ranged_message_descriptor<int> num_folds(
"num_folds",
"set the number of folds used for cross validation",
1, 100,
10
);

descriptors[ml::k_svm].add_message_descriptor(cross_validation, num_folds, type, kernel, degree, svm_gamma, coef0, cost, nu);

//---- ml.adaboost
ranged_message_descriptor<int> num_boosting_iterations(
Expand Down Expand Up @@ -524,14 +502,6 @@ namespace ml_doc
5
);

ranged_message_descriptor<int> num_symbols(
"num_symbols",
"sets the number of symbols in the model",
0,
100,
10
);

valued_message_descriptor<int> model_type(
"model_type",
"set the model type used, 0:ERGODIC, 1:LEFTRIGHT",
Expand Down Expand Up @@ -588,7 +558,7 @@ namespace ml_doc
);

descriptors[ml::k_hmm].insert_message_descriptor(record);
descriptors[ml::k_hmm].add_message_descriptor(num_states, num_symbols, model_type, delta, max_num_iterations, num_random_training_iterations, min_improvement, committee_size, downsample_factor);
descriptors[ml::k_hmm].add_message_descriptor(num_states, model_type, delta, max_num_iterations, num_random_training_iterations, min_improvement, committee_size, downsample_factor);

//---- ml.softmax

Expand Down

0 comments on commit bf3c140

Please sign in to comment.