From c884a9f9a6c4846c5c3a5650ca7083ec2f89e9ac Mon Sep 17 00:00:00 2001 From: sstone Date: Mon, 27 Nov 2023 18:41:21 +0100 Subject: [PATCH] Rework TxComplete to use implicit ordering for musig2 nonces Instead of sending an explicit serialId -> nonce map, we send a list of public nonces ordered by serial id. This matches how signatures are sent in TxSignatures. --- .../acinq/lightning/channel/InteractiveTx.kt | 47 ++++++++++++------- .../acinq/lightning/wire/InteractiveTxTlv.kt | 19 +++----- .../acinq/lightning/wire/LightningMessages.kt | 4 +- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt index 491855153..62a06ab24 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt @@ -379,11 +379,11 @@ data class FundingContributions(val inputs: List, v ) fun weight(walletInputs: List): Int = walletInputs.sumOf { - when { - Script.isPay2wsh(it.previousTx.txOut[it.outputIndex].publicKeyScript.toByteArray()) -> Transactions.swapInputWeight - else -> Transactions.swapInputWeightMusig2 - } + when { + Script.isPay2wsh(it.previousTx.txOut[it.outputIndex].publicKeyScript.toByteArray()) -> Transactions.swapInputWeight + else -> Transactions.swapInputWeightMusig2 } + } /** We always randomize the order of inputs and outputs. */ private fun sortFundingContributions(params: InteractiveTxParams, inputs: List, outputs: List): FundingContributions { @@ -392,7 +392,7 @@ data class FundingContributions(val inputs: List, v when (input) { is InteractiveTxInput.LocalOnly -> input.copy(serialId = serialId) is InteractiveTxInput.LocalSwapIn -> input.copy(serialId = serialId) - is InteractiveTxInput.LocalMusig2SwapIn-> input.copy(serialId = serialId) + is InteractiveTxInput.LocalMusig2SwapIn -> input.copy(serialId = serialId) is InteractiveTxInput.Shared -> input.copy(serialId = serialId) } } @@ -462,6 +462,16 @@ data class SharedTransaction( val previousOutputsMap = sharedOutput + localOutputs + remoteOutputs val previousOutputs = unsignedTx.txIn.map { previousOutputsMap[it.outPoint]!! }.toList() + // nonces that we've received for all musig2 swap-in + val receivedNonces: Map = when (session.txCompleteReceived) { + null -> mapOf() + else -> (localInputs.filterIsInstance() + remoteInputs.filterIsInstance()) + .sortedBy { it.serialId } + .zip(session.txCompleteReceived.publicNonces) + .associate { it.first.serialId to it.second } + } + + // If we are swapping funds in, we provide our partial signatures to the corresponding inputs. val swapUserSigs = unsignedTx.txIn.mapIndexed { i, txIn -> localInputs @@ -477,8 +487,8 @@ data class SharedTransaction( ?.let { input -> val userNonce = input.secretNonce require(session.txCompleteReceived != null) - val serverNonce = session.txCompleteReceived.publicNonces[input.serialId] - require(serverNonce != null) + val serverNonce = receivedNonces[input.serialId] + require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } val commonNonce = PublicNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce)) TxSignatures.Companion.PartialSignature(keyManager.swapInOnChainWallet.signSwapInputUserMusig2(unsignedTx, i, previousOutputs, userNonce, serverNonce), commonNonce) } @@ -504,8 +514,8 @@ data class SharedTransaction( val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId) val userNonce = input.secretNonce require(session.txCompleteReceived != null) - val serverNonce = session.txCompleteReceived.publicNonces[input.serialId] - require(serverNonce != null) + val serverNonce = receivedNonces[input.serialId] + require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } val commonNonce = PublicNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce)) val swapInProtocol = SwapInProtocolMusig2(input.swapInParams.userKey, serverKey.publicKey(), input.swapInParams.userRefundKey, input.swapInParams.refundDelay) TxSignatures.Companion.PartialSignature(swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, serverNonce, serverKey, userNonce), commonNonce) @@ -568,7 +578,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over val localSwapTxInMusig2 = tx.localInputs.filterIsInstance().sortedBy { i -> i.serialId }.zip(localSigs.swapInUserPartialSigs.zip(remoteSigs.swapInServerPartialSigs)).map { (i, sigs) -> val (userSig, serverSig) = sigs val swapInProtocol = SwapInProtocolMusig2(i.swapInParams) - require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce){ "aggregated public nonces mismatch for local input ${i.serialId}"} + require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce) { "aggregated public nonces mismatch for local input ${i.serialId}" } val commonNonce = userSig.aggregatedPublicNonce val unsignedTx = tx.buildUnsignedTx() val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) @@ -587,7 +597,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over val remoteSwapTxInMusig2 = tx.remoteInputs.filterIsInstance().sortedBy { i -> i.serialId }.zip(remoteSigs.swapInUserPartialSigs.zip(localSigs.swapInServerPartialSigs)).map { (i, sigs) -> val (userSig, serverSig) = sigs val swapInProtocol = SwapInProtocolMusig2(i.swapInParams) - require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce){ "aggregated public nonces mismatch for remote input ${i.serialId}"} + require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce) { "aggregated public nonces mismatch for remote input ${i.serialId}" } val commonNonce = userSig.aggregatedPublicNonce val unsignedTx = tx.buildUnsignedTx() val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) @@ -689,10 +699,10 @@ data class InteractiveTxSession( null -> { // generate a new secret nonce for each musig2 new swapin every time we send TxComplete val localMusig2SwapIns = localInputs.filterIsInstance() - val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap() + val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() } val remoteMusig2SwapIns = remoteInputs.filterIsInstance() - val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap() - val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces)) + val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() } + val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).sortedBy { it.first }.map { it.second }) val next = copy(txCompleteSent = txComplete) if (next.isComplete) { Pair(next, next.validateTx(txComplete)) @@ -885,11 +895,16 @@ data class InteractiveTxSession( } sharedInputs.first() } + val receivedNonces = (localInputs.filterIsInstance() + remoteInputs.filterIsInstance()) + .sortedBy { it.serialId } + .zip(txCompleteReceived.publicNonces) + .associate { it.first.serialId to it.second } + localOnlyInputs.filterIsInstance().forEach { - txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) + receivedNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) } remoteOnlyInputs.filterIsInstance().forEach { - txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) + receivedNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId) } val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputs, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime) val tx = sharedTx.buildUnsignedTx() diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt index c80cff9a7..be40b13d7 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt @@ -71,27 +71,20 @@ sealed class TxRemoveInputTlv : Tlv sealed class TxRemoveOutputTlv : Tlv sealed class TxCompleteTlv : Tlv { - data class Nonces(val nonces: Map): TxCompleteTlv() { + /** nonces for all Musig2 swap-in inputs, ordered by serial id */ + data class Nonces(val nonces: List): TxCompleteTlv() { override val tag: Long get() = Nonces.tag override fun write(out: Output) { - LightningCodecs.writeU16(nonces.size, out) - nonces.forEach { (serialId, nonce) -> - LightningCodecs.writeBigSize(serialId, out) - LightningCodecs.writeBytes(nonce.toByteArray(), out) - } + nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) } } companion object : TlvValueReader { const val tag: Long = 101 override fun read(input: Input): Nonces { - val noncesCount = LightningCodecs.u16(input) - val nonces = (1..noncesCount).map { - val serialId = LightningCodecs.bigSize(input) - val nonce = PublicNonce.fromBin(LightningCodecs.bytes(input, 66)) - serialId to nonce - } - return Nonces(nonces.toMap()) + val count = input.availableBytes / 66 + val nonces = (0 until count).map { PublicNonce.fromBin(LightningCodecs.bytes(input, 66)) } + return Nonces(nonces) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt index 27f220e98..805efa6c3 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt @@ -451,9 +451,9 @@ data class TxComplete( ) : InteractiveTxConstructionMessage(), HasChannelId { override val type: Long get() = TxComplete.type - val publicNonces: Map = tlvs.get()?.nonces?.toMap() ?: mapOf() + val publicNonces: List = tlvs.get()?.nonces ?: listOf() - constructor(channelId: ByteVector32, publicNonces: Map) : this(channelId, TlvStream(TxCompleteTlv.Nonces(publicNonces))) + constructor(channelId: ByteVector32, publicNonces: List) : this(channelId, TlvStream(TxCompleteTlv.Nonces(publicNonces))) override fun write(out: Output) { LightningCodecs.writeBytes(channelId.toByteArray(), out)