Skip to content

Commit

Permalink
Add a musig2 secret nonce field to local/remote musing2 swap-in classes
Browse files Browse the repository at this point in the history
It makes the code cleaner and we get rid of the secret nonces map.
These nonces are replaced with dummy values whenever this classes are serialized, which is safe since they're never reused for signing txs.
  • Loading branch information
sstone committed Nov 27, 2023
1 parent 4c044fb commit ce75299
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 43 deletions.
114 changes: 72 additions & 42 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,35 @@ sealed class InteractiveTxInput {
override val previousTx: Transaction,
override val previousTxOutput: Long,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Local() {
val swapInParams: TxAddInputTlv.SwapInParamsMusig2,
val secretNonce: SecretNonce) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as LocalMusig2SwapIn

if (serialId != other.serialId) return false
if (previousTx != other.previousTx) return false
if (previousTxOutput != other.previousTxOutput) return false
if (sequence != other.sequence) return false
if (swapInParams != other.swapInParams) return false
if (outPoint != other.outPoint) return false

return true
}

override fun hashCode(): Int {
var result = serialId.hashCode()
result = 31 * result + previousTx.hashCode()
result = 31 * result + previousTxOutput.hashCode()
result = 31 * result + sequence.hashCode()
result = 31 * result + swapInParams.hashCode()
result = 31 * result + outPoint.hashCode()
return result
}

}
/**
* A remote input that funds the interactive transaction.
Expand All @@ -154,7 +181,32 @@ sealed class InteractiveTxInput {
override val outPoint: OutPoint,
override val txOut: TxOut,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParamsMusig2) : Remote()
val swapInParams: TxAddInputTlv.SwapInParamsMusig2,
val secretNonce: SecretNonce) : Remote() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as RemoteSwapInMusig2

if (serialId != other.serialId) return false
if (outPoint != other.outPoint) return false
if (txOut != other.txOut) return false
if (sequence != other.sequence) return false
if (swapInParams != other.swapInParams) return false

return true
}

override fun hashCode(): Int {
var result = serialId.hashCode()
result = 31 * result + outPoint.hashCode()
result = 31 * result + txOut.hashCode()
result = 31 * result + sequence.hashCode()
result = 31 * result + swapInParams.hashCode()
return result
}
}

/** The shared input can be added by us or by our peer, depending on who initiated the protocol. */
data class Shared(override val serialId: Long, override val outPoint: OutPoint, override val txOut: TxOut, override val sequence: UInt, val localAmount: MilliSatoshi, val remoteAmount: MilliSatoshi) : InteractiveTxInput(), Incoming, Outgoing
Expand Down Expand Up @@ -287,7 +339,8 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v
i.previousTx.stripInputWitnesses(),
i.outputIndex.toLong(),
0xfffffffdU,
TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay)
TxAddInputTlv.SwapInParamsMusig2(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay),
SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPrivateKey.publicKey(), null, null, null, randomBytes32()),
)
}
}
Expand Down Expand Up @@ -422,8 +475,7 @@ data class SharedTransaction(
.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
val userNonce = input.secretNonce
require(session.txCompleteReceived != null)
val serverNonce = session.txCompleteReceived.publicNonces[input.serialId]
require(serverNonce != null)
Expand All @@ -450,8 +502,7 @@ data class SharedTransaction(
.find { txIn.outPoint == it.outPoint }
?.let { input ->
val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId)
val userNonce = session.secretNonces[input.serialId]
require(userNonce != null)
val userNonce = input.secretNonce
require(session.txCompleteReceived != null)
val serverNonce = session.txCompleteReceived.publicNonces[input.serialId]
require(serverNonce != null)
Expand Down Expand Up @@ -598,9 +649,7 @@ data class InteractiveTxSession(
val txCompleteSent: TxComplete? = null,
val txCompleteReceived: TxComplete? = null,
val inputsReceivedCount: Int = 0,
val outputsReceivedCount: Int = 0,
val secretNonces: Map<Long, SecretNonce> = mapOf()
) {
val outputsReceivedCount: Int = 0) {

// Example flow:
// +-------+ +-------+
Expand Down Expand Up @@ -639,15 +688,12 @@ data class InteractiveTxSession(
return when (val msg = toSend.firstOrNull()) {
null -> {
// generate a new secret nonce for each musig2 new swapin every time we send TxComplete
val currentNonces = secretNonces
fun userNonce(serialId: Long) = currentNonces.getOrElse(serialId) { SecretNonce.generate(swapInKeys.userPrivateKey, swapInKeys.userPublicKey, null, null, null, randomBytes32()) }
fun serverNonce(serialId: Long, serverKey: PublicKey) = currentNonces.getOrElse(serialId) { SecretNonce.generate(null, serverKey, null, null, null, randomBytes32()) }
val localMusig2SwapIns = localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
val localNonces = localMusig2SwapIns.map { it.serialId to userNonce(it.serialId) }.toMap()
val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap()
val remoteMusig2SwapIns = remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>()
val remoteNonces = remoteMusig2SwapIns.map { it.serialId to serverNonce(it.serialId, it.swapInParams.serverKey) }.toMap()
val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).mapValues { it.value.publicNonce() })
val next = copy(txCompleteSent = txComplete, secretNonces = localNonces + remoteNonces)
val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap()
val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces))
val next = copy(txCompleteSent = txComplete)
if (next.isComplete) {
Pair(next, next.validateTx(txComplete))
} else {
Expand Down Expand Up @@ -714,7 +760,10 @@ data class InteractiveTxSession(
val outpoint = OutPoint(message.previousTx, message.previousTxOutput)
val txOut = message.previousTx.txOut[message.previousTxOutput.toInt()]
when {
message.swapInParamsMusig2 != null -> InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2)
message.swapInParamsMusig2 != null -> {
val secretNonce = SecretNonce.generate(null, message.swapInParamsMusig2.serverKey, null, null, null, randomBytes32())
InteractiveTxInput.RemoteSwapInMusig2(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsMusig2, secretNonce)
}
message.swapInParams != null -> InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams)
else -> InteractiveTxInput.RemoteOnly(message.serialId, outpoint, txOut, message.sequence)
}
Expand Down Expand Up @@ -836,32 +885,13 @@ data class InteractiveTxSession(
}
sharedInputs.first()
}
val localOnlyInputsWithNonces = localOnlyInputs.map {
when {
it is InteractiveTxInput.LocalSwapIn && swapInKeys.swapInProtocolMusig2.isMine(it.txOut) -> {
val userNonce = secretNonces[it.serialId]
val serverNonce = txCompleteReceived.publicNonces[it.serialId]
if (userNonce == null || serverNonce == null) return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
it
}

else -> it
}
localOnlyInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>().forEach {
txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
}
val remoteOnlyInputsWithNonces = remoteOnlyInputs.map {
when {
it is InteractiveTxInput.RemoteSwapInMusig2 -> {
val userNonce = secretNonces[it.serialId]
val serverNonce = txCompleteReceived.publicNonces[it.serialId]
if (userNonce == null || serverNonce == null) return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
it
}

else -> it
}
remoteOnlyInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>().forEach {
txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
}

val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputsWithNonces, remoteOnlyInputsWithNonces, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime)
val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputs, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime)
val tx = sharedTx.buildUnsignedTx()
if (sharedTx.localAmountIn < sharedTx.localAmountOut || sharedTx.remoteAmountIn < sharedTx.remoteAmountOut) {
return InteractiveTxSessionAction.InvalidTxChangeAmount(fundingParams.channelId, tx.txid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.Input
import fr.acinq.bitcoin.musig2.PublicNonce
import fr.acinq.bitcoin.musig2.SecretNonce
import fr.acinq.lightning.CltvExpiryDelta
import fr.acinq.lightning.Features
import fr.acinq.lightning.ShortChannelId
Expand Down Expand Up @@ -240,6 +241,7 @@ object Deserialization {
previousTxOutput = readNumber(),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this),
secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey())
)
else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Local::class}")
}
Expand All @@ -263,7 +265,8 @@ object Deserialization {
outPoint = readOutPoint(),
txOut = TxOut.read(readDelimitedByteArray()),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this)
swapInParams = TxAddInputTlv.SwapInParamsMusig2.read(this),
secretNonce = SecretNonce(PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One), PrivateKey(ByteVector32.One).publicKey())
)
else -> error("unknown discriminator $discriminator for class ${InteractiveTxInput.Remote::class}")
}
Expand Down

0 comments on commit ce75299

Please sign in to comment.