Skip to content

Commit

Permalink
Add support for “probs” to mlp. Fixes =O_032=
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie Bullock committed Feb 16, 2019
1 parent 4e5e931 commit 7ed3974
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions sources/regression/ml_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ namespace ml
FLEXT_CADDATTR_SET(c, "use_validation_set", set_use_validation_set);
FLEXT_CADDATTR_SET(c, "validation_set_size", set_validation_set_size);
FLEXT_CADDATTR_SET(c, "randomize_training_order", set_randomise_training_order);

FLEXT_CADDATTR_SET(c, "probs", set_probs);

FLEXT_CADDATTR_GET(c, "mode", get_mode);
FLEXT_CADDATTR_GET(c, "num_outputs", get_num_outputs);
FLEXT_CADDATTR_GET(c, "num_hidden", get_num_hidden);
Expand All @@ -104,7 +105,8 @@ namespace ml
FLEXT_CADDATTR_GET(c, "use_validation_set", get_use_validation_set);
FLEXT_CADDATTR_GET(c, "validation_set_size", get_validation_set_size);
FLEXT_CADDATTR_GET(c, "randomize_training_order", get_randomise_training_order);

FLEXT_CADDATTR_GET(c, "probs", get_probs);

DefineHelp(c, object_name.c_str());
}

Expand Down Expand Up @@ -132,7 +134,8 @@ namespace ml
void set_use_validation_set(bool use_validation_set);
void set_validation_set_size(int validation_set_size);
void set_randomise_training_order(bool randomise_training_order);

void set_probs(bool probs);

// Flext attribute getters
void get_mode(int &mode) const;
void get_num_outputs(int &num_outputs) const;
Expand All @@ -152,7 +155,8 @@ namespace ml
void get_use_validation_set(bool &use_validation_set) const;
void get_validation_set_size(int &validation_set_size) const;
void get_randomise_training_order(bool &randomise_training_order) const;

void get_probs(bool &probs) const;

// Implement pure virtual methods
GRT::MLBase &get_MLBase_instance();
const GRT::MLBase &get_MLBase_instance() const;
Expand Down Expand Up @@ -184,6 +188,7 @@ namespace ml
FLEXT_CALLVAR_B(get_use_validation_set, set_use_validation_set);
FLEXT_CALLVAR_I(get_validation_set_size, set_validation_set_size);
FLEXT_CALLVAR_B(get_randomise_training_order, set_randomise_training_order);
FLEXT_CALLVAR_B(get_probs, set_probs);

// Virtual method override
virtual const std::string get_object_name(void) const { return object_name; };
Expand All @@ -194,6 +199,7 @@ namespace ml
GRT::Neuron::Type hidden_activation_function;
GRT::Neuron::Type output_activation_function;

bool probs;
};

// Flext attribute setters
Expand Down Expand Up @@ -422,6 +428,11 @@ namespace ml
}
}

void mlp::set_probs(bool probs)
{
this->probs = probs;
}

// Flext attribute getters
void mlp::get_mode(int &mode) const
{
Expand Down Expand Up @@ -520,6 +531,9 @@ namespace ml
void mlp::get_randomise_training_order(bool &randomise_training_order) const
{
flext::error("function not implemented");
void mlp::get_probs(bool &probs) const
{
probs = this->probs;
}

// Methods
Expand Down Expand Up @@ -620,7 +634,6 @@ namespace ml
return;
}

// TODO: add probs to attributes
if (grt_mlp.getClassificationModeActive())
{
GRT::VectorDouble likelihoods = grt_mlp.getClassLikelihoods();
Expand All @@ -631,9 +644,9 @@ namespace ml
{
flext::error("labels / likelihoods size mismatch");
}
else
else if (probs)
{
AtomList probs;
AtomList probs_list;

for (unsigned count = 0; count < labels.size(); ++count)
{
Expand All @@ -643,10 +656,10 @@ namespace ml
SetFloat(likelihood_a, static_cast<float>(likelihoods[count]));
SetInt(label_a, labels[count]);

probs.Append(label_a);
probs.Append(likelihood_a);
probs_list.Append(label_a);
probs_list.Append(likelihood_a);
}
ToOutAnything(1, get_s_probs(), probs);
ToOutAnything(1, get_s_probs(), probs_list);
}

ToOutInt(0, classification);
Expand Down

0 comments on commit 7ed3974

Please sign in to comment.