From 166da6e48bfa6e645de6c3d79f84a03bdb2d95ab Mon Sep 17 00:00:00 2001 From: pangolulu Date: Fri, 19 Feb 2016 14:14:07 +0800 Subject: [PATCH 1/5] Support data instance class for user usage, and add test cases for model serialization --- .../apache/samoa/learners/DataInstance.java | 27 ++++ .../apache/samoa/learners/InstanceUtils.java | 139 ++++++++++++++++ .../java/org/apache/samoa/learners/Model.java | 4 +- .../classifiers/NominalDataInstance.java | 81 ++++++++++ .../classifiers/NumericDataInstance.java | 69 ++++++++ .../classifiers/ensemble/EnsembleModel.java | 19 +-- .../classifiers/rules/AMRulesModel.java | 8 +- .../classifiers/trees/HoeffdingTreeModel.java | 18 ++- .../learners/clusterers/CluStreamModel.java | 28 ++-- .../clusterers/ClusterDataInstance.java | 59 +++++++ .../src/main/resources/reference.conf | 3 + .../samoa/serialize/AMRulesModelTest.java | 102 ++++++++++++ .../samoa/serialize/CluStreamModelTest.java | 105 ++++++++++++ .../samoa/serialize/EnsembleModelTest.java | 149 ++++++++++++++++++ .../serialize/HoeffdingTreeModelTest.java | 142 +++++++++++++++++ 15 files changed, 920 insertions(+), 33 deletions(-) create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java create mode 100644 samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java create mode 100644 samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java create mode 100644 samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java create mode 100644 samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java new file mode 100644 index 00000000..99167360 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java @@ -0,0 +1,27 @@ +package org.apache.samoa.learners; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +public interface DataInstance extends Serializable { + +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java b/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java new file mode 100644 index 00000000..7065e920 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java @@ -0,0 +1,139 @@ +package org.apache.samoa.learners; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.*; +import org.apache.samoa.learners.classifiers.NominalDataInstance; +import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.clusterers.ClusterDataInstance; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.moa.core.FastVector; + +import java.util.ArrayList; + +public class InstanceUtils { + static private InstancesHeader getNumericInstanceHeader(NumericDataInstance numericDataInstance) { + FastVector attributes = new FastVector<>(); + + for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) { + attributes.addElement(new Attribute("numeric" + (i + 1))); + } + + FastVector classLabels = new FastVector<>(); + for (int i = 0; i < numericDataInstance.getNumClasses(); i++) { + classLabels.addElement("class" + (i + 1)); + } + attributes.addElement(new Attribute("class", classLabels)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + static private Instance convertNumericInstance(NumericDataInstance numericDataInstance) { + InstancesHeader header = InstanceUtils.getNumericInstanceHeader(numericDataInstance); + Instance inst = new DenseInstance(header.numAttributes()); + + for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) { + inst.setValue(i, numericDataInstance.getData()[i]); + } + + inst.setDataset(header); + inst.setClassValue(numericDataInstance.getTrueClass()); + + return inst; + } + + static private InstancesHeader getNominalInstanceHeader(NominalDataInstance nominalDataInstance) { + FastVector attributes = new FastVector<>(); + + for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) { + FastVector nominalAttVals = new FastVector<>(); + for (int j = 0; j < nominalDataInstance.getNumValsPerNominal()[i]; j++) { + nominalAttVals.addElement("value" + (j + 1)); + } + attributes.addElement(new Attribute("nominal" + (i + 1), + nominalAttVals)); + } + + FastVector classLabels = new FastVector<>(); + for (int i = 0; i < nominalDataInstance.getNumClasses(); i++) { + classLabels.addElement("class" + (i + 1)); + } + attributes.addElement(new Attribute("class", classLabels)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + static private Instance convertNominalInstance(NominalDataInstance nominalDataInstance) { + InstancesHeader header = InstanceUtils.getNominalInstanceHeader(nominalDataInstance); + Instance inst = new DenseInstance(header.numAttributes()); + + for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) { + inst.setValue(i, nominalDataInstance.getData()[i]); + } + + inst.setDataset(header); + inst.setClassValue(nominalDataInstance.getTrueClass()); + + return inst; + } + + static private InstancesHeader getClusterInstanceHeader(ClusterDataInstance clusterDataInstance) { + ArrayList attributes = new ArrayList<>(); + + for (int i = 0; i < clusterDataInstance.getNumAtts(); i++) { + attributes.add(new Attribute("att" + (i + 1))); + } + + // attributes.add(new Attribute("class", null)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + static private Instance convertClusterInstance(ClusterDataInstance clusterDataInstance) { + Instance inst = new DenseInstance(1.0, clusterDataInstance.getData()); + inst.setDataset(InstanceUtils.getClusterInstanceHeader(clusterDataInstance)); + return new DataPoint(inst, clusterDataInstance.getTimeStamp()); + } + + static public Instance convertToSamoaInstance(DataInstance dataInstance) { + if (dataInstance instanceof NumericDataInstance) { + return InstanceUtils.convertNumericInstance((NumericDataInstance) dataInstance); + } else if (dataInstance instanceof NominalDataInstance) { + return InstanceUtils.convertNominalInstance((NominalDataInstance) dataInstance); + } else if (dataInstance instanceof ClusterDataInstance) { + return InstanceUtils.convertClusterInstance((ClusterDataInstance) dataInstance); + } else { + throw new Error("Invalid input class!"); + } + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java b/samoa-api/src/main/java/org/apache/samoa/learners/Model.java index f955c041..37e05617 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/Model.java @@ -21,10 +21,8 @@ */ -import org.apache.samoa.instances.Instance; - import java.io.Serializable; public interface Model extends Serializable { - double[] predict(Instance inst); + double[] predict(DataInstance dataInstance); } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java new file mode 100644 index 00000000..fe3d747d --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java @@ -0,0 +1,81 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.learners.DataInstance; + +public class NominalDataInstance implements DataInstance { + + private int numNominals; + private int numClasses; + private double trueClass; // index started from 0 + private int[] numValsPerNominal; + private double[] data; // index started from 0 + + public NominalDataInstance(int numNominals, int numClasses, double trueClass, + int[] numValsPerNominal, double[] data) { + this.numNominals = numNominals; + this.numClasses = numClasses; + this.trueClass = trueClass; + this.numValsPerNominal = numValsPerNominal; + this.data = data; + } + + public int getNumNominals() { + return numNominals; + } + + public void setNumNominals(int numNominals) { + this.numNominals = numNominals; + } + + public int getNumClasses() { + return numClasses; + } + + public void setNumClasses(int numClasses) { + this.numClasses = numClasses; + } + + public double getTrueClass() { + return trueClass; + } + + public void setTrueClass(double trueClass) { + this.trueClass = trueClass; + } + + public int[] getNumValsPerNominal() { + return numValsPerNominal; + } + + public void setNumValsPerNominal(int[] numValsPerNominal) { + this.numValsPerNominal = numValsPerNominal; + } + + public double[] getData() { + return data; + } + + public void setData(double[] data) { + this.data = data; + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java new file mode 100644 index 00000000..d28858a2 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java @@ -0,0 +1,69 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.learners.DataInstance; + +public class NumericDataInstance implements DataInstance { + private int numNumerics; + private int numClasses; + private double trueClass; // index started from 0 + private double[] data; + + public NumericDataInstance(int numNumerics, int numClasses, double trueClass, double[] data) { + this.numNumerics = numNumerics; + this.numClasses = numClasses; + this.trueClass = trueClass; + this.data = data; + } + + public int getNumNumerics() { + return numNumerics; + } + + public void setNumNumerics(int numNumerics) { + this.numNumerics = numNumerics; + } + + public int getNumClasses() { + return numClasses; + } + + public void setNumClasses(int numClasses) { + this.numClasses = numClasses; + } + + public double getTrueClass() { + return trueClass; + } + + public void setTrueClass(double trueClass) { + this.trueClass = trueClass; + } + + public double[] getData() { + return data; + } + + public void setData(double[] data) { + this.data = data; + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java index 00dd126a..c4c0cc00 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java @@ -2,12 +2,12 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.InstanceUtils; import org.apache.samoa.learners.Model; import org.apache.samoa.moa.core.DoubleVector; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; /* * #%L @@ -33,19 +33,16 @@ public class EnsembleModel implements Model { private ArrayList modelList; private ArrayList modelWeightList; - public EnsembleModel() { - } - public EnsembleModel(ArrayList modelList, ArrayList modelWeightList) { this.modelList = modelList; this.modelWeightList = modelWeightList; } @Override - public double[] predict(Instance inst) { + public double[] predict(DataInstance dataInstance) { DoubleVector combinedVote = new DoubleVector(); for (int i = 0; i < modelList.size(); i++) { - double[] prediction = modelList.get(i).predict(inst); + double[] prediction = modelList.get(i).predict(dataInstance); DoubleVector vote = new DoubleVector(prediction); if (vote.sumOfValues() > 0.0) { vote.normalize(); @@ -56,9 +53,13 @@ public double[] predict(Instance inst) { return combinedVote.getArrayCopy(); } - public boolean evaluate(Instance inst) { + /* + Predict the class of an input data instance, and evaluate if it is the true class. + */ + public boolean evaluate(DataInstance dataInstance) { + Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); int trueClass = (int) inst.classValue(); - double[] prediction = this.predict(inst); + double[] prediction = this.predict(dataInstance); int predictedClass = Utils.maxIndex(prediction); return trueClass == predictedClass; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java index 5565ea7d..061c66dc 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java @@ -21,6 +21,8 @@ */ import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.InstanceUtils; import org.apache.samoa.learners.Model; import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; @@ -34,9 +36,6 @@ public class AMRulesModel implements Model { private ErrorWeightedVote errorWeightedVote; private boolean unorderedRules; - public AMRulesModel() { - } - public AMRulesModel(ActiveRule defaultRule, List ruleSet, ErrorWeightedVote errorWeightedVote, boolean unorderedRules) { this.defaultRule = defaultRule; @@ -46,7 +45,8 @@ public AMRulesModel(ActiveRule defaultRule, List ruleSet, } @Override - public double[] predict(Instance inst) { + public double[] predict(DataInstance dataInstance) { + Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); double[] prediction; boolean predictionCovered = false; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java index 4b856795..b3398e1b 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java @@ -24,6 +24,8 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Instances; import org.apache.samoa.instances.Utils; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.InstanceUtils; import org.apache.samoa.learners.Model; public class HoeffdingTreeModel implements Model { @@ -35,12 +37,10 @@ public HoeffdingTreeModel(Instances dataset, Node treeRoot) { this.treeRoot = treeRoot; } - public HoeffdingTreeModel() { - } - @Override - public double[] predict(Instance inst) { + public double[] predict(DataInstance dataInstance) { double[] prediction; + Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); // inst.setDataset(dataset); FoundNode foundNode; @@ -55,13 +55,19 @@ public double[] predict(Instance inst) { int numClasses = dataset.numClasses(); prediction = new double[numClasses]; } + return prediction; } - public boolean evaluate(Instance inst) { + /* + Predict the class of an input data instance, and evaluate if it is the true class. + */ + public boolean evaluate(DataInstance dataInstance) { + Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); int trueClass = (int) inst.classValue(); - double[] prediction = this.predict(inst); + double[] prediction = this.predict(dataInstance); int predictedClass = Utils.maxIndex(prediction); + return trueClass == predictedClass; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java index dd153cd6..5bd34a4d 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java @@ -21,7 +21,8 @@ */ -import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.InstanceUtils; import org.apache.samoa.learners.Model; import org.apache.samoa.moa.cluster.Clustering; import org.apache.samoa.moa.core.DataPoint; @@ -36,13 +37,10 @@ public CluStreamModel(Clustering clustering) { this.clustering = clustering; } - public CluStreamModel() { - } - @Override - public double[] predict(Instance inst) { - DataPoint dataPoint = (DataPoint) inst; + public double[] predict(DataInstance dataInstance) { double[] distances = new double[clustering.size()]; + DataPoint dataPoint = (DataPoint) InstanceUtils.convertToSamoaInstance(dataInstance); for (int c = 0; c < clustering.size(); c++) { double distance = 0.0; double[] center = clustering.get(c).getCenter(); @@ -52,18 +50,26 @@ public double[] predict(Instance inst) { } distances[c] = Math.sqrt(distance); } + return distances; } - public double evaluate(ArrayList points, MeasureCollection measure) { - double score = 0.0; + /* + Given a list of data instances and a measure, evaluate the performance of the resulting cluster. + */ + public double evaluate(ArrayList points, MeasureCollection measure) { + ArrayList dataPoints = new ArrayList<>(); + for (DataInstance dataInstance : points) { + dataPoints.add((DataPoint) InstanceUtils.convertToSamoaInstance(dataInstance)); + } + try { - measure.evaluateClusteringPerformance(clustering, null, points); - score = measure.getMean(0); + measure.evaluateClusteringPerformance(clustering, null, dataPoints); } catch (Exception e) { e.printStackTrace(); } - return score; + + return measure.getMean(0); } @Override diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java new file mode 100644 index 00000000..b45cef92 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java @@ -0,0 +1,59 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.learners.DataInstance; + +public class ClusterDataInstance implements DataInstance { + private int numAtts; + private int timeStamp; + private double[] data; + + public ClusterDataInstance(int numAtts, int timeStamp, double[] data) { + this.numAtts = numAtts; + this.timeStamp = timeStamp; + this.data = data; + } + + public int getNumAtts() { + return numAtts; + } + + public void setNumAtts(int numAtts) { + this.numAtts = numAtts; + } + + public int getTimeStamp() { + return timeStamp; + } + + public void setTimeStamp(int timeStamp) { + this.timeStamp = timeStamp; + } + + public double[] getData() { + return data; + } + + public void setData(double[] data) { + this.data = data; + } +} diff --git a/samoa-gearpump/src/main/resources/reference.conf b/samoa-gearpump/src/main/resources/reference.conf index a6e4d8e5..fbb3f3b3 100644 --- a/samoa-gearpump/src/main/resources/reference.conf +++ b/samoa-gearpump/src/main/resources/reference.conf @@ -57,6 +57,9 @@ gearpump { "org.apache.samoa.learners.classifiers.trees.HoeffdingTreeModel" = "" "org.apache.samoa.learners.classifiers.trees.ActiveLearningNode" = "" "[Lorg.apache.samoa.learners.classifiers.trees.AttributeBatchContentEvent;" = "" + "com.github.javacliparser.IntOption" = "" + "com.github.javacliparser.FloatOption" = "" + "org.apache.samoa.learners.InstanceContent" = "" "java.util.ArrayList" = "" "java.util.LinkedList" = "" "java.util.HashMap" = "" diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java new file mode 100644 index 00000000..627c2770 --- /dev/null +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java @@ -0,0 +1,102 @@ +package org.apache.samoa.serialize; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.Option; +import junit.framework.TestCase; +import org.apache.commons.io.FileUtils; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.classifiers.rules.AMRulesModel; +import org.apache.samoa.moa.core.SerializeUtils; +import org.apache.samoa.tasks.Task; +import org.apache.samoa.topology.impl.SimpleComponentFactory; +import org.apache.samoa.topology.impl.SimpleEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +public class AMRulesModelTest extends TestCase { + private static final String BASE_DIR = "amr"; + private static final int NUM_MODEL_IN_DIR = 10; + + private static final String CLISTRING = + "PrequentialEvaluation " + + "-l (org.apache.samoa.learners.classifiers.rules.VerticalAMRulesRegressor -p 4) " + + "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; + + private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; + private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; + private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates."; + + private Option[] extraOptions; + + @Before + public void setUp() { + FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG); + FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG); + IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0, + Integer.MAX_VALUE); + extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt }; + } + + @After + public void tearDown() { + + } + + @Test + public void testAMRulesModel() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { + File fileModel = new File(BASE_DIR + "/amr-model-0-" + i); + File fileData = new File(BASE_DIR + "/amr-data-0-" + i); + + AMRulesModel amRulesModel = (AMRulesModel) SerializeUtils.readFromFile(fileModel); + Instance inst = (Instance) SerializeUtils.readFromFile(fileData); + System.out.println("=== model: " + i + " ==="); + + double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); + + DataInstance dataInstance = new NumericDataInstance(data.length, + inst.numClasses(), inst.classValue(), data); + + System.out.println(Arrays.toString(amRulesModel.predict(dataInstance))); + System.out.println("true predict: " + (int) inst.classValue()); + } + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + } +} diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java new file mode 100644 index 00000000..29edb074 --- /dev/null +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java @@ -0,0 +1,105 @@ +package org.apache.samoa.serialize; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.Option; +import junit.framework.TestCase; +import org.apache.commons.io.FileUtils; +import org.apache.samoa.evaluation.measures.SSQ; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.clusterers.CluStreamModel; +import org.apache.samoa.learners.clusterers.ClusterDataInstance; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.moa.core.SerializeUtils; +import org.apache.samoa.moa.evaluation.MeasureCollection; +import org.apache.samoa.tasks.Task; +import org.apache.samoa.topology.impl.SimpleComponentFactory; +import org.apache.samoa.topology.impl.SimpleEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; + +public class CluStreamModelTest extends TestCase { + private static final String BASE_DIR = "clu"; + private static final String CLISTRING = "ClusteringEvaluation"; + + private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; + private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; + private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates."; + + private Option[] extraOptions; + + @Before + public void setUp() { + FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG); + FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG); + IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0, + Integer.MAX_VALUE); + extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt }; + } + + @After + public void tearDown() { + + } + + @Test + public void testCluStreamModel() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + File fileModel = new File("clu/clu-model"); + File fileData = new File("clu/clu-data"); + + ArrayList points = (ArrayList) SerializeUtils.readFromFile(fileData); + CluStreamModel cluStreamModel = (CluStreamModel) SerializeUtils.readFromFile(fileModel); + + assert points != null; + ArrayList dataInstances = new ArrayList<>(); + for (DataPoint point : points) { + double[] data = point.toDoubleArray(); + DataInstance dataInstance = new ClusterDataInstance(data.length, point.getTimestamp(), data); + dataInstances.add(dataInstance); + System.out.println(Arrays.toString(cluStreamModel.predict(dataInstance))); + } + + assert cluStreamModel != null; + System.out.println(cluStreamModel.toString()); + + MeasureCollection measure = new SSQ(); + double score = cluStreamModel.evaluate(dataInstances, measure); + System.out.println(score); + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + } +} diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java new file mode 100644 index 00000000..9c4244e2 --- /dev/null +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java @@ -0,0 +1,149 @@ +package org.apache.samoa.serialize; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.Option; +import junit.framework.TestCase; +import org.apache.commons.io.FileUtils; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.classifiers.NominalDataInstance; +import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.classifiers.ensemble.EnsembleModel; +import org.apache.samoa.moa.core.SerializeUtils; +import org.apache.samoa.tasks.Task; +import org.apache.samoa.topology.impl.SimpleComponentFactory; +import org.apache.samoa.topology.impl.SimpleEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +public class EnsembleModelTest extends TestCase { + private static final String ENS_BASE_DIR = "bagging"; + private static final String BASE_DIR = "vht"; + private static final int NUM_MODEL_IN_DIR = 10; + + private static final String CLISTRING_NUM = + "PrequentialEvaluation -i 1000000 -f 100000 " + + "-l (classifiers.ensemble.Bagging -s 10 -l (classifiers.trees.VerticalHoeffdingTree)) " + + "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; + private static final String CLISTRING_NOM = + "PrequentialEvaluation -i 1000000 -f 100000 " + + "-l (classifiers.ensemble.Bagging -s 10 -l (classifiers.trees.VerticalHoeffdingTree)) " + + "-s (generators.RandomTreeGenerator -c 2 -o 10 -u 0)"; + + private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; + private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; + private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates."; + + private Option[] extraOptions; + + @Before + public void setUp() { + FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG); + FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG); + IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0, + Integer.MAX_VALUE); + extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt }; + } + + @After + public void tearDown() { + + } + + @Test + public void testEnsembleNumber() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); + FileUtils.forceMkdir(new File(ENS_BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING_NUM, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { + File fileModel = new File(ENS_BASE_DIR + "/bagging-model-" + i); + File fileData = new File(BASE_DIR + "/vht-data-0-" + i); + + EnsembleModel htm = (EnsembleModel) SerializeUtils.readFromFile(fileModel); + Instance inst = (Instance) SerializeUtils.readFromFile(fileData); + System.out.println("=== model: " + i + " ==="); + + double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); + + DataInstance dataInstance = new NumericDataInstance(data.length, + inst.numClasses(), inst.classValue(), data); + + System.out.println(Arrays.toString(htm.predict(dataInstance))); + System.out.println("true predict: " + (int) inst.classValue()); + System.out.println("predict: " + htm.evaluate(dataInstance)); + } + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); + } + + @Test + public void testEnsembleNominal() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); + FileUtils.forceMkdir(new File(ENS_BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING_NOM, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { + File fileModel = new File(ENS_BASE_DIR + "/bagging-model-" + i); + File fileData = new File(BASE_DIR + "/vht-data-0-" + i); + + EnsembleModel htm = (EnsembleModel) SerializeUtils.readFromFile(fileModel); + Instance inst = (Instance) SerializeUtils.readFromFile(fileData); + System.out.println("=== model: " + i + " ==="); + + double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); + + int[] numValsPerNominal = new int[data.length]; + Arrays.fill(numValsPerNominal, inst.attribute(0).numValues()); + + DataInstance dataInstance = new NominalDataInstance(data.length, inst.numClasses(), + inst.classValue(), numValsPerNominal, data); + + System.out.println(Arrays.toString(htm.predict(dataInstance))); + System.out.println("true predict: " + (int) inst.classValue()); + System.out.println("predict: " + htm.evaluate(dataInstance)); + } + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); + } +} diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java new file mode 100644 index 00000000..ad5fed59 --- /dev/null +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java @@ -0,0 +1,142 @@ +package org.apache.samoa.serialize; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.Option; +import junit.framework.TestCase; +import org.apache.commons.io.FileUtils; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.classifiers.NominalDataInstance; +import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.classifiers.trees.HoeffdingTreeModel; +import org.apache.samoa.moa.core.SerializeUtils; +import org.apache.samoa.tasks.Task; +import org.apache.samoa.topology.impl.SimpleComponentFactory; +import org.apache.samoa.topology.impl.SimpleEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +public class HoeffdingTreeModelTest extends TestCase { + private static final String BASE_DIR = "vht"; + private static final int NUM_MODEL_IN_DIR = 10; + + private static final String CLISTRING_NUM = + "PrequentialEvaluation -i 1000000 -f 100000 " + + "-l (classifiers.trees.VerticalHoeffdingTree -p 4) " + + "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; + private static final String CLISTRING_NOM = + "PrequentialEvaluation -i 1000000 -f 100000 " + + "-l (classifiers.trees.VerticalHoeffdingTree -p 4) " + + "-s (generators.RandomTreeGenerator -c 2 -o 10 -u 0)"; + + private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; + private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; + private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates."; + + private Option[] extraOptions; + + @Before + public void setUp() { + FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG); + FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG); + IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0, + Integer.MAX_VALUE); + extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt }; + } + + @After + public void tearDown() { + + } + + @Test + public void testVHTNumber() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING_NUM, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { + File fileModel = new File(BASE_DIR + "/vht-model-0-" + i); + File fileData = new File(BASE_DIR + "/vht-data-0-" + i); + + HoeffdingTreeModel htm = (HoeffdingTreeModel) SerializeUtils.readFromFile(fileModel); + Instance inst = (Instance) SerializeUtils.readFromFile(fileData); + System.out.println("=== model: " + i + " ==="); + + double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); + + DataInstance dataInstance = new NumericDataInstance(data.length, + inst.numClasses(), inst.classValue(), data); + + System.out.println(Arrays.toString(htm.predict(dataInstance))); + System.out.println("true predict: " + (int) inst.classValue()); + System.out.println("predict: " + htm.evaluate(dataInstance)); + } + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + } + + @Test + public void testVHTNominal() throws Exception { + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + FileUtils.forceMkdir(new File(BASE_DIR)); + + Task task = ClassOption.cliStringToObject(CLISTRING_NOM, Task.class, extraOptions); + task.setFactory(new SimpleComponentFactory()); + task.init(); + SimpleEngine.submitTopology(task.getTopology()); + + for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { + File fileModel = new File(BASE_DIR + "/vht-model-0-" + i); + File fileData = new File(BASE_DIR + "/vht-data-0-" + i); + + HoeffdingTreeModel htm = (HoeffdingTreeModel) SerializeUtils.readFromFile(fileModel); + Instance inst = (Instance) SerializeUtils.readFromFile(fileData); + System.out.println("=== model: " + i + " ==="); + + double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); + + int[] numValsPerNominal = new int[data.length]; + Arrays.fill(numValsPerNominal, inst.attribute(0).numValues()); + + DataInstance dataInstance = new NominalDataInstance(data.length, inst.numClasses(), + inst.classValue(), numValsPerNominal, data); + + System.out.println(Arrays.toString(htm.predict(dataInstance))); + System.out.println("true predict: " + (int) inst.classValue()); + System.out.println("predict: " + htm.evaluate(dataInstance)); + } + + FileUtils.forceDeleteOnExit(new File(BASE_DIR)); + } +} From 5644f6835005e1dd5e7c637d2045c31c77731b84 Mon Sep 17 00:00:00 2001 From: pangolulu Date: Tue, 1 Mar 2016 10:07:46 +0800 Subject: [PATCH 2/5] fix --- .../java/org/apache/samoa/serialize/CluStreamModelTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java index 29edb074..d9513e72 100644 --- a/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java @@ -78,8 +78,8 @@ public void testCluStreamModel() throws Exception { task.init(); SimpleEngine.submitTopology(task.getTopology()); - File fileModel = new File("clu/clu-model"); - File fileData = new File("clu/clu-data"); + File fileModel = new File(BASE_DIR + "/clu-model"); + File fileData = new File(BASE_DIR + "/clu-data"); ArrayList points = (ArrayList) SerializeUtils.readFromFile(fileData); CluStreamModel cluStreamModel = (CluStreamModel) SerializeUtils.readFromFile(fileModel); From 29565b4ac0d7ea720153be5c19adf4da59448a85 Mon Sep 17 00:00:00 2001 From: pangolulu Date: Thu, 3 Mar 2016 17:41:59 +0800 Subject: [PATCH 3/5] fix Model and DataInstance --- .../apache/samoa/learners/InstanceUtils.java | 119 +++++++++--------- .../samoa/learners/ModelContentEvent.java | 9 +- .../ClassificationDataInstance.java | 76 +++++++++++ .../ClassificationModel.java} | 11 +- .../classifiers/NominalDataInstance.java | 81 ------------ .../classifiers/NumericDataInstance.java | 69 ---------- .../classifiers/ensemble/EnsembleModel.java | 21 ++-- .../ensemble/PredictionCombinerProcessor.java | 10 +- .../classifiers/rules/AMRulesModel.java | 10 +- .../classifiers/trees/HoeffdingTreeModel.java | 18 +-- .../learners/clusterers/CluStreamModel.java | 18 ++- .../clusterers/ClusterDataInstance.java | 26 ++-- .../ClusterModel.java} | 10 +- .../samoa/serialize/AMRulesModelTest.java | 15 +-- .../samoa/serialize/CluStreamModelTest.java | 7 +- .../samoa/serialize/EnsembleModelTest.java | 69 +++------- .../serialize/HoeffdingTreeModelTest.java | 60 ++------- 17 files changed, 233 insertions(+), 396 deletions(-) create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java rename samoa-api/src/main/java/org/apache/samoa/learners/{Model.java => classifiers/ClassificationModel.java} (79%) delete mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java delete mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java rename samoa-api/src/main/java/org/apache/samoa/learners/{DataInstance.java => clusterers/ClusterModel.java} (80%) diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java b/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java index 7065e920..ab0faea9 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java @@ -21,27 +21,24 @@ */ import org.apache.samoa.instances.*; -import org.apache.samoa.learners.classifiers.NominalDataInstance; -import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; import org.apache.samoa.learners.clusterers.ClusterDataInstance; import org.apache.samoa.moa.core.DataPoint; import org.apache.samoa.moa.core.FastVector; import java.util.ArrayList; +import java.util.Arrays; public class InstanceUtils { - static private InstancesHeader getNumericInstanceHeader(NumericDataInstance numericDataInstance) { - FastVector attributes = new FastVector<>(); - for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) { - attributes.addElement(new Attribute("numeric" + (i + 1))); - } + private static InstancesHeader getClusterInstanceHeader(ClusterDataInstance dataInstance) { + ArrayList attributes = new ArrayList<>(); - FastVector classLabels = new FastVector<>(); - for (int i = 0; i < numericDataInstance.getNumClasses(); i++) { - classLabels.addElement("class" + (i + 1)); + for (int i = 0; i < dataInstance.getNumberFeatures(); i++) { + attributes.add(new Attribute("att" + (i + 1))); } - attributes.addElement(new Attribute("class", classLabels)); + + // attributes.add(new Attribute("class", null)); InstancesHeader instancesHeader = new InstancesHeader( new Instances(null, attributes, 0)); @@ -50,34 +47,24 @@ static private InstancesHeader getNumericInstanceHeader(NumericDataInstance nume return instancesHeader; } - static private Instance convertNumericInstance(NumericDataInstance numericDataInstance) { - InstancesHeader header = InstanceUtils.getNumericInstanceHeader(numericDataInstance); - Instance inst = new DenseInstance(header.numAttributes()); - - for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) { - inst.setValue(i, numericDataInstance.getData()[i]); - } - - inst.setDataset(header); - inst.setClassValue(numericDataInstance.getTrueClass()); - - return inst; - } - - static private InstancesHeader getNominalInstanceHeader(NominalDataInstance nominalDataInstance) { + private static InstancesHeader getClassificationInstanceHeader(ClassificationDataInstance dataInstance) { FastVector attributes = new FastVector<>(); - for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) { + for (int i = 0; i < dataInstance.getNumberNominalFeatures(); i++) { FastVector nominalAttVals = new FastVector<>(); - for (int j = 0; j < nominalDataInstance.getNumValsPerNominal()[i]; j++) { + for (int j = 0; j < dataInstance.getNumberValsPerNominalFeature()[i]; j++) { nominalAttVals.addElement("value" + (j + 1)); } attributes.addElement(new Attribute("nominal" + (i + 1), nominalAttVals)); } + for (int i = 0; i < dataInstance.getNumberNumericFeatures(); i++) { + attributes.addElement(new Attribute("numeric" + (i + 1))); + } + FastVector classLabels = new FastVector<>(); - for (int i = 0; i < nominalDataInstance.getNumClasses(); i++) { + for (int i = 0; i < dataInstance.getNumberLabels(); i++) { classLabels.addElement("class" + (i + 1)); } attributes.addElement(new Attribute("class", classLabels)); @@ -89,51 +76,67 @@ static private InstancesHeader getNominalInstanceHeader(NominalDataInstance nomi return instancesHeader; } - static private Instance convertNominalInstance(NominalDataInstance nominalDataInstance) { - InstancesHeader header = InstanceUtils.getNominalInstanceHeader(nominalDataInstance); + /** + * convert ClassificationDataInstance to SAMOA Instance + */ + public static Instance convertClassificationDataInstance(ClassificationDataInstance dataInstance) { + InstancesHeader header = InstanceUtils.getClassificationInstanceHeader(dataInstance); Instance inst = new DenseInstance(header.numAttributes()); - for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) { - inst.setValue(i, nominalDataInstance.getData()[i]); + int numNomFeatures = dataInstance.getNumberNominalFeatures(); + int numNumFeatures = dataInstance.getNumberNumericFeatures(); + + for (int i = 0; i < numNomFeatures + numNumFeatures; i++) { + if (i < numNomFeatures) { + inst.setValue(i, dataInstance.getNominalData()[i]); + } else { + inst.setValue(i, dataInstance.getNumericData()[i - numNomFeatures]); + } } inst.setDataset(header); - inst.setClassValue(nominalDataInstance.getTrueClass()); + inst.setClassValue(dataInstance.getTrueLabel()); return inst; } - static private InstancesHeader getClusterInstanceHeader(ClusterDataInstance clusterDataInstance) { - ArrayList attributes = new ArrayList<>(); - - for (int i = 0; i < clusterDataInstance.getNumAtts(); i++) { - attributes.add(new Attribute("att" + (i + 1))); + /** + * convert SAMOA Instance to ClassificationDataInstance + */ + public static ClassificationDataInstance reConvertClassificationDataInstance( + Instance inst, int numberNominalFeatures, int numberNumericFeatures) { + double[] nominalDataTmp = Arrays.copyOfRange(inst.toDoubleArray(), 0, numberNominalFeatures); + int[] nominalData = new int[nominalDataTmp.length]; + for (int j = 0; j < nominalData.length; j++) { + nominalData[j] = (int) nominalDataTmp[j]; } - // attributes.add(new Attribute("class", null)); + double[] numericData = Arrays.copyOfRange( + inst.toDoubleArray(), numberNominalFeatures, + inst.toDoubleArray().length - 1); - InstancesHeader instancesHeader = new InstancesHeader( - new Instances(null, attributes, 0)); - instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + int[] numValsPerNominal = new int[nominalData.length]; + Arrays.fill(numValsPerNominal, inst.attribute(0).numValues()); - return instancesHeader; + return new ClassificationDataInstance( + numberNumericFeatures, numericData, numberNominalFeatures, + numValsPerNominal, nominalData, inst.numClasses(), (int) inst.classValue()); } - static private Instance convertClusterInstance(ClusterDataInstance clusterDataInstance) { - Instance inst = new DenseInstance(1.0, clusterDataInstance.getData()); - inst.setDataset(InstanceUtils.getClusterInstanceHeader(clusterDataInstance)); - return new DataPoint(inst, clusterDataInstance.getTimeStamp()); + /** + * convert ClusterDataInstance to SAMOA Instance + */ + public static Instance convertClusterDataInstance(ClusterDataInstance dataInstance) { + Instance inst = new DenseInstance(1.0, dataInstance.getData()); + inst.setDataset(InstanceUtils.getClusterInstanceHeader(dataInstance)); + return new DataPoint(inst, dataInstance.getTimeStamp()); } - static public Instance convertToSamoaInstance(DataInstance dataInstance) { - if (dataInstance instanceof NumericDataInstance) { - return InstanceUtils.convertNumericInstance((NumericDataInstance) dataInstance); - } else if (dataInstance instanceof NominalDataInstance) { - return InstanceUtils.convertNominalInstance((NominalDataInstance) dataInstance); - } else if (dataInstance instanceof ClusterDataInstance) { - return InstanceUtils.convertClusterInstance((ClusterDataInstance) dataInstance); - } else { - throw new Error("Invalid input class!"); - } + /** + * convert SAMOA Instance to ClusterDataInstance + */ + public static ClusterDataInstance reConvertClusterDataInstance(DataPoint point) { + double[] data = point.toDoubleArray(); + return new ClusterDataInstance(data.length, point.getTimestamp(), data); } } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java index bdf5d0d6..20088f15 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java @@ -21,10 +21,11 @@ */ import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.learners.classifiers.ClassificationModel; final public class ModelContentEvent implements ContentEvent { final private boolean isLast; - private Model model; + private ClassificationModel model; private long modelIndex; private long instanceIndex; private int classifierIndex; @@ -38,7 +39,7 @@ public ModelContentEvent(boolean isLast) { this.isLast = isLast; } - public ModelContentEvent(boolean isLast, Model model, long modelIndex, long instanceIndex, + public ModelContentEvent(boolean isLast, ClassificationModel model, long modelIndex, long instanceIndex, int classifierIndex, int evaluationIndex) { this.isLast = isLast; this.model = model; @@ -71,11 +72,11 @@ public void setModelIndex(long modelIndex) { this.modelIndex = modelIndex; } - public Model getModel() { + public ClassificationModel getModel() { return model; } - public void setModel(Model model) { + public void setModel(ClassificationModel model) { this.model = model; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java new file mode 100644 index 00000000..7e7ee87b --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java @@ -0,0 +1,76 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +public class ClassificationDataInstance implements Serializable { + private int numberNumericFeatures; + private double[] numericData; + + private int numberNominalFeatures; + private int[] numberValsPerNominalFeature; + private int[] nominalData; + + private int numberLabels; + private int trueLabel; + + public ClassificationDataInstance(int numberNumericFeatures, + double[] numericData, int numberNominalFeatures, + int[] numberValsPerNominalFeature, int[] nominalData, + int numberLabels, int trueLabel) { + this.numberNumericFeatures = numberNumericFeatures; + this.numericData = numericData; + this.numberNominalFeatures = numberNominalFeatures; + this.numberValsPerNominalFeature = numberValsPerNominalFeature; + this.nominalData = nominalData; + this.numberLabels = numberLabels; + this.trueLabel = trueLabel; + } + + public int getNumberNumericFeatures() { + return numberNumericFeatures; + } + + public double[] getNumericData() { + return numericData; + } + + public int getNumberNominalFeatures() { + return numberNominalFeatures; + } + + public int[] getNumberValsPerNominalFeature() { + return numberValsPerNominalFeature; + } + + public int[] getNominalData() { + return nominalData; + } + + public int getNumberLabels() { + return numberLabels; + } + + public int getTrueLabel() { + return trueLabel; + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java similarity index 79% rename from samoa-api/src/main/java/org/apache/samoa/learners/Model.java rename to samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java index 37e05617..8d8c1b5a 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java @@ -1,4 +1,4 @@ -package org.apache.samoa.learners; +package org.apache.samoa.learners.classifiers; /* * #%L @@ -9,9 +9,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,9 +20,8 @@ * #L% */ - import java.io.Serializable; -public interface Model extends Serializable { - double[] predict(DataInstance dataInstance); +public interface ClassificationModel extends Serializable { + double[] predict(ClassificationDataInstance dataInstance); } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java deleted file mode 100644 index fe3d747d..00000000 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalDataInstance.java +++ /dev/null @@ -1,81 +0,0 @@ -package org.apache.samoa.learners.classifiers; - -/* - * #%L - * SAMOA - * %% - * Copyright (C) 2014 - 2016 Apache Software Foundation - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ - -import org.apache.samoa.learners.DataInstance; - -public class NominalDataInstance implements DataInstance { - - private int numNominals; - private int numClasses; - private double trueClass; // index started from 0 - private int[] numValsPerNominal; - private double[] data; // index started from 0 - - public NominalDataInstance(int numNominals, int numClasses, double trueClass, - int[] numValsPerNominal, double[] data) { - this.numNominals = numNominals; - this.numClasses = numClasses; - this.trueClass = trueClass; - this.numValsPerNominal = numValsPerNominal; - this.data = data; - } - - public int getNumNominals() { - return numNominals; - } - - public void setNumNominals(int numNominals) { - this.numNominals = numNominals; - } - - public int getNumClasses() { - return numClasses; - } - - public void setNumClasses(int numClasses) { - this.numClasses = numClasses; - } - - public double getTrueClass() { - return trueClass; - } - - public void setTrueClass(double trueClass) { - this.trueClass = trueClass; - } - - public int[] getNumValsPerNominal() { - return numValsPerNominal; - } - - public void setNumValsPerNominal(int[] numValsPerNominal) { - this.numValsPerNominal = numValsPerNominal; - } - - public double[] getData() { - return data; - } - - public void setData(double[] data) { - this.data = data; - } -} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java deleted file mode 100644 index d28858a2..00000000 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericDataInstance.java +++ /dev/null @@ -1,69 +0,0 @@ -package org.apache.samoa.learners.classifiers; - -/* - * #%L - * SAMOA - * %% - * Copyright (C) 2014 - 2016 Apache Software Foundation - * %% - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * #L% - */ - -import org.apache.samoa.learners.DataInstance; - -public class NumericDataInstance implements DataInstance { - private int numNumerics; - private int numClasses; - private double trueClass; // index started from 0 - private double[] data; - - public NumericDataInstance(int numNumerics, int numClasses, double trueClass, double[] data) { - this.numNumerics = numNumerics; - this.numClasses = numClasses; - this.trueClass = trueClass; - this.data = data; - } - - public int getNumNumerics() { - return numNumerics; - } - - public void setNumNumerics(int numNumerics) { - this.numNumerics = numNumerics; - } - - public int getNumClasses() { - return numClasses; - } - - public void setNumClasses(int numClasses) { - this.numClasses = numClasses; - } - - public double getTrueClass() { - return trueClass; - } - - public void setTrueClass(double trueClass) { - this.trueClass = trueClass; - } - - public double[] getData() { - return data; - } - - public void setData(double[] data) { - this.data = data; - } -} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java index c4c0cc00..bf6e15c1 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java @@ -2,9 +2,9 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; -import org.apache.samoa.learners.DataInstance; import org.apache.samoa.learners.InstanceUtils; -import org.apache.samoa.learners.Model; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; +import org.apache.samoa.learners.classifiers.ClassificationModel; import org.apache.samoa.moa.core.DoubleVector; import java.util.ArrayList; @@ -29,17 +29,18 @@ * #L% */ -public class EnsembleModel implements Model { - private ArrayList modelList; +public class EnsembleModel implements ClassificationModel { + private ArrayList modelList; private ArrayList modelWeightList; - public EnsembleModel(ArrayList modelList, ArrayList modelWeightList) { + public EnsembleModel(ArrayList modelList, + ArrayList modelWeightList) { this.modelList = modelList; this.modelWeightList = modelWeightList; } @Override - public double[] predict(DataInstance dataInstance) { + public double[] predict(ClassificationDataInstance dataInstance) { DoubleVector combinedVote = new DoubleVector(); for (int i = 0; i < modelList.size(); i++) { double[] prediction = modelList.get(i).predict(dataInstance); @@ -53,11 +54,11 @@ public double[] predict(DataInstance dataInstance) { return combinedVote.getArrayCopy(); } - /* - Predict the class of an input data instance, and evaluate if it is the true class. + /** + * Predict the class of an input data instance, and evaluate if it is the true class. */ - public boolean evaluate(DataInstance dataInstance) { - Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); + public boolean evaluate(ClassificationDataInstance dataInstance) { + Instance inst = InstanceUtils.convertClassificationDataInstance(dataInstance); int trueClass = (int) inst.classValue(); double[] prediction = this.predict(dataInstance); int predictedClass = Utils.maxIndex(prediction); diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java index 0d9d22a4..aac42172 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -29,9 +29,9 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; -import org.apache.samoa.learners.Model; import org.apache.samoa.learners.ModelContentEvent; import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.ClassificationModel; import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.moa.core.SerializeUtils; import org.apache.samoa.topology.Stream; @@ -97,7 +97,7 @@ public void setEnsembleSize(int ensembleSize) { protected Map mapCountsforModelReceived; // for serialize - protected Map> mapModelListforModelReceived; // for serialize + protected Map> mapModelListforModelReceived; // for serialize protected Map> mapModelWeightListforModelReceived; // for serialize /** @@ -136,7 +136,7 @@ public boolean process(ContentEvent event) { // for serialize protected boolean processModel(ModelContentEvent event) { - Model model = event.getModel(); + ClassificationModel model = event.getModel(); long modelIndex = event.getModelIndex(); long instanceIndex = event.getInstanceIndex(); int classifierIndex = event.getClassifierIndex(); @@ -213,13 +213,13 @@ protected void addStatisticsForInstanceReceived(int instanceIndex, int classifie } //for serialize - protected void addStatisticsForModelReceived(long modelIndex, int classifierIndex, Model model, int add) { + protected void addStatisticsForModelReceived(long modelIndex, int classifierIndex, ClassificationModel model, int add) { if (this.mapCountsforModelReceived == null) { this.mapCountsforModelReceived = new HashMap<>(); this.mapModelListforModelReceived = new HashMap<>(); this.mapModelWeightListforModelReceived = new HashMap<>(); } - ArrayList modelList = this.mapModelListforModelReceived.get(modelIndex); + ArrayList modelList = this.mapModelListforModelReceived.get(modelIndex); if (modelList == null) { modelList = new ArrayList<>(); } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java index 061c66dc..389d74d2 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java @@ -21,16 +21,16 @@ */ import org.apache.samoa.instances.Instance; -import org.apache.samoa.learners.DataInstance; import org.apache.samoa.learners.InstanceUtils; -import org.apache.samoa.learners.Model; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; +import org.apache.samoa.learners.classifiers.ClassificationModel; import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; import java.util.List; -public class AMRulesModel implements Model { +public class AMRulesModel implements ClassificationModel { private List ruleSet; private ActiveRule defaultRule; private ErrorWeightedVote errorWeightedVote; @@ -45,8 +45,8 @@ public AMRulesModel(ActiveRule defaultRule, List ruleSet, } @Override - public double[] predict(DataInstance dataInstance) { - Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); + public double[] predict(ClassificationDataInstance dataInstance) { + Instance inst = InstanceUtils.convertClassificationDataInstance(dataInstance); double[] prediction; boolean predictionCovered = false; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java index b3398e1b..aa935570 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java @@ -24,11 +24,11 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Instances; import org.apache.samoa.instances.Utils; -import org.apache.samoa.learners.DataInstance; import org.apache.samoa.learners.InstanceUtils; -import org.apache.samoa.learners.Model; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; +import org.apache.samoa.learners.classifiers.ClassificationModel; -public class HoeffdingTreeModel implements Model { +public class HoeffdingTreeModel implements ClassificationModel { private Instances dataset; private Node treeRoot; @@ -38,9 +38,9 @@ public HoeffdingTreeModel(Instances dataset, Node treeRoot) { } @Override - public double[] predict(DataInstance dataInstance) { + public double[] predict(ClassificationDataInstance dataInstance) { double[] prediction; - Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); + Instance inst = InstanceUtils.convertClassificationDataInstance(dataInstance); // inst.setDataset(dataset); FoundNode foundNode; @@ -59,11 +59,11 @@ public double[] predict(DataInstance dataInstance) { return prediction; } - /* - Predict the class of an input data instance, and evaluate if it is the true class. + /** + * Predict the class of an input data instance, and evaluate if it is the true class. */ - public boolean evaluate(DataInstance dataInstance) { - Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance); + public boolean evaluate(ClassificationDataInstance dataInstance) { + Instance inst = InstanceUtils.convertClassificationDataInstance(dataInstance); int trueClass = (int) inst.classValue(); double[] prediction = this.predict(dataInstance); int predictedClass = Utils.maxIndex(prediction); diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java index 5bd34a4d..ee8ee748 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java @@ -21,16 +21,14 @@ */ -import org.apache.samoa.learners.DataInstance; import org.apache.samoa.learners.InstanceUtils; -import org.apache.samoa.learners.Model; import org.apache.samoa.moa.cluster.Clustering; import org.apache.samoa.moa.core.DataPoint; import org.apache.samoa.moa.evaluation.MeasureCollection; import java.util.ArrayList; -public class CluStreamModel implements Model { +public class CluStreamModel implements ClusterModel { private Clustering clustering; public CluStreamModel(Clustering clustering) { @@ -38,9 +36,9 @@ public CluStreamModel(Clustering clustering) { } @Override - public double[] predict(DataInstance dataInstance) { + public double[] predict(ClusterDataInstance dataInstance) { double[] distances = new double[clustering.size()]; - DataPoint dataPoint = (DataPoint) InstanceUtils.convertToSamoaInstance(dataInstance); + DataPoint dataPoint = (DataPoint) InstanceUtils.convertClusterDataInstance(dataInstance); for (int c = 0; c < clustering.size(); c++) { double distance = 0.0; double[] center = clustering.get(c).getCenter(); @@ -54,13 +52,13 @@ public double[] predict(DataInstance dataInstance) { return distances; } - /* - Given a list of data instances and a measure, evaluate the performance of the resulting cluster. + /** + * Given a list of data instances and a measure, evaluate the performance of the resulting cluster. */ - public double evaluate(ArrayList points, MeasureCollection measure) { + public double evaluate(ArrayList points, MeasureCollection measure) { ArrayList dataPoints = new ArrayList<>(); - for (DataInstance dataInstance : points) { - dataPoints.add((DataPoint) InstanceUtils.convertToSamoaInstance(dataInstance)); + for (ClusterDataInstance dataInstance : points) { + dataPoints.add((DataPoint) InstanceUtils.convertClusterDataInstance(dataInstance)); } try { diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java index b45cef92..4d388754 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java @@ -20,40 +20,28 @@ * #L% */ -import org.apache.samoa.learners.DataInstance; +import java.io.Serializable; -public class ClusterDataInstance implements DataInstance { - private int numAtts; +public class ClusterDataInstance implements Serializable { + private int numberFeatures; private int timeStamp; private double[] data; - public ClusterDataInstance(int numAtts, int timeStamp, double[] data) { - this.numAtts = numAtts; + public ClusterDataInstance(int numberFeatures, int timeStamp, double[] data) { + this.numberFeatures = numberFeatures; this.timeStamp = timeStamp; this.data = data; } - public int getNumAtts() { - return numAtts; - } - - public void setNumAtts(int numAtts) { - this.numAtts = numAtts; + public int getNumberFeatures() { + return numberFeatures; } public int getTimeStamp() { return timeStamp; } - public void setTimeStamp(int timeStamp) { - this.timeStamp = timeStamp; - } - public double[] getData() { return data; } - - public void setData(double[] data) { - this.data = data; - } } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java similarity index 80% rename from samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java rename to samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java index 99167360..dd5f1892 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/DataInstance.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java @@ -1,4 +1,4 @@ -package org.apache.samoa.learners; +package org.apache.samoa.learners.clusterers; /* * #%L @@ -9,9 +9,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,6 +22,6 @@ import java.io.Serializable; -public interface DataInstance extends Serializable { - +public interface ClusterModel extends Serializable { + double[] predict(ClusterDataInstance dataInstance); } diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java index 627c2770..6631708d 100644 --- a/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/AMRulesModelTest.java @@ -27,8 +27,8 @@ import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.apache.samoa.instances.Instance; -import org.apache.samoa.learners.DataInstance; -import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.InstanceUtils; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; import org.apache.samoa.learners.classifiers.rules.AMRulesModel; import org.apache.samoa.moa.core.SerializeUtils; import org.apache.samoa.tasks.Task; @@ -45,10 +45,13 @@ public class AMRulesModelTest extends TestCase { private static final String BASE_DIR = "amr"; private static final int NUM_MODEL_IN_DIR = 10; + private static final int numberNumericFeatures = 10; + private static final int numberNominalFeatures = 10; private static final String CLISTRING = "PrequentialEvaluation " + "-l (org.apache.samoa.learners.classifiers.rules.VerticalAMRulesRegressor -p 4) " + - "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; + "-s (generators.RandomTreeGenerator -c 2 -o " + + numberNominalFeatures + " -u " + numberNumericFeatures +")"; private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; @@ -88,10 +91,8 @@ public void testAMRulesModel() throws Exception { Instance inst = (Instance) SerializeUtils.readFromFile(fileData); System.out.println("=== model: " + i + " ==="); - double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); - - DataInstance dataInstance = new NumericDataInstance(data.length, - inst.numClasses(), inst.classValue(), data); + ClassificationDataInstance dataInstance = + InstanceUtils.reConvertClassificationDataInstance(inst, numberNominalFeatures, numberNumericFeatures); System.out.println(Arrays.toString(amRulesModel.predict(dataInstance))); System.out.println("true predict: " + (int) inst.classValue()); diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java index d9513e72..c73ae64e 100644 --- a/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/CluStreamModelTest.java @@ -27,7 +27,7 @@ import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.apache.samoa.evaluation.measures.SSQ; -import org.apache.samoa.learners.DataInstance; +import org.apache.samoa.learners.InstanceUtils; import org.apache.samoa.learners.clusterers.CluStreamModel; import org.apache.samoa.learners.clusterers.ClusterDataInstance; import org.apache.samoa.moa.core.DataPoint; @@ -85,10 +85,9 @@ public void testCluStreamModel() throws Exception { CluStreamModel cluStreamModel = (CluStreamModel) SerializeUtils.readFromFile(fileModel); assert points != null; - ArrayList dataInstances = new ArrayList<>(); + ArrayList dataInstances = new ArrayList<>(); for (DataPoint point : points) { - double[] data = point.toDoubleArray(); - DataInstance dataInstance = new ClusterDataInstance(data.length, point.getTimestamp(), data); + ClusterDataInstance dataInstance = InstanceUtils.reConvertClusterDataInstance(point); dataInstances.add(dataInstance); System.out.println(Arrays.toString(cluStreamModel.predict(dataInstance))); } diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java index 9c4244e2..63fe4cec 100644 --- a/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/EnsembleModelTest.java @@ -27,9 +27,8 @@ import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.apache.samoa.instances.Instance; -import org.apache.samoa.learners.DataInstance; -import org.apache.samoa.learners.classifiers.NominalDataInstance; -import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.InstanceUtils; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; import org.apache.samoa.learners.classifiers.ensemble.EnsembleModel; import org.apache.samoa.moa.core.SerializeUtils; import org.apache.samoa.tasks.Task; @@ -47,14 +46,13 @@ public class EnsembleModelTest extends TestCase { private static final String BASE_DIR = "vht"; private static final int NUM_MODEL_IN_DIR = 10; - private static final String CLISTRING_NUM = + private static final int numberNumericFeatures = 10; + private static final int numberNominalFeatures = 10; + private static final String CLISTRING = "PrequentialEvaluation -i 1000000 -f 100000 " + "-l (classifiers.ensemble.Bagging -s 10 -l (classifiers.trees.VerticalHoeffdingTree)) " + - "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; - private static final String CLISTRING_NOM = - "PrequentialEvaluation -i 1000000 -f 100000 " + - "-l (classifiers.ensemble.Bagging -s 10 -l (classifiers.trees.VerticalHoeffdingTree)) " + - "-s (generators.RandomTreeGenerator -c 2 -o 10 -u 0)"; + "-s (generators.RandomTreeGenerator -c 2 -o " + + numberNominalFeatures + " -u " + numberNumericFeatures +")"; private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; @@ -77,47 +75,13 @@ public void tearDown() { } @Test - public void testEnsembleNumber() throws Exception { - FileUtils.forceDeleteOnExit(new File(BASE_DIR)); - FileUtils.forceMkdir(new File(BASE_DIR)); - FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); - FileUtils.forceMkdir(new File(ENS_BASE_DIR)); - - Task task = ClassOption.cliStringToObject(CLISTRING_NUM, Task.class, extraOptions); - task.setFactory(new SimpleComponentFactory()); - task.init(); - SimpleEngine.submitTopology(task.getTopology()); - - for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { - File fileModel = new File(ENS_BASE_DIR + "/bagging-model-" + i); - File fileData = new File(BASE_DIR + "/vht-data-0-" + i); - - EnsembleModel htm = (EnsembleModel) SerializeUtils.readFromFile(fileModel); - Instance inst = (Instance) SerializeUtils.readFromFile(fileData); - System.out.println("=== model: " + i + " ==="); - - double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); - - DataInstance dataInstance = new NumericDataInstance(data.length, - inst.numClasses(), inst.classValue(), data); - - System.out.println(Arrays.toString(htm.predict(dataInstance))); - System.out.println("true predict: " + (int) inst.classValue()); - System.out.println("predict: " + htm.evaluate(dataInstance)); - } - - FileUtils.forceDeleteOnExit(new File(BASE_DIR)); - FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); - } - - @Test - public void testEnsembleNominal() throws Exception { + public void testEnsembleModel() throws Exception { FileUtils.forceDeleteOnExit(new File(BASE_DIR)); FileUtils.forceMkdir(new File(BASE_DIR)); FileUtils.forceDeleteOnExit(new File(ENS_BASE_DIR)); FileUtils.forceMkdir(new File(ENS_BASE_DIR)); - Task task = ClassOption.cliStringToObject(CLISTRING_NOM, Task.class, extraOptions); + Task task = ClassOption.cliStringToObject(CLISTRING, Task.class, extraOptions); task.setFactory(new SimpleComponentFactory()); task.init(); SimpleEngine.submitTopology(task.getTopology()); @@ -126,21 +90,16 @@ public void testEnsembleNominal() throws Exception { File fileModel = new File(ENS_BASE_DIR + "/bagging-model-" + i); File fileData = new File(BASE_DIR + "/vht-data-0-" + i); - EnsembleModel htm = (EnsembleModel) SerializeUtils.readFromFile(fileModel); + EnsembleModel em = (EnsembleModel) SerializeUtils.readFromFile(fileModel); Instance inst = (Instance) SerializeUtils.readFromFile(fileData); System.out.println("=== model: " + i + " ==="); - double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); - - int[] numValsPerNominal = new int[data.length]; - Arrays.fill(numValsPerNominal, inst.attribute(0).numValues()); - - DataInstance dataInstance = new NominalDataInstance(data.length, inst.numClasses(), - inst.classValue(), numValsPerNominal, data); + ClassificationDataInstance dataInstance = + InstanceUtils.reConvertClassificationDataInstance(inst,numberNominalFeatures, numberNumericFeatures); - System.out.println(Arrays.toString(htm.predict(dataInstance))); + System.out.println(Arrays.toString(em.predict(dataInstance))); System.out.println("true predict: " + (int) inst.classValue()); - System.out.println("predict: " + htm.evaluate(dataInstance)); + System.out.println("predict: " + em.evaluate(dataInstance)); } FileUtils.forceDeleteOnExit(new File(BASE_DIR)); diff --git a/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java b/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java index ad5fed59..575a396c 100644 --- a/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/serialize/HoeffdingTreeModelTest.java @@ -27,9 +27,8 @@ import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.apache.samoa.instances.Instance; -import org.apache.samoa.learners.DataInstance; -import org.apache.samoa.learners.classifiers.NominalDataInstance; -import org.apache.samoa.learners.classifiers.NumericDataInstance; +import org.apache.samoa.learners.InstanceUtils; +import org.apache.samoa.learners.classifiers.ClassificationDataInstance; import org.apache.samoa.learners.classifiers.trees.HoeffdingTreeModel; import org.apache.samoa.moa.core.SerializeUtils; import org.apache.samoa.tasks.Task; @@ -46,14 +45,13 @@ public class HoeffdingTreeModelTest extends TestCase { private static final String BASE_DIR = "vht"; private static final int NUM_MODEL_IN_DIR = 10; - private static final String CLISTRING_NUM = + private static final int numberNumericFeatures = 10; + private static final int numberNominalFeatures = 10; + private static final String CLISTRING = "PrequentialEvaluation -i 1000000 -f 100000 " + "-l (classifiers.trees.VerticalHoeffdingTree -p 4) " + - "-s (generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; - private static final String CLISTRING_NOM = - "PrequentialEvaluation -i 1000000 -f 100000 " + - "-l (classifiers.trees.VerticalHoeffdingTree -p 4) " + - "-s (generators.RandomTreeGenerator -c 2 -o 10 -u 0)"; + "-s (generators.RandomTreeGenerator -c 2 -o " + + numberNominalFeatures + " -u " + numberNumericFeatures +")"; private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr."; private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout."; @@ -76,42 +74,11 @@ public void tearDown() { } @Test - public void testVHTNumber() throws Exception { - FileUtils.forceDeleteOnExit(new File(BASE_DIR)); - FileUtils.forceMkdir(new File(BASE_DIR)); - - Task task = ClassOption.cliStringToObject(CLISTRING_NUM, Task.class, extraOptions); - task.setFactory(new SimpleComponentFactory()); - task.init(); - SimpleEngine.submitTopology(task.getTopology()); - - for (int i = 0; i < NUM_MODEL_IN_DIR; i++) { - File fileModel = new File(BASE_DIR + "/vht-model-0-" + i); - File fileData = new File(BASE_DIR + "/vht-data-0-" + i); - - HoeffdingTreeModel htm = (HoeffdingTreeModel) SerializeUtils.readFromFile(fileModel); - Instance inst = (Instance) SerializeUtils.readFromFile(fileData); - System.out.println("=== model: " + i + " ==="); - - double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); - - DataInstance dataInstance = new NumericDataInstance(data.length, - inst.numClasses(), inst.classValue(), data); - - System.out.println(Arrays.toString(htm.predict(dataInstance))); - System.out.println("true predict: " + (int) inst.classValue()); - System.out.println("predict: " + htm.evaluate(dataInstance)); - } - - FileUtils.forceDeleteOnExit(new File(BASE_DIR)); - } - - @Test - public void testVHTNominal() throws Exception { + public void testHoeffdingTreeModel() throws Exception { FileUtils.forceDeleteOnExit(new File(BASE_DIR)); FileUtils.forceMkdir(new File(BASE_DIR)); - Task task = ClassOption.cliStringToObject(CLISTRING_NOM, Task.class, extraOptions); + Task task = ClassOption.cliStringToObject(CLISTRING, Task.class, extraOptions); task.setFactory(new SimpleComponentFactory()); task.init(); SimpleEngine.submitTopology(task.getTopology()); @@ -124,13 +91,8 @@ public void testVHTNominal() throws Exception { Instance inst = (Instance) SerializeUtils.readFromFile(fileData); System.out.println("=== model: " + i + " ==="); - double[] data = Arrays.copyOfRange(inst.toDoubleArray(), 0, inst.toDoubleArray().length - 1); - - int[] numValsPerNominal = new int[data.length]; - Arrays.fill(numValsPerNominal, inst.attribute(0).numValues()); - - DataInstance dataInstance = new NominalDataInstance(data.length, inst.numClasses(), - inst.classValue(), numValsPerNominal, data); + ClassificationDataInstance dataInstance = + InstanceUtils.reConvertClassificationDataInstance(inst,numberNominalFeatures, numberNumericFeatures); System.out.println(Arrays.toString(htm.predict(dataInstance))); System.out.println("true predict: " + (int) inst.classValue()); From 03d0538b18ba5ef7fadf60de6dcce3a641460bff Mon Sep 17 00:00:00 2001 From: pangolulu Date: Fri, 4 Mar 2016 14:18:26 +0800 Subject: [PATCH 4/5] add comments --- .../learners/classifiers/ClassificationDataInstance.java | 4 ++++ .../samoa/learners/classifiers/ClassificationModel.java | 3 +++ .../apache/samoa/learners/clusterers/ClusterDataInstance.java | 4 ++++ .../org/apache/samoa/learners/clusterers/ClusterModel.java | 3 +++ 4 files changed, 14 insertions(+) diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java index 7e7ee87b..d2427acb 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationDataInstance.java @@ -22,6 +22,10 @@ import java.io.Serializable; +/** + * DataInstance for classification problem + * There may be tow types of feature in feature vector: numeric feature and nominal feature + */ public class ClassificationDataInstance implements Serializable { private int numberNumericFeatures; private double[] numericData; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java index 8d8c1b5a..eb37dd5c 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ClassificationModel.java @@ -22,6 +22,9 @@ import java.io.Serializable; +/** + * Model for classification problem + */ public interface ClassificationModel extends Serializable { double[] predict(ClassificationDataInstance dataInstance); } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java index 4d388754..9b74b7f3 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterDataInstance.java @@ -22,6 +22,10 @@ import java.io.Serializable; +/** + * DataInstance for cluster problem + * The feature type is numeric + */ public class ClusterDataInstance implements Serializable { private int numberFeatures; private int timeStamp; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java index dd5f1892..4dc519f4 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterModel.java @@ -22,6 +22,9 @@ import java.io.Serializable; +/** + * Model for cluster problem + */ public interface ClusterModel extends Serializable { double[] predict(ClusterDataInstance dataInstance); } From a32ae83d8282c2e7f6498166eecbcc154e1279cc Mon Sep 17 00:00:00 2001 From: pangolulu Date: Mon, 7 Mar 2016 10:22:08 +0800 Subject: [PATCH 5/5] rename ModelContentEvent --- .../apache/samoa/evaluation/EvaluatorProcessor.java | 4 ++-- ...Event.java => ClassificationModelContentEvent.java} | 10 +++++----- .../ensemble/BoostingPredictionCombinerProcessor.java | 6 +++--- .../ensemble/PredictionCombinerProcessor.java | 8 ++++---- .../classifiers/trees/ModelAggregatorProcessor.java | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) rename samoa-api/src/main/java/org/apache/samoa/learners/{ModelContentEvent.java => ClassificationModelContentEvent.java} (86%) diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java index d174dea6..6301b18d 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java @@ -31,7 +31,7 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; -import org.apache.samoa.learners.ModelContentEvent; +import org.apache.samoa.learners.ClassificationModelContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.core.Measurement; import org.apache.samoa.moa.evaluation.LearningCurve; @@ -74,7 +74,7 @@ private EvaluatorProcessor(Builder builder) { @Override public boolean process(ContentEvent event) { // for serialize - if (event instanceof ModelContentEvent) { + if (event instanceof ClassificationModelContentEvent) { return false; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/ClassificationModelContentEvent.java similarity index 86% rename from samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java rename to samoa-api/src/main/java/org/apache/samoa/learners/ClassificationModelContentEvent.java index 20088f15..c125f941 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/ModelContentEvent.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/ClassificationModelContentEvent.java @@ -23,7 +23,7 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.learners.classifiers.ClassificationModel; -final public class ModelContentEvent implements ContentEvent { +final public class ClassificationModelContentEvent implements ContentEvent { final private boolean isLast; private ClassificationModel model; private long modelIndex; @@ -31,16 +31,16 @@ final public class ModelContentEvent implements ContentEvent { private int classifierIndex; private int evaluationIndex; - public ModelContentEvent() { + public ClassificationModelContentEvent() { this.isLast = false; } - public ModelContentEvent(boolean isLast) { + public ClassificationModelContentEvent(boolean isLast) { this.isLast = isLast; } - public ModelContentEvent(boolean isLast, ClassificationModel model, long modelIndex, long instanceIndex, - int classifierIndex, int evaluationIndex) { + public ClassificationModelContentEvent(boolean isLast, ClassificationModel model, long modelIndex, long instanceIndex, + int classifierIndex, int evaluationIndex) { this.isLast = isLast; this.model = model; this.modelIndex = modelIndex; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java index b7bc995b..58e8c4b7 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java @@ -30,7 +30,7 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.instances.Instance; import org.apache.samoa.learners.InstanceContentEvent; -import org.apache.samoa.learners.ModelContentEvent; +import org.apache.samoa.learners.ClassificationModelContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.moa.core.Utils; @@ -59,8 +59,8 @@ public class BoostingPredictionCombinerProcessor extends PredictionCombinerProce @Override public boolean process(ContentEvent event) { // for serialize - if (event instanceof ModelContentEvent) { - return this.processModel((ModelContentEvent) event); + if (event instanceof ClassificationModelContentEvent) { + return this.processModel((ClassificationModelContentEvent) event); } ResultContentEvent inEvent = (ResultContentEvent) event; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java index aac42172..fa9d7507 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -29,7 +29,7 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; -import org.apache.samoa.learners.ModelContentEvent; +import org.apache.samoa.learners.ClassificationModelContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.learners.classifiers.ClassificationModel; import org.apache.samoa.moa.core.DoubleVector; @@ -109,8 +109,8 @@ public void setEnsembleSize(int ensembleSize) { */ public boolean process(ContentEvent event) { // for serialize - if (event instanceof ModelContentEvent) { - return this.processModel((ModelContentEvent) event); + if (event instanceof ClassificationModelContentEvent) { + return this.processModel((ClassificationModelContentEvent) event); } ResultContentEvent inEvent = (ResultContentEvent) event; @@ -135,7 +135,7 @@ public boolean process(ContentEvent event) { } // for serialize - protected boolean processModel(ModelContentEvent event) { + protected boolean processModel(ClassificationModelContentEvent event) { ClassificationModel model = event.getModel(); long modelIndex = event.getModelIndex(); long instanceIndex = event.getInstanceIndex(); diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java index 9aabdbd3..06f479b9 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java @@ -41,7 +41,7 @@ import org.apache.samoa.instances.InstancesHeader; import org.apache.samoa.learners.InstanceContent; import org.apache.samoa.learners.InstancesContentEvent; -import org.apache.samoa.learners.ModelContentEvent; +import org.apache.samoa.learners.ClassificationModelContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; @@ -321,7 +321,7 @@ private void processInstances(InstancesContentEvent instContentEvent) { e.printStackTrace(); } - this.modelStream.put(new ModelContentEvent( + this.modelStream.put(new ClassificationModelContentEvent( instContent.isLastEvent(), hoeffdingTreeModel, modelIndex, instContent.getInstanceIndex(), processorId, instContent.getEvaluationIndex())); // for serialize