diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt index 2eda1a72a..7d801eab9 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt @@ -4,6 +4,9 @@ import fr.acinq.bitcoin.* import fr.acinq.bitcoin.Script.tail import fr.acinq.bitcoin.crypto.musig2.IndividualNonce import fr.acinq.bitcoin.crypto.musig2.SecretNonce +import fr.acinq.bitcoin.utils.flatMap +import fr.acinq.bitcoin.utils.getOrDefault +import fr.acinq.bitcoin.utils.getOrElse import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.blockchain.electrum.WalletState @@ -442,16 +445,14 @@ data class SharedTransaction( val swapUserPartialSigs = unsignedTx.txIn.mapIndexed { i, txIn -> localInputs .filterIsInstance() - .find { txIn.outPoint == it.outPoint } + .find { txIn.outPoint == it.outPoint && session.secretNonces.containsKey(it.serialId) && receivedNonces.containsKey(it.serialId) } ?.let { input -> - val userNonce = session.secretNonces[input.serialId] - require(userNonce != null) - require(session.txCompleteReceived != null) - val serverNonce = receivedNonces[input.serialId] - require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } - val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce)) - val psig = keyManager.swapInOnChainWallet.signSwapInputUser(unsignedTx, i, previousOutputs, userNonce.first, commonNonce) - TxSignatures.Companion.PartialSignature(psig, commonNonce) + val userNonce = session.secretNonces[input.serialId]!! + val serverNonce = receivedNonces[input.serialId]!! + IndividualNonce.aggregate(listOf(userNonce.second, serverNonce)) + .flatMap { commonNonce -> keyManager.swapInOnChainWallet.signSwapInputUser(unsignedTx, i, previousOutputs, userNonce.first, commonNonce) + .map { psig -> TxSignatures.Companion.PartialSignature(psig, commonNonce) } + }.getOrDefault(null) } }.filterNotNull() @@ -470,18 +471,16 @@ data class SharedTransaction( val swapServerPartialSigs = unsignedTx.txIn.mapIndexed { i, txIn -> remoteInputs .filterIsInstance() - .find { txIn.outPoint == it.outPoint } + .find { txIn.outPoint == it.outPoint && session.secretNonces.containsKey(it.serialId) && receivedNonces.containsKey(it.serialId) } ?.let { input -> val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId) - val userNonce = session.secretNonces[input.serialId] - require(userNonce != null) - require(session.txCompleteReceived != null) - val serverNonce = receivedNonces[input.serialId] - require(serverNonce != null) { "missing server nonce for input ${input.serialId}" } - val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce)) + val userNonce = session.secretNonces[input.serialId]!! + val serverNonce = receivedNonces[input.serialId]!! val swapInProtocol = SwapInProtocol(input.swapInParams.userKey, serverKey.publicKey(), input.swapInParams.userRefundKey, input.swapInParams.refundDelay) - val psig = swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, commonNonce, serverKey, userNonce.first) - TxSignatures.Companion.PartialSignature(psig, commonNonce) + IndividualNonce.aggregate(listOf(userNonce.second, serverNonce)) + .flatMap { commonNonce -> swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, commonNonce, serverKey, userNonce.first) + .map { psig -> TxSignatures.Companion.PartialSignature(psig, commonNonce) } + }.getOrDefault(null) } }.filterNotNull() @@ -543,10 +542,10 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over val swapInProtocol = SwapInProtocol(i.swapInParams) val commonNonce = userSig.aggregatedPublicNonce val unsignedTx = tx.buildUnsignedTx() - val ctx = swapInProtocol.session(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) - val commonSig = ctx.add(listOf(userSig.sig, serverSig.sig)) - val witness = swapInProtocol.witness(commonSig) - Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness)) + val witness = swapInProtocol.session(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) + .flatMap { s -> s.add(listOf(userSig.sig, serverSig.sig)).map { commonSig -> swapInProtocol.witness(commonSig) } } + require(witness.isRight) { "cannot compute aggregated signature" } + Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness.right!!)) } val remoteOnlyTxIn = tx.remoteOnlyInputs().sortedBy { i -> i.serialId }.zip(remoteSigs.witnesses).map { (i, w) -> Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), w)) } @@ -561,10 +560,10 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over val swapInProtocol = SwapInProtocol(i.swapInParams) val commonNonce = userSig.aggregatedPublicNonce val unsignedTx = tx.buildUnsignedTx() - val ctx = swapInProtocol.session(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) - val commonSig = ctx.add(listOf(userSig.sig, serverSig.sig)) - val witness = swapInProtocol.witness(commonSig) - Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness)) + val witness = swapInProtocol.session(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce) + .flatMap { s -> s.add(listOf(userSig.sig, serverSig.sig)).map { commonSig -> swapInProtocol.witness(commonSig) } } + require(witness.isRight) { "cannot compute aggregated signature" } + Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness.right!!)) } val inputs = (sharedTxIn + localOnlyTxIn + localSwapTxIn + localSwapTxInMusig2 + remoteOnlyTxIn + remoteSwapTxIn + remoteSwapTxInMusig2).sortedBy { (serialId, _) -> serialId }.map { (_, i) -> i } val sharedTxOut = listOf(Pair(tx.sharedOutput.serialId, TxOut(tx.sharedOutput.amount, tx.sharedOutput.pubkeyScript))) @@ -692,7 +691,10 @@ data class InteractiveTxSession( val next1 = when (msg.value) { is InteractiveTxInput.LocalSwapIn -> { // generate a secret nonce for this input if we don't already have one - val secretNonce = next.secretNonces[msg.value.serialId] ?: SecretNonce.generate(randomBytes32(), swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null) + val secretNonce = next.secretNonces[msg.value.serialId] ?: run { + val s = SecretNonce.generate(randomBytes32(), swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null) + s.getOrElse { error("cannot generate secret nonce") } + } next.copy(secretNonces = next.secretNonces + (msg.value.serialId to secretNonce)) } else -> next @@ -763,6 +765,7 @@ data class InteractiveTxSession( val session2 = when (input) { is InteractiveTxInput.RemoteSwapIn -> { val secretNonce = secretNonces[input.serialId] ?: SecretNonce.generate(randomBytes32(), null, input.swapInParams.serverKey, null, null, null) + .getOrElse { error("cannot generate secret nonce") } session1.copy(secretNonces = secretNonces + (input.serialId to secretNonce)) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/crypto/KeyManager.kt b/src/commonMain/kotlin/fr/acinq/lightning/crypto/KeyManager.kt index 4288eac4e..1e34459ea 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/crypto/KeyManager.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/crypto/KeyManager.kt @@ -5,6 +5,7 @@ import fr.acinq.bitcoin.DeterministicWallet.hardened import fr.acinq.bitcoin.crypto.musig2.AggregatedNonce import fr.acinq.bitcoin.crypto.musig2.SecretNonce import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.utils.Either import fr.acinq.lightning.DefaultSwapInParams import fr.acinq.lightning.NodeParams import fr.acinq.lightning.blockchain.fee.FeeratePerKw @@ -158,7 +159,7 @@ interface KeyManager { return legacySwapInProtocol.signSwapInputUser(fundingTx, index, parentTxOuts[fundingTx.txIn[index].outPoint.index.toInt()] , userPrivateKey) } - fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List, userNonce: SecretNonce, commonNonce: AggregatedNonce): ByteVector32 { + fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List, userNonce: SecretNonce, commonNonce: AggregatedNonce): Either { return swapInProtocol.signSwapInputUser(fundingTx, index, parentTxOuts, userPrivateKey, userNonce, commonNonce) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/transactions/SwapInProtocol.kt b/src/commonMain/kotlin/fr/acinq/lightning/transactions/SwapInProtocol.kt index 6a7d77964..5dcf8a61d 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/transactions/SwapInProtocol.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/transactions/SwapInProtocol.kt @@ -5,6 +5,8 @@ import fr.acinq.bitcoin.crypto.musig2.AggregatedNonce import fr.acinq.bitcoin.crypto.musig2.KeyAggCache import fr.acinq.bitcoin.crypto.musig2.SecretNonce import fr.acinq.bitcoin.crypto.musig2.Session +import fr.acinq.bitcoin.utils.Either +import fr.acinq.bitcoin.utils.flatMap import fr.acinq.lightning.NodeParams import fr.acinq.lightning.wire.TxAddInputTlv @@ -25,10 +27,14 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe private val merkleRoot = scriptTree.hash() // the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key - private val internalPubKeyAndCache = KeyAggCache.Companion.add(listOf(userPublicKey, serverPublicKey), null) + private val internalPubKeyAndCache = run { + val c = KeyAggCache.add(listOf(userPublicKey, serverPublicKey), null) + if (c.isLeft) error("key aggregation failed") else c.right!! + } private val internalPubKey = internalPubKeyAndCache.first private val cache = internalPubKeyAndCache.second + // it is tweaked with the script's merkle root to get the pubkey that will be exposed private val commonPubKeyAndParity = internalPubKey.outputKey(Crypto.TaprootTweak.ScriptTweak(merkleRoot)) val commonPubKey = commonPubKeyAndParity.first @@ -45,12 +51,13 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe fun witnessRefund(userSig: ByteVector64): ScriptWitness = ScriptWitness.empty.push(userSig).push(redeemScript).push(controlBlock) - fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List, userPrivateKey: PrivateKey, userNonce: SecretNonce, commonNonce: AggregatedNonce): ByteVector32 { + fun signSwapInputUser(fundingTx: Transaction, index: Int, parentTxOuts: List, userPrivateKey: PrivateKey, userNonce: SecretNonce, commonNonce: AggregatedNonce): Either { require(userPrivateKey.publicKey() == userPublicKey) val txHash = Transaction.hashForSigningSchnorr(fundingTx, index, parentTxOuts, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT) - val cache1 = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true).first - val session = Session.build(commonNonce, txHash, cache1) - return session.sign(userNonce, userPrivateKey, cache1) + + return cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true) + .flatMap { (c, _) -> Session.build(commonNonce, txHash, c).map { s -> Pair(s, c) } } + .flatMap { (s, c) -> s.sign(userNonce, userPrivateKey, c) } } fun signSwapInputRefund(fundingTx: Transaction, index: Int, parentTxOuts: List, userPrivateKey: PrivateKey): ByteVector64 { @@ -58,17 +65,17 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe return Crypto.signSchnorr(txHash, userPrivateKey, Crypto.SchnorrTweak.NoTweak) } - fun signSwapInputServer(fundingTx: Transaction, index: Int, parentTxOuts: List, commonNonce: AggregatedNonce, serverPrivateKey: PrivateKey, serverNonce: SecretNonce): ByteVector32 { + fun signSwapInputServer(fundingTx: Transaction, index: Int, parentTxOuts: List, commonNonce: AggregatedNonce, serverPrivateKey: PrivateKey, serverNonce: SecretNonce): Either { val txHash = Transaction.hashForSigningSchnorr(fundingTx, index, parentTxOuts, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT) - val cache1 = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true).first - val session = Session.build(commonNonce, txHash, cache1) - return session.sign(serverNonce, serverPrivateKey, cache1) + return cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true) + .flatMap { (c, _) -> Session.build(commonNonce, txHash, c).map { s -> Pair(s, c) } } + .flatMap { (s, c) -> s.sign(serverNonce, serverPrivateKey, c) } } - fun session(fundingTx: Transaction, index: Int, parentTxOuts: List, commonNonce: AggregatedNonce): Session { + fun session(fundingTx: Transaction, index: Int, parentTxOuts: List, commonNonce: AggregatedNonce): Either { val txHash = Transaction.hashForSigningSchnorr(fundingTx, index, parentTxOuts, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPROOT) - val cache1 = cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true).first - return Session.build(commonNonce, txHash, cache1) + return cache.tweak(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(merkleRoot)), true) + .flatMap { (c, _) -> Session.build(commonNonce, txHash, c) } } companion object { @@ -81,17 +88,18 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe * @param masterRefundKey master private key for the refund keys. we assume that there is a single level of derivation to compute the refund keys * @return a taproot descriptor that can be imported in bitcoin core (from version 26 on) to recover user funds once the funding delay has passed */ - fun descriptor(chain: NodeParams.Chain, userPublicKey: PublicKey, serverPublicKey: PublicKey, refundDelay: Int, masterRefundKey: DeterministicWallet.ExtendedPrivateKey): String { + fun descriptor(chain: NodeParams.Chain, userPublicKey: PublicKey, serverPublicKey: PublicKey, refundDelay: Int, masterRefundKey: DeterministicWallet.ExtendedPrivateKey): Either { // the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key - val (internalPubKey, _) = KeyAggCache.Companion.add(listOf(userPublicKey, serverPublicKey), null) - val prefix = when (chain) { - NodeParams.Chain.Mainnet -> DeterministicWallet.xprv - else -> DeterministicWallet.tprv + return KeyAggCache.Companion.add(listOf(userPublicKey, serverPublicKey)).map { (internalPubKey, _) -> + val prefix = when (chain) { + NodeParams.Chain.Mainnet -> DeterministicWallet.xprv + else -> DeterministicWallet.tprv + } + val xpriv = DeterministicWallet.encode(masterRefundKey, prefix) + val desc = "tr(${internalPubKey.value},and_v(v:pk($xpriv/*),older($refundDelay)))" + val checksum = Descriptor.checksum(desc) + "$desc#$checksum" } - val xpriv = DeterministicWallet.encode(masterRefundKey, prefix) - val desc = "tr(${internalPubKey.value},and_v(v:pk($xpriv/*),older($refundDelay)))" - val checksum = Descriptor.checksum(desc) - return "$desc#$checksum" } /** @@ -103,20 +111,20 @@ class SwapInProtocol(val userPublicKey: PublicKey, val serverPublicKey: PublicKe * @param masterRefundKey master public key for the refund keys. we assume that there is a single level of derivation to compute the refund keys * @return a taproot descriptor that can be imported in bitcoin core (from version 26 on) to create a watch-only wallet for your swap-in transactions */ - fun descriptor(chain: NodeParams.Chain, userPublicKey: PublicKey, serverPublicKey: PublicKey, refundDelay: Int, masterRefundKey: DeterministicWallet.ExtendedPublicKey): String { + fun descriptor(chain: NodeParams.Chain, userPublicKey: PublicKey, serverPublicKey: PublicKey, refundDelay: Int, masterRefundKey: DeterministicWallet.ExtendedPublicKey): Any { // the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key - val (internalPubKey, _) = KeyAggCache.Companion.add(listOf(userPublicKey, serverPublicKey), null) - val prefix = when (chain) { - NodeParams.Chain.Mainnet -> DeterministicWallet.xpub - else -> DeterministicWallet.tpub + return KeyAggCache.Companion.add(listOf(userPublicKey, serverPublicKey)).map { (internalPubKey, _) -> + val prefix = when (chain) { + NodeParams.Chain.Mainnet -> DeterministicWallet.xpub + else -> DeterministicWallet.tpub + } + val xpub = DeterministicWallet.encode(masterRefundKey, prefix) + val path = masterRefundKey.path.toString().replace('\'', 'h').removePrefix("m") + val desc = "tr(${internalPubKey.value},and_v(v:pk($xpub$path/*),older($refundDelay)))" + val checksum = Descriptor.checksum(desc) + return "$desc#$checksum" } - val xpub = DeterministicWallet.encode(masterRefundKey, prefix) - val path = masterRefundKey.path.toString().replace('\'', 'h').removePrefix("m") - val desc = "tr(${internalPubKey.value},and_v(v:pk($xpub$path/*),older($refundDelay)))" - val checksum = Descriptor.checksum(desc) - return "$desc#$checksum" } - } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt index 1ab3dee25..2468ed786 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt @@ -581,14 +581,14 @@ class TransactionsTestsCommon : LightningTestSuite() { ) // this is the beginning of an interactive musig2 signing session. if user and server are disconnected before they have exchanged partial // signatures they will have to start again with fresh nonces - val (_, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey()), null) - val userNonce = SecretNonce.generate(randomBytes32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null) - val serverNonce = SecretNonce.generate(randomBytes32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null) - val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce.second)) - val userSig = swapInProtocol.signSwapInputUser(tx, 0, swapInTx.txOut, userPrivateKey, userNonce.first, commonNonce) - val serverSig = swapInProtocol.signSwapInputServer(tx, 0, swapInTx.txOut, commonNonce, serverPrivateKey, serverNonce.first) - val ctx = swapInProtocol.session(tx, 0, swapInTx.txOut, commonNonce) - val commonSig = ctx.add(listOf(userSig, serverSig)) + val (_, cache) = KeyAggCache.add(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).right!! + val userNonce = SecretNonce.generate(randomBytes32(), userPrivateKey, userPrivateKey.publicKey(), null, cache, null).right!! + val serverNonce = SecretNonce.generate(randomBytes32(), serverPrivateKey, serverPrivateKey.publicKey(), null, cache, null).right!! + val commonNonce = IndividualNonce.aggregate(listOf(userNonce.second, serverNonce.second)).right!! + val userSig = swapInProtocol.signSwapInputUser(tx, 0, swapInTx.txOut, userPrivateKey, userNonce.first, commonNonce).right!! + val serverSig = swapInProtocol.signSwapInputServer(tx, 0, swapInTx.txOut, commonNonce, serverPrivateKey, serverNonce.first).right!! + val ctx = swapInProtocol.session(tx, 0, swapInTx.txOut, commonNonce).right!! + val commonSig = ctx.add(listOf(userSig, serverSig)).right!! val signedTx = tx.updateWitness(0, swapInProtocol.witness(commonSig)) Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) }