Skip to content

Commit

Permalink
Allow arbitrary class labels for MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie Bullock committed Mar 9, 2019
1 parent ea69436 commit f4d8db4
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
buildConfiguration = "Development"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
language = ""
shouldUseLaunchSchemeArgsEnv = "YES">
<Testables>
</Testables>
Expand All @@ -60,7 +59,6 @@
buildConfiguration = "Development"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
language = ""
launchStyle = "0"
useCustomWorkingDirectory = "NO"
ignoresPersistentStateOnLaunch = "NO"
Expand All @@ -86,7 +84,7 @@
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
argument = "~/Documents/ml-lib/ml.mlp/ml.mlp-help.pd"
argument = "~/Documents/ml-lib/help/pd/ml.mlp-help.pd"
isEnabled = "YES">
</CommandLineArgument>
</CommandLineArguments>
Expand Down
2 changes: 1 addition & 1 deletion sources/ml_ml.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ namespace ml

data_type get_data_type() const;
void set_data_type(data_type type);
void set_num_inputs(uint16_t num_inputs);

virtual GRT::MLBase &get_MLBase_instance() = 0;
virtual const GRT::MLBase &get_MLBase_instance() const = 0;
Expand Down Expand Up @@ -103,7 +104,6 @@ namespace ml
private:

void record_(bool state);
void set_num_inputs(uint16_t num_inputs);

// Flext method wrappers
FLEXT_CALLBACK_A(any);
Expand Down
127 changes: 123 additions & 4 deletions sources/regression/ml_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "ml_ml.h"
#include "ml_defaults.h"

#include <unordered_map>

namespace ml
{
const std::string object_name = ML_NAME_PREFIX "mlp";
Expand Down Expand Up @@ -111,6 +113,7 @@ namespace ml
DefineHelp(c, object_name.c_str());
}

void add(int argc, const t_atom *argv);
void clear();
void train();
void map(int argc, const t_atom *argv);
Expand Down Expand Up @@ -166,6 +169,9 @@ namespace ml

private:
void set_activation_function(int activation_function, mlp_layer layer);
int get_index_for_class(int classID);
int get_class_id_for_index(int index);
void clear_index_maps();

// Flext method wrappers
FLEXT_CALLBACK(error);
Expand Down Expand Up @@ -199,6 +205,8 @@ namespace ml
GRT::Neuron::Type input_activation_function;
GRT::Neuron::Type hidden_activation_function;
GRT::Neuron::Type output_activation_function;
std::unordered_map<int, int> classLabelToIndex;
std::unordered_map<int, int> indexToClassLabel;

bool probs;
};
Expand Down Expand Up @@ -374,6 +382,47 @@ namespace ml
post("activation function set to " + grt_mlp.activationFunctionToString(activation_function_));
}

// adds index if it doesn't exist
int mlp::get_index_for_class(int classLabel)
{
const int count = classLabelToIndex.count(classLabel);

if (count == 1)
{
return classLabelToIndex.at(classLabel);
}
else if (count != 0)
{
assert(false);
}

const int index = classLabelToIndex.size() + 1; // GRT labels (i.e. index) must start from 1
classLabelToIndex[classLabel] = index;
indexToClassLabel[index] = classLabel;

return index;
}

// assumes index exists and returns -1 as a failsafe
int mlp::get_class_id_for_index(int index)
{
const int count = indexToClassLabel.count(index);

if (count == 0)
{
assert(false);
return -1;
}

return indexToClassLabel.at(index);
}

void mlp::clear_index_maps()
{
indexToClassLabel.clear();
classLabelToIndex.clear();
}

void mlp::set_input_activation_function(int activation_function)
{
set_activation_function(activation_function, LAYER_INPUT);
Expand Down Expand Up @@ -591,10 +640,79 @@ namespace ml
ToOutAnything(1, get_s_train(), 1, &a_success);
}

void mlp::add(int argc, const t_atom *argv)
{
if (get_data_type() != data_type::LABELLED_CLASSIFICATION)
{
ml::add(argc, argv);
return;
}

// work around a bug in GRT where class labels must be contigious
if (argc < 2)
{
flext::error("invalid input length, must contain at least 2 values");
return;
}

GRT::UINT numInputDimensions = classification_data.getNumDimensions();
GRT::UINT numOutputDimensions = 1;
GRT::UINT combinedVectorSize = numInputDimensions + numOutputDimensions;

if ((unsigned)argc != combinedVectorSize)
{
numInputDimensions = argc - numOutputDimensions;

if (numInputDimensions < 1)
{
flext::error(std::string("invalid input length, expected at least " + std::to_string(numOutputDimensions + 1)).c_str());
return;
}
post("new input vector size, adjusting num_inputs to " + std::to_string(numInputDimensions));
set_num_inputs(numInputDimensions);
}

GRT::VectorDouble inputVector(numInputDimensions);
GRT::VectorDouble targetVector(numOutputDimensions);

for (uint32_t index = 0; index < (unsigned)argc; ++index)
{
float value = GetAFloat(argv[index]);

if (index < numOutputDimensions)
{
targetVector[index] = value;
}
else
{
inputVector[index - numOutputDimensions] = value;
}
}

GRT::UINT label = get_index_for_class((GRT::UINT)targetVector[0]);

assert(label > 0);

// if ((double)label != targetVector[0])
// {
// flext::error("class label must be a positive integer");
// return;
// }
//
// if (label == 0)
// {
// flext::error("class label must be non-zero");
// return;
// }

classification_data.addSample(label, inputVector);
}

void mlp::clear()
{
grt_mlp.clear();
ml::clear();
clear_index_maps();
}

void mlp::map(int argc, const t_atom *argv)
Expand Down Expand Up @@ -639,9 +757,10 @@ namespace ml

if (grt_mlp.getClassificationModeActive())
{
GRT::VectorDouble likelihoods = grt_mlp.getClassLikelihoods();
GRT::Vector<GRT::UINT> labels = classification_data.getClassLabels();
GRT::UINT classification = grt_mlp.getPredictedClassLabel();
const GRT::VectorDouble likelihoods = grt_mlp.getClassLikelihoods();
const GRT::Vector<GRT::UINT> labels = classification_data.getClassLabels();
const GRT::UINT predicted = grt_mlp.getPredictedClassLabel();
const GRT::UINT classification = predicted == 0 ? 0 : get_class_id_for_index(predicted);

if (likelihoods.size() != labels.size())
{
Expand All @@ -657,7 +776,7 @@ namespace ml
t_atom likelihood_a;

SetFloat(likelihood_a, static_cast<float>(likelihoods[count]));
SetInt(label_a, labels[count]);
SetInt(label_a, get_class_id_for_index(labels[count]));

probs_list.Append(label_a);
probs_list.Append(likelihood_a);
Expand Down

0 comments on commit f4d8db4

Please sign in to comment.