From ce7529951ade4585751a5393fd1e3f56e674bd99 Mon Sep 17 00:00:00 2001 From: sstone Date: Mon, 27 Nov 2023 17:49:48 +0100 Subject: [PATCH] Add a musig2 secret nonce field to local/remote musing2 swap-in classes It makes the code cleaner and we get rid of the secret nonces map. These nonces are replaced with dummy values whenever this classes are serialized, which is safe since they're never reused for signing txs. --- .../acinq/lightning/channel/InteractiveTx.kt | 114 +++++++++++------- .../serialization/v4/Deserialization.kt | 5 +- 2 files changed, 76 insertions(+), 43 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt index b51981f8a..40bcbc3be 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt @@ -129,8 +129,35 @@ sealed class InteractiveTxInput { override val previousTx: Transaction, override val previousTxOutput: Long, override val sequence: UInt, - val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Local() { + val swapInParams: TxAddInputTlv.SwapInParamsMusig2, + val secretNonce: SecretNonce) : Local() { override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput) + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as LocalMusig2SwapIn + + if (serialId != other.serialId) return false + if (previousTx != other.previousTx) return false + if (previousTxOutput != other.previousTxOutput) return false + if (sequence != other.sequence) return false + if (swapInParams != other.swapInParams) return false + if (outPoint != other.outPoint) return false + + return true + } + + override fun hashCode(): Int { + var result = serialId.hashCode() + result = 31 * result + previousTx.hashCode() + result = 31 * result + previousTxOutput.hashCode() + result = 31 * result + sequence.hashCode() + result = 31 * result + swapInParams.hashCode() + result = 31 * result + outPoint.hashCode() + return result + } + } /** * A remote input that funds the interactive transaction. @@ -154,7 +181,32 @@ sealed class InteractiveTxInput { override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, - val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Remote() + val swapInParams: TxAddInputTlv.SwapInParamsMusig2, + val secretNonce: SecretNonce) : Remote() { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as RemoteSwapInMusig2 + + if (serialId != other.serialId) return false + if (outPoint != other.outPoint) return false + if (txOut != other.txOut) return false + if (sequence != other.sequence) return false + if (swapInParams != other.swapInParams) return false + + return true + } + + override fun hashCode(): Int { + var result = serialId.hashCode() + result = 31 * result + outPoint.hashCode() + result = 31 * result + txOut.hashCode() + result = 31 * result + sequence.hashCode() + result = 31 * result + swapInParams.hashCode() + return result + } + } /** The shared input can be added by us or by our peer, depending on who initiated the protocol. */ data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming, Outgoing @@ -287,7 +339,8 @@ data class FundingContributions(val inputs: List, v i.previousTx.stripInputWitnesses(), i.outputIndex.toLong(), 0xfffffffdU, - TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay) + TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay), + SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPrivateKey.publicKey(), null, null, null, randomBytes32()), ) } } @@ -422,8 +475,7 @@ data class SharedTransaction( .filterIsInstance() .find { txIn.outPoint == it.outPoint } ?.let { input -> - val userNonce = session.secretNonces[input.serialId] - require(userNonce != null) + val userNonce = input.secretNonce require(session.txCompleteReceived != null) val serverNonce = session.txCompleteReceived.publicNonces[input.serialId] require(serverNonce != null) @@ -450,8 +502,7 @@ data class SharedTransaction( .find { txIn.outPoint == it.outPoint } ?.let { input -> val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId) - val userNonce = session.secretNonces[input.serialId] - require(userNonce != null) + val userNonce = input.secretNonce require(session.txCompleteReceived != null) val serverNonce = session.txCompleteReceived.publicNonces[input.serialId] require(serverNonce != null) @@ -598,9 +649,7 @@ data class InteractiveTxSession( val txCompleteSent: TxComplete? = null, val txCompleteReceived: TxComplete? = null, val inputsReceivedCount: Int = 0, - val outputsReceivedCount: Int = 0, - val secretNonces: Map = mapOf() -) { + val outputsReceivedCount: Int = 0) { // Example flow: // +-------+ +-------+ @@ -639,15 +688,12 @@ data class InteractiveTxSession( return when (val msg = toSend.firstOrNull()) { null -> { // generate a new secret nonce for each musig2 new swapin every time we send TxComplete - val currentNonces = secretNonces - fun userNonce(serialId: Long) = currentNonces.getOrElse(serialId) { SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32()) } - fun serverNonce(serialId: Long, serverKey: PublicKey) = currentNonces.getOrElse(serialId) { SecretNonce.generate(null, serverKey, null, null, null, randomBytes32()) } val localMusig2SwapIns = localInputs.filterIsInstance() - val localNonces = localMusig2SwapIns.map { it.serialId to userNonce(it.serialId) }.toMap() + val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap() val remoteMusig2SwapIns = remoteInputs.filterIsInstance() - val remoteNonces = remoteMusig2SwapIns.map { it.serialId to serverNonce(it.serialId, it.swapInParams.serverKey) }.toMap() - val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).mapValues { it.value.publicNonce() }) - val next = copy(txCompleteSent = txComplete, secretNonces = localNonces + remoteNonces) + val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap() + val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces)) + val next = copy(txCompleteSent = txComplete) if (next.isComplete) { Pair(next, next.validateTx(txComplete)) } else { @@ -714,7 +760,10 @@ data class InteractiveTxSession( val outpoint = OutPoint(message.previousTx, message.previousTxOutput) val txOut = message.previousTx.txOut[message.previousTxOutput.toInt()] when { - message.swapInParamsMusig2 != null -> InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2) + message.swapInParamsMusig2 != null -> { + val secretNonce = SecretNonce.generate(null, message.swapInParamsMusig2.serverKey, null, null, null, randomBytes32()) + InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2, secretNonce) + } message.swapInParams != null -> InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams) else -> InteractiveTxInput.RemoteOnly(message.serialId, outpoint, txOut, message.sequence) } @@ -836,32 +885,13 @@ data class InteractiveTxSession( } sharedInputs.first() } - val localOnlyInputsWithNonces = localOnlyInputs.map { - when { - it is InteractiveTxInput.LocalSwapIn && swapInKeys.swapInProtocolMusig2.isMine(it.txOut) -> { - val userNonce = secretNonces[it.serialId] - val serverNonce = txCompleteReceived.publicNonces[it.serialId] - if (userNonce == null || serverNonce == null) return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) - it - } - - else -> it - } + localOnlyInputs.filterIsInstance().forEach { + txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) } - val remoteOnlyInputsWithNonces = remoteOnlyInputs.map { - when { - it is InteractiveTxInput.RemoteSwapInMusig2 -> { - val userNonce = secretNonces[it.serialId] - val serverNonce = txCompleteReceived.publicNonces[it.serialId] - if (userNonce == null || serverNonce == null) return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) - it - } - - else -> it - } + remoteOnlyInputs.filterIsInstance().forEach { + txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) } - - val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputsWithNonces, remoteOnlyInputsWithNonces, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime) + val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputs, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime) val tx = sharedTx.buildUnsignedTx() if (sharedTx.localAmountIn < sharedTx.localAmountOut || sharedTx.remoteAmountIn < sharedTx.remoteAmountOut) { return InteractiveTxSessionAction.InvalidTxChangeAmount(fundingParams.channelId, tx.txid) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt index 37b8a4c12..4c9c56be0 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Deserialization.kt @@ -4,6 +4,7 @@ import fr.acinq.bitcoin.* import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.musig2.PublicNonce +import fr.acinq.bitcoin.musig2.SecretNonce import fr.acinq.lightning.CltvExpiryDelta import fr.acinq.lightning.Features import fr.acinq.lightning.ShortChannelId @@ -240,6 +241,7 @@ object Deserialization { previousTxOutput = readNumber(), sequence = readNumber().toUInt(), swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this), + secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey()) ) else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Local::class}") } @@ -263,7 +265,8 @@ object Deserialization { outPoint = readOutPoint(), txOut = TxOut.read(readDelimitedByteArray()), sequence = readNumber().toUInt(), - swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this) + swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this), + secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey()) ) else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Remote::class}") }