From 88eb579be20c36d26d61ac484de07f327c22b30f Mon Sep 17 00:00:00 2001 From: Chizkiyahu Date: Fri, 6 Oct 2023 16:07:21 +0300 Subject: [PATCH] keras fix layer sharing inputs and outputs names fix layer shring inputs and outputs --- source/keras.js | 46 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/source/keras.js b/source/keras.js index 30f765c7db..3f6b088099 100644 --- a/source/keras.js +++ b/source/keras.js @@ -788,21 +788,54 @@ keras.Node = class { const innerType = this.inner ? this.inner.type : null; const innerMetadata = innerType ? metadata.type(innerType) : null; + + // handle layer sharing let inputIndex = 0; + let inputTypes = []; + const outputTypes = []; + if (this._type && this._type.inputs) { + for (let i = 0; i < this._type.inputs.length; i++) { + if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) { + for (let j = 0; j < layer.inbound_nodes.length; j++) { + const inputType = JSON.parse(JSON.stringify(this._type.inputs[i])); + inputType.name = j + ": "+ inputType.name; + inputTypes.push(inputType); + } + } else { + inputTypes.push(this._type.inputs[i]); + } + } + } + if (this._type && this._type.outputs) { + for (let i = 0; i < this._type.outputs.length; i++) { + if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) { + for (let j = 0; j < layer.inbound_nodes.length; j++) { + const outputType = JSON.parse(JSON.stringify(this._type.outputs[i])); + outputType.name = j + ": "+ outputType.name; + outputTypes.push(outputType); + } + } else { + outputTypes.push(this._type.inputs[i]); + } + } + } + const inbound_nodes_size = layer.inbound_nodes ? layer.inbound_nodes.length : 1; + outputs = Array(inbound_nodes_size).fill(outputs).flat(); + while (inputs.length > 0) { let list = false; let name = null; let visible = true; if (!innerMetadata || inputIndex == 0) { - if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) { - const input = this._type.inputs[inputIndex]; + if (inputTypes && inputIndex < inputTypes.length) { + const input = inputTypes[inputIndex]; name = input.name; if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) { inputIndex++; continue; } visible = input.visible == false ? false : true; - if (this._type.inputs[inputIndex].list) { + if (inputTypes[inputIndex].list) { list = true; } } @@ -832,7 +865,8 @@ keras.Node = class { break; } } - const input = !list ? [ inputs.shift() ] : inputs.splice(0, inputs.length); + const size = !layer.inbound_nodes ? 1 : layer.inbound_nodes.length; + const input = !list ? [ inputs.shift() ] : inputs.splice(0, size); const inputArguments = input.map((input) => { if (input.name) { return values.map(input.name, null, initializers[input.name]); @@ -861,12 +895,12 @@ keras.Node = class { for (let i = 0; i < outputs.length; i++) { const output = outputs[i]; - const outputName = (this._type && this._type.outputs && i < this._type.outputs.length && this._type.outputs[i] && this._type.outputs[i].name) ? this._type.outputs[i].name : i.toString(); + const outputName = (outputTypes[i] && outputTypes[i].name) ? outputTypes[i].name : i.toString(); const argument = new keras.Argument(outputName, true, output.length === 0 ? [] : [ values.map(output) ]); this._outputs.push(argument); } - const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ])); + inputTypes = new Map((inputTypes || []).map((input) => [ input.name, input.type ])); for (const entry of Object.entries(args)) { const name = entry[0]; const arg = entry[1];