Skip to content

Commit

Permalink
Merge epoch listener with file listener
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 14, 2017
1 parent 7332d31 commit 597f363
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import weka.dl4j.layers.OutputLayer;
import weka.dl4j.layers.SubsamplingLayer;
import weka.dl4j.listener.EpochListener;
import weka.dl4j.listener.FileIterationListener;
import weka.dl4j.zoo.CustomNet;
import weka.dl4j.zoo.FaceNetNN4Small2;
import weka.dl4j.zoo.GoogLeNet;
Expand Down Expand Up @@ -579,7 +578,11 @@ public void initializeClassifier(Instances data) throws Exception {

// Set the iteration listener
m_model.setListeners(getListener());
m_earlyStopping.init(m_valIterator);

// Init early stopping
if (useEarlyStopping()){
m_earlyStopping.init(m_valIterator);
}

m_NumEpochsPerformed = 0;
} finally {
Expand Down Expand Up @@ -826,21 +829,13 @@ private List<IterationListener> getListener() throws Exception {
List<IterationListener> listeners = new ArrayList<>();

// Initialize weka listener
if (m_iterationListener instanceof weka.dl4j.listener.IterationListener) {
if (m_iterationListener instanceof weka.dl4j.listener.EpochListener) {
int numEpochs = getNumEpochs();
((weka.dl4j.listener.IterationListener) m_iterationListener).init
((weka.dl4j.listener.EpochListener) m_iterationListener).init
(m_trainData.numClasses(), numEpochs, numSamples, m_trainIterator, m_valIterator);
((weka.dl4j.listener.EpochListener) m_iterationListener).setLogFile(m_logFile);
listeners.add(m_iterationListener);
}


// if the log file doesn't point to a directory, set up the listener
if (getLogFile() != null && !getLogFile().isDirectory()) {
FileIterationListener fil = new FileIterationListener(getLogFile().getAbsolutePath());
fil.init(m_trainData.numClasses(), getNumEpochs(), numSamples, m_trainIterator, m_valIterator);
listeners.add(fil);
}

return listeners;
}

Expand Down
87 changes: 80 additions & 7 deletions package/src/main/java/weka/dl4j/listener/EpochListener.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package weka.dl4j.listener;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Model;
Expand All @@ -8,9 +9,12 @@
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.OptionMetadata;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Map;

Expand All @@ -20,19 +24,41 @@
*
* @author Steven Lang
*/
@Slf4j
public class EpochListener extends IterationListener implements TrainingListener {
private static final Logger log = LoggerFactory.getLogger(weka.dl4j.listener.EpochListener.class);
private static final long serialVersionUID = -8852994767947925554L;

/**
* Epoch counter
*/
private int currentEpoch = 0;

/**
* Evaluate every N epochs
*/
private int n = 5;

/**
* Log to this file if set
*/
private transient PrintWriter logFile;

@Override
public void onEpochEnd(Model model) {
currentEpoch++;

// Skip if this is not an evaluation epoch
if (currentEpoch % n !=0){
return;
}

String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";
s += "Train: " + evaluateDataSetIterator(model, trainIterator);
if (validationIterator != null){

if (validationIterator != null) {
s += "Validation: " + evaluateDataSetIterator(model, validationIterator);
}

log(s);
}

Expand All @@ -49,7 +75,9 @@ private String evaluateDataSetIterator(Model model, DataSetIterator iterator) {
Evaluation cEval = new Evaluation(numClasses);
RegressionEvaluation rEval = new RegressionEvaluation(1);
while (iterator.hasNext()) {
DataSet next = iterator.next();
// TODO: figure out which batch size is feasible for inference
final int batch = iterator.batch() * 8;
DataSet next = iterator.next(batch);
scoreSum += net.score(next);
iterations++;
INDArray output = net.outputSingle(next.getFeatureMatrix()); //get the networks prediction
Expand All @@ -58,8 +86,8 @@ private String evaluateDataSetIterator(Model model, DataSetIterator iterator) {
}

double score = 0;
if (iterations != 0){
score = scoreSum/iterations;
if (iterations != 0) {
score = scoreSum / iterations;
}
if (isClassification) {
s += String.format("Accuracy: %4.2f%%", cEval.accuracy() * 100);
Expand All @@ -71,6 +99,9 @@ private String evaluateDataSetIterator(Model model, DataSetIterator iterator) {
}
} catch (UnsupportedOperationException e) {
return "Validation set is too small and does not contain all labels.";
} catch (Exception e) {
log.error("Evaluation after epoch failed. Error: ", e);
return "Not available";
} finally {
iterator.reset();
}
Expand All @@ -79,9 +110,24 @@ private String evaluateDataSetIterator(Model model, DataSetIterator iterator) {
return s;
}

/**
* Set the log file
*
* @param logFile Logging file
*/
public void setLogFile(File logFile) throws IOException {
if (logFile.exists()) logFile.delete();
System.out.println("Creating debug file at: " + logFile.getAbsolutePath());
this.logFile = new PrintWriter(new FileWriter(logFile, false));
}

@Override
public void log(String msg) {
log.info(msg);
if (logFile != null){
logFile.write(msg + "\n");
logFile.flush();
}
}

@Override
Expand Down Expand Up @@ -111,4 +157,31 @@ public void onGradientCalculation(Model model) {
public void onBackwardPass(Model model) {

}

@OptionMetadata(
displayName = "evaluate every N epochs",
description = "Evaluate every N epochs (default = 5).",
commandLineParamName = "n", commandLineParamSynopsis = "-n <int>",
displayOrder = 0)
public void setN(int evaluateEveryNEpochs) {
if (evaluateEveryNEpochs < 1) {
// Never evaluate
this.n = Integer.MAX_VALUE;
}

this.n = evaluateEveryNEpochs;
}

public int getN(){return this.n;}

/**
* Returns a string describing this search method
*
* @return a description of the search method suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "A listener which evaluates the model while training every N " +
"epochs.";
}
}

This file was deleted.

0 comments on commit 597f363

Please sign in to comment.