From c3659b87cd8a5a413994301649900274c4b6305a Mon Sep 17 00:00:00 2001 From: saliherdemk Date: Sat, 19 Oct 2024 15:25:54 +0300 Subject: [PATCH] digitouput added --- index.html | 3 +- js/Draw/MLP/Playable.js | 25 ++++--- js/Draw/Minors/Viewers/CalculationViewer.js | 3 +- js/Draw/NeuronView.js | 12 +++- .../{DigitComponent.js => DigitInput.js} | 2 +- js/HandWritten/DigitOutput.js | 72 +++++++++++++++++++ js/script.js | 3 +- 7 files changed, 105 insertions(+), 15 deletions(-) rename js/HandWritten/{DigitComponent.js => DigitInput.js} (97%) create mode 100644 js/HandWritten/DigitOutput.js diff --git a/index.html b/index.html index 5303638..f161dd5 100644 --- a/index.html +++ b/index.html @@ -375,7 +375,8 @@

Grada

- + + diff --git a/js/Draw/MLP/Playable.js b/js/Draw/MLP/Playable.js index f5815d8..bfedfe4 100644 --- a/js/Draw/MLP/Playable.js +++ b/js/Draw/MLP/Playable.js @@ -144,7 +144,7 @@ class Playable extends Draggable { this.updateStatus( +( (this.getInput() instanceof InputLayer || - this.getInput() instanceof DigitComponent) && + this.getInput() instanceof DigitInput) && (this.isEval() || this.getOutput() instanceof OutputLayer) ), ); @@ -310,21 +310,30 @@ class Playable extends Draggable { let startTime = performance.now(); const inputData = this.getInput().getData(); const origin = this.origin; - this.calculationComponent?.setInputData( - inputData.slice(0, 5).map((row) => row.slice(0, 4)), - [inputData.length, inputData[0].length], - ); - const outputData = this.getOutput()?.getData() ?? null; + const calcComp = this.calculationComponent; + const output = this.getOutput(); + + calcComp?.setInputData(inputData.slice(0, 5).map((row) => row.slice(0, 4))); + + const outputData = output instanceof OutputLayer ? output.getData() : null; const mlp_output = await origin.forward(inputData); - this.calculationComponent && this.setCalculationData(); + + calcComp && this.setCalculationData(); outputData && (await origin.backward(mlp_output, outputData)); + this.updateParameters(); this.setGraphComponentData(); this.setMsPerStepText(performance.now() - startTime + "ms / step"); + + if (output instanceof DigitOutput) { + const lastLayer = origin.layers[origin.layers.length - 1]; + output.setData(lastLayer.outputs.data[0]); + } } fetchNext() { this.getInput().fetchNext(); - this.getOutput()?.fetchNext(); + const output = this.getOutput(); + output instanceof OutputLayer && output.fetchNext(); } } diff --git a/js/Draw/Minors/Viewers/CalculationViewer.js b/js/Draw/Minors/Viewers/CalculationViewer.js index 6842695..c503563 100644 --- a/js/Draw/Minors/Viewers/CalculationViewer.js +++ b/js/Draw/Minors/Viewers/CalculationViewer.js @@ -18,7 +18,8 @@ class CalculationViewer extends Viewer { this.line.from.parent.removeCalculationComponent(); } - setInputData(data, shape) { + setInputData(data) { + const shape = [data.length, data[0].length]; this.data = [this.formatMatrix(data, shape, 0, 0)]; } diff --git a/js/Draw/NeuronView.js b/js/Draw/NeuronView.js index 47e79a8..e99901d 100644 --- a/js/Draw/NeuronView.js +++ b/js/Draw/NeuronView.js @@ -2,6 +2,7 @@ class NeuronView { constructor() { this.x; this.y; + this.color = themeManager.getTheme("white").activeColor; this.output = ""; this.bias = ""; this.biasGrad = ""; @@ -10,8 +11,13 @@ class NeuronView { this.r = 25; } - setOutput(o) { - this.output = o.toFixed(2).toString(); + setColor(color) { + this.color = themeManager.getTheme(color).defaultColor; + } + + setOutput(o, int = false) { + const output = int ? parseInt(o) : o.toFixed(2); + this.output = output.toString(); } setBias(b) { @@ -85,7 +91,7 @@ class NeuronView { show() { const commands = [ - { func: "fill", args: [255] }, + { func: "fill", args: [this.color] }, { func: "circle", args: [this.x, this.y, this.r] }, { func: "textAlign", args: [CENTER, CENTER] }, { func: "textSize", args: [8] }, diff --git a/js/HandWritten/DigitComponent.js b/js/HandWritten/DigitInput.js similarity index 97% rename from js/HandWritten/DigitComponent.js rename to js/HandWritten/DigitInput.js index 77f111c..598a6cb 100644 --- a/js/HandWritten/DigitComponent.js +++ b/js/HandWritten/DigitInput.js @@ -1,4 +1,4 @@ -class DigitComponent extends Component { +class DigitInput extends Component { constructor(x, y) { super(x, y, 550); this.shrank = true; diff --git a/js/HandWritten/DigitOutput.js b/js/HandWritten/DigitOutput.js new file mode 100644 index 0000000..e7c7634 --- /dev/null +++ b/js/HandWritten/DigitOutput.js @@ -0,0 +1,72 @@ +class DigitOutput extends Component { + constructor(x, y) { + super(x, y, 50); + this.initialize(); + } + + initialize() { + this.outputDot.destroy(); + this.outputDot = null; + this.inputDot.setColor("cyan"); + this.adjustNeuronNum(10); + this.setShownNeuronsNum(10); + this.neurons.forEach((n, i) => n.setOutput(i, true)); + } + + setData(data) { + let maxIndex = 0; + data.forEach((val, idx) => { + if (val > data[maxIndex]) maxIndex = idx; + }); + this.neurons.forEach((n, i) => { + n.setColor(i == maxIndex ? "green" : "white"); + }); + } + + // FIXME: maybe we can merge those functions into a base class for output + fetchNext() {} + + connectLayer(targetLayer) { + const isEqual = this.getNeuronNum() == targetLayer.getNeuronNum(); + if (!isEqual) return; + + this.connectNeurons(targetLayer); + } + + connectNeurons(targetLayer) { + targetLayer.neurons.forEach((n1, i) => { + n1.removeLines(); + n1.addLine(new Line(n1, this.neurons[i])); + }); + this.inputDot.occupy(); + targetLayer.outputDot.occupy(); + targetLayer.parent.setOutputComponent(this); + this.connected = targetLayer; + } + + clearLines() { + this.connected.clearLines(this); + this.connected.parent.clearOutput(); + this.connected = null; + } + + show() { + const middleX = this.x + this.w / 2; + const commands = [ + { func: "rect", args: [this.x, this.y, this.w, this.h, 10] }, + { func: "textAlign", args: [CENTER, CENTER] }, + { + func: "text", + args: ["Grid Output", middleX, this.y - 10], + }, + ]; + + executeDrawingCommands(commands); + } + + draw() { + super.draw(); + this.show(); + this.neurons.forEach((neuron) => neuron.draw()); + } +} diff --git a/js/script.js b/js/script.js index d90ea6b..d0629d3 100644 --- a/js/script.js +++ b/js/script.js @@ -1,7 +1,8 @@ document.addEventListener("contextmenu", (event) => event.preventDefault()); function createHandWrittenInput() { - mainOrganizer.addComponent(new DigitComponent(100, 100)); + mainOrganizer.addComponent(new DigitInput(100, 100)); + mainOrganizer.addComponent(new DigitOutput(800, 100)); } function createLayer() {