From 7ed3974d3e847b1c54ac9cfd92277452ff22ff71 Mon Sep 17 00:00:00 2001 From: Jamie Bullock Date: Sat, 16 Feb 2019 19:02:41 +0000 Subject: [PATCH] =?UTF-8?q?Add=20support=20for=20=E2=80=9Cprobs=E2=80=9D?= =?UTF-8?q?=20to=20mlp.=20Fixes=20=3DO=5F032=3D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sources/regression/ml_mlp.cpp | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/sources/regression/ml_mlp.cpp b/sources/regression/ml_mlp.cpp index ceec216..83105a8 100644 --- a/sources/regression/ml_mlp.cpp +++ b/sources/regression/ml_mlp.cpp @@ -86,7 +86,8 @@ namespace ml FLEXT_CADDATTR_SET(c, "use_validation_set", set_use_validation_set); FLEXT_CADDATTR_SET(c, "validation_set_size", set_validation_set_size); FLEXT_CADDATTR_SET(c, "randomize_training_order", set_randomise_training_order); - + FLEXT_CADDATTR_SET(c, "probs", set_probs); + FLEXT_CADDATTR_GET(c, "mode", get_mode); FLEXT_CADDATTR_GET(c, "num_outputs", get_num_outputs); FLEXT_CADDATTR_GET(c, "num_hidden", get_num_hidden); @@ -104,7 +105,8 @@ namespace ml FLEXT_CADDATTR_GET(c, "use_validation_set", get_use_validation_set); FLEXT_CADDATTR_GET(c, "validation_set_size", get_validation_set_size); FLEXT_CADDATTR_GET(c, "randomize_training_order", get_randomise_training_order); - + FLEXT_CADDATTR_GET(c, "probs", get_probs); + DefineHelp(c, object_name.c_str()); } @@ -132,7 +134,8 @@ namespace ml void set_use_validation_set(bool use_validation_set); void set_validation_set_size(int validation_set_size); void set_randomise_training_order(bool randomise_training_order); - + void set_probs(bool probs); + // Flext attribute getters void get_mode(int &mode) const; void get_num_outputs(int &num_outputs) const; @@ -152,7 +155,8 @@ namespace ml void get_use_validation_set(bool &use_validation_set) const; void get_validation_set_size(int &validation_set_size) const; void get_randomise_training_order(bool &randomise_training_order) const; - + void get_probs(bool &probs) const; + // Implement pure virtual methods GRT::MLBase &get_MLBase_instance(); const GRT::MLBase &get_MLBase_instance() const; @@ -184,6 +188,7 @@ namespace ml FLEXT_CALLVAR_B(get_use_validation_set, set_use_validation_set); FLEXT_CALLVAR_I(get_validation_set_size, set_validation_set_size); FLEXT_CALLVAR_B(get_randomise_training_order, set_randomise_training_order); + FLEXT_CALLVAR_B(get_probs, set_probs); // Virtual method override virtual const std::string get_object_name(void) const { return object_name; }; @@ -194,6 +199,7 @@ namespace ml GRT::Neuron::Type hidden_activation_function; GRT::Neuron::Type output_activation_function; + bool probs; }; // Flext attribute setters @@ -422,6 +428,11 @@ namespace ml } } + void mlp::set_probs(bool probs) + { + this->probs = probs; + } + // Flext attribute getters void mlp::get_mode(int &mode) const { @@ -520,6 +531,9 @@ namespace ml void mlp::get_randomise_training_order(bool &randomise_training_order) const { flext::error("function not implemented"); + void mlp::get_probs(bool &probs) const + { + probs = this->probs; } // Methods @@ -620,7 +634,6 @@ namespace ml return; } - // TODO: add probs to attributes if (grt_mlp.getClassificationModeActive()) { GRT::VectorDouble likelihoods = grt_mlp.getClassLikelihoods(); @@ -631,9 +644,9 @@ namespace ml { flext::error("labels / likelihoods size mismatch"); } - else + else if (probs) { - AtomList probs; + AtomList probs_list; for (unsigned count = 0; count < labels.size(); ++count) { @@ -643,10 +656,10 @@ namespace ml SetFloat(likelihood_a, static_cast(likelihoods[count])); SetInt(label_a, labels[count]); - probs.Append(label_a); - probs.Append(likelihood_a); + probs_list.Append(label_a); + probs_list.Append(likelihood_a); } - ToOutAnything(1, get_s_probs(), probs); + ToOutAnything(1, get_s_probs(), probs_list); } ToOutInt(0, classification);