diff --git a/js/Draw/MLP/MLPView.js b/js/Draw/MLP/MLPView.js index 328dc96..41f5bc2 100644 --- a/js/Draw/MLP/MLPView.js +++ b/js/Draw/MLP/MLPView.js @@ -16,6 +16,10 @@ class MlpView extends Playable { this.selected = false; } + isEval() { + return this.mode == "eval"; + } + getAllParameters() { let layersParameters = []; @@ -90,7 +94,10 @@ class MlpView extends Playable { handleSetMode(mode) { this.setMode(mode); - this.isInitialized() && this.checkCompleted(); + if (this.isInitialized()) { + this.setGraphComponentData(); + this.checkCompleted(); + } } handleSetZenMode(mode) { @@ -407,7 +414,7 @@ class MlpView extends Playable { { func: "text", args: [`Total Parameters: ${totalParams}\n`, x, y - 10] }, ]; - const commands = this.getMode() == "train" ? this.getTrainCommands() : []; + const commands = this.isEval() ? [] : this.getTrainCommands(); executeDrawingCommands([...commands, ...common]); } diff --git a/js/Draw/MLP/Playable.js b/js/Draw/MLP/Playable.js index 1cfd4a4..250fca0 100644 --- a/js/Draw/MLP/Playable.js +++ b/js/Draw/MLP/Playable.js @@ -75,8 +75,14 @@ class Playable extends Draggable { this.x - (graphComponent.w - this.w) / 2, this.y - 400, ); - graphComponent.setData(this.origin.getLossData()); this.graphComponent = graphComponent; + this.setGraphComponentData(); + } + + setGraphComponentData() { + this.graphComponent?.setData( + this.isEval() ? this.origin.getEvalLoss() : this.origin.getTrainLoss(), + ); } removeGraphComponent() { @@ -134,13 +140,12 @@ class Playable extends Draggable { } checkCompleted() { - const isEval = this.getMode() == "eval"; this.pause(); this.updateStatus( +( (this.getInput() instanceof InputLayer || this.getInput() instanceof DigitComponent) && - (isEval || this.getOutput() instanceof OutputLayer) + (this.isEval() || this.getOutput() instanceof OutputLayer) ), ); @@ -304,16 +309,17 @@ class Playable extends Draggable { async executeOnce() { let startTime = performance.now(); const inputData = this.getInput().getData(); + const origin = this.origin; this.calculationComponent?.setInputData( inputData.slice(0, 5).map((row) => row.slice(0, 5)), [parseInt(this.batchSize), this.getInput().shape[1]], ); const outputData = this.getOutput()?.getData() ?? null; - const mlp_output = await this.origin.forward(inputData); + const mlp_output = await origin.forward(inputData); this.calculationComponent && this.setCalculationData(); - await this.origin.backward(mlp_output, outputData); + outputData && (await origin.backward(mlp_output, outputData)); this.updateParameters(); - this.graphComponent?.setData(this.origin.getLossData()); + this.setGraphComponentData(); this.setMsPerStepText(performance.now() - startTime + "ms / step"); } diff --git a/js/MLP/Mlp.js b/js/MLP/Mlp.js index 57fba82..d139aa2 100644 --- a/js/MLP/Mlp.js +++ b/js/MLP/Mlp.js @@ -52,12 +52,15 @@ class MLP extends MlpParams { } backward(mlp_output, y_batch) { - if (this.mode == "eval") return; const loss = errFuncManager.getFunction(this.errFunc)( mlp_output, new Tensor(y_batch), ); - this.addLossData(loss.data[0][0]); + if (this.mode == "eval") { + this.addEvalLoss(loss.data[0][0]); + return; + } + this.addTrainLoss(loss.data[0][0]); this.zeroGrad(); loss.backward(); @@ -74,14 +77,16 @@ class MLP extends MlpParams { biases: biases.map((b) => b.data), seenRecordNum: this.seenRecordNum, stepCounter: this.stepCounter, - lossData: this.lossData, + trainLoss: this.trainLoss, + evalLoss: this.evalLoss, }; } - import({ weights, biases, seenRecordNum, stepCounter, lossData }) { + import({ weights, biases, seenRecordNum, stepCounter, trainLoss, evalLoss }) { this.seenRecordNum = seenRecordNum; this.stepCounter = stepCounter; - this.lossData = lossData; + this.trainLoss = trainLoss; + this.evalLoss = evalLoss; for (let i = 0; i < this.layers.length; i++) { this.layers[i].setParameters(weights[i], biases[i]); } diff --git a/js/MLP/MlpParams.js b/js/MLP/MlpParams.js index c0bd120..af91c25 100644 --- a/js/MLP/MlpParams.js +++ b/js/MLP/MlpParams.js @@ -9,15 +9,24 @@ class MlpParams { this.epoch = 0; this.seenRecordNum = 0; this.mode = "train"; - this.lossData = []; + this.trainLoss = []; + this.evalLoss = []; } - getLossData() { - return this.lossData; + getTrainLoss() { + return this.trainLoss; } - addLossData(lossData) { - this.lossData.push(lossData); + getEvalLoss() { + return this.evalLoss; + } + + addTrainLoss(lossData) { + this.trainLoss.push(lossData); + } + + addEvalLoss(lossData) { + this.evalLoss.push(lossData); } setLr(lr) {