Skip to content

Commit

Permalink
Add specific splicing nonces to channel_reestablish
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed Nov 18, 2024
1 parent e225e66 commit affe638
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ sealed class ChannelState {
/** A channel state that is persisted to the DB. */
sealed class PersistedChannelState : ChannelState() {
abstract val channelId: ByteVector32

internal fun ChannelContext.createChannelReestablish(): HasEncryptedChannelData = when (val state = this@PersistedChannelState) {
is WaitForFundingSigned -> {
val myFirstPerCommitmentPoint = keyManager.channelKeys(state.channelParams.localParams.fundingKeyPath).commitmentPoint(0)
Expand All @@ -332,14 +331,37 @@ sealed class PersistedChannelState : ChannelState() {
true -> state.commitments.active.map { channelKeys.verificationNonce(it.fundingTxIndex, state.commitments.localCommitIndex + 1).second }
else -> null
}
val spliceNonces = when {
state.commitments.isTaprootChannel && state is Normal && state.spliceStatus is SpliceStatus.WaitingForSigs -> {
logger.info { "splice in progress, re-sending splice nonces" }
val localCommitIndex = when (state.spliceStatus.session.localCommit) {
is Either.Left -> state.spliceStatus.session.localCommit.value.index
is Either.Right -> state.spliceStatus.session.localCommit.value.index
}
listOf(
channelKeys.verificationNonce(state.spliceStatus.session.fundingTxIndex, localCommitIndex).second,
channelKeys.verificationNonce(state.spliceStatus.session.fundingTxIndex, localCommitIndex + 1).second
)
}

state.commitments.isTaprootChannel && state.commitments.latest.localFundingStatus is LocalFundingStatus.UnconfirmedFundingTx -> {
logger.info { "splice may not have confirmed yet, re-sending splice nonces" }
listOf(
channelKeys.verificationNonce(state.commitments.latest.fundingTxIndex, state.commitments.latest.localCommit.index).second,
channelKeys.verificationNonce(state.commitments.latest.fundingTxIndex, state.commitments.latest.localCommit.index + 1).second
)
}
else -> null
}
val unsignedFundingTxId = when (state) {
is WaitForFundingConfirmed -> state.getUnsignedFundingTxId()
is Normal -> state.getUnsignedFundingTxId() // a splice was in progress, we tell our peer that we are remembering it and are expecting signatures
else -> null
}
val tlvs: TlvStream<ChannelReestablishTlv> = TlvStream(setOfNotNull(
unsignedFundingTxId?.let { ChannelReestablishTlv.NextFunding(it) },
myNextLocalNonces?.let { ChannelReestablishTlv.NextLocalNoncesTlv(it) }
myNextLocalNonces?.let { ChannelReestablishTlv.NextLocalNoncesTlv(it) },
spliceNonces?.let { ChannelReestablishTlv.SpliceNoncesTlv(it) }
))
ChannelReestablish(
channelId = channelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,15 @@ data class Syncing(val state: PersistedChannelState, val channelReestablishSent:
val spliceStatus1 = if (state.spliceStatus is SpliceStatus.WaitingForSigs && state.spliceStatus.session.fundingTx.txId == cmd.message.nextFundingTxId) {
// We retransmit our commit_sig, and will send our tx_signatures once we've received their commit_sig.
logger.info { "re-sending commit_sig for splice attempt with fundingTxIndex=${state.spliceStatus.session.fundingTxIndex} fundingTxId=${state.spliceStatus.session.fundingTx.txId}" }
val commitSig = state.spliceStatus.session.remoteCommit.sign(channelKeys(), state.commitments.params, state.spliceStatus.session, cmd.message.nextLocalNonces.firstOrNull())
val spliceNonce = when {
state.spliceStatus.session.remoteCommit.index == cmd.message.nextLocalCommitmentNumber -> cmd.message.secondSpliceNonce
state.spliceStatus.session.remoteCommit.index == cmd.message.nextLocalCommitmentNumber - 1 -> cmd.message.firstSpliceNonce
else -> {
// we should never end up here, it would have been handled in handleSync()
error("invalid nextLocalCommitmentNumber in ChannelReestablish")
}
}
val commitSig = state.spliceStatus.session.remoteCommit.sign(channelKeys(), state.commitments.params, state.spliceStatus.session, spliceNonce)
actions.add(ChannelAction.Message.Send(commitSig))
state.spliceStatus
} else if (state.commitments.latest.fundingTxId == cmd.message.nextFundingTxId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ sealed class ChannelReestablishTlv : Tlv {
}
}
}

data class SpliceNoncesTlv(val nonces: List<IndividualNonce>) : ChannelReestablishTlv() {
override val tag: Long get() = SpliceNoncesTlv.tag

override fun write(out: Output) {
nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) }
}

companion object : TlvValueReader<SpliceNoncesTlv> {
const val tag: Long = 6
override fun read(input: Input): SpliceNoncesTlv {
val count = input.availableBytes / 66
val nonces = (0 until count).map { IndividualNonce(LightningCodecs.bytes(input, 66)) }
return SpliceNoncesTlv(nonces)
}
}
}
}

sealed class ShutdownTlv : Tlv {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,9 @@ data class ChannelReestablish(

val nextFundingTxId: TxId? = tlvStream.get<ChannelReestablishTlv.NextFunding>()?.txId
val nextLocalNonces: List<IndividualNonce> = tlvStream.get<ChannelReestablishTlv.NextLocalNoncesTlv>()?.nonces ?: listOf()
val spliceNonces: List<IndividualNonce> = tlvStream.get<ChannelReestablishTlv.SpliceNoncesTlv>()?.nonces ?: listOf()
val firstSpliceNonce = if (spliceNonces.isNotEmpty()) spliceNonces[0] else null
val secondSpliceNonce = if (spliceNonces.isNotEmpty()) spliceNonces[1] else null

override val channelData: EncryptedChannelData get() = tlvStream.get<ChannelReestablishTlv.ChannelData>()?.ecb ?: EncryptedChannelData.empty
override fun withNonEmptyChannelData(ecd: EncryptedChannelData): ChannelReestablish = copy(tlvStream = tlvStream.addOrUpdate(ChannelReestablishTlv.ChannelData(ecd)))
Expand All @@ -1391,6 +1394,7 @@ data class ChannelReestablish(
ChannelReestablishTlv.ChannelData.tag to ChannelReestablishTlv.ChannelData.Companion as TlvValueReader<ChannelReestablishTlv>,
ChannelReestablishTlv.NextFunding.tag to ChannelReestablishTlv.NextFunding.Companion as TlvValueReader<ChannelReestablishTlv>,
ChannelReestablishTlv.NextLocalNoncesTlv.tag to ChannelReestablishTlv.NextLocalNoncesTlv.Companion as TlvValueReader<ChannelReestablishTlv>,
ChannelReestablishTlv.SpliceNoncesTlv.tag to ChannelReestablishTlv.SpliceNoncesTlv.Companion as TlvValueReader<ChannelReestablishTlv>,
)

override fun read(input: Input): ChannelReestablish {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,33 @@ class SpliceTestsCommon : LightningTestSuite() {
resolveHtlcs(alice4, bob4, htlcs, commitmentsCount = 2)
}

@Test
fun `disconnect -- commit_sig not received -- simple taproot channels`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx(channelType = ChannelType.SupportedChannelType.SimpleTaprootStaging)
val (alice0, bob0, htlcs) = setupHtlcs(alice, bob)
val (alice1, _, bob1, _) = spliceInAndOutWithoutSigs(alice0, bob0, inAmounts = listOf(50_000.sat), outAmount = 100_000.sat)

val spliceStatus = alice1.state.spliceStatus
assertIs<SpliceStatus.WaitingForSigs>(spliceStatus)

val (alice2, bob2, channelReestablishAlice) = disconnect(alice1, bob1)
assertEquals(channelReestablishAlice.nextFundingTxId, spliceStatus.session.fundingTx.txId)
val (bob3, actionsBob3) = bob2.process(ChannelCommand.MessageReceived(channelReestablishAlice))
assertIs<LNChannel<Normal>>(bob3)
assertEquals(actionsBob3.size, 4)
val channelReestablishBob = actionsBob3.findOutgoingMessage<ChannelReestablish>()
val commitSigBob = actionsBob3.findOutgoingMessage<CommitSig>()
assertEquals(htlcs.aliceToBob.map { it.second }.toSet(), actionsBob3.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
assertEquals(channelReestablishBob.nextFundingTxId, spliceStatus.session.fundingTx.txId)
val (alice3, actionsAlice3) = alice2.process(ChannelCommand.MessageReceived(channelReestablishBob))
assertIs<LNChannel<Normal>>(alice3)
assertEquals(actionsAlice3.size, 3)
val commitSigAlice = actionsAlice3.findOutgoingMessage<CommitSig>()
val (alice4, bob4) = exchangeSpliceSigs(alice3, commitSigAlice, bob3, commitSigBob)
assertEquals(htlcs.bobToAlice.map { it.second }.toSet(), actionsAlice3.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
resolveHtlcs(alice4, bob4, htlcs, commitmentsCount = 2)
}

@Test
fun `disconnect -- commit_sig received by alice`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx()
Expand Down Expand Up @@ -814,6 +841,35 @@ class SpliceTestsCommon : LightningTestSuite() {
resolveHtlcs(alice6, bob5, htlcs, commitmentsCount = 2)
}

@Test
fun `disconnect -- commit_sig received by alice -- simple taproot channels`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx(channelType = ChannelType.SupportedChannelType.SimpleTaprootStaging)
val (alice1, bob1, htlcs) = setupHtlcs(alice, bob)
val (alice2, _, bob2, commitSigBob1) = spliceInAndOutWithoutSigs(alice1, bob1, inAmounts = listOf(50_000.sat), outAmount = 100_000.sat)
val (alice3, actionsAlice3) = alice2.process(ChannelCommand.MessageReceived(commitSigBob1))
assertIs<LNChannel<Normal>>(alice3)
assertTrue(actionsAlice3.isEmpty())
val spliceStatus = alice3.state.spliceStatus
assertIs<SpliceStatus.WaitingForSigs>(spliceStatus)

val (alice4, bob3, channelReestablishAlice) = disconnect(alice3, bob2)
assertEquals(channelReestablishAlice.nextFundingTxId, spliceStatus.session.fundingTx.txId)
val (bob4, actionsBob4) = bob3.process(ChannelCommand.MessageReceived(channelReestablishAlice))
assertIs<LNChannel<Normal>>(bob4)
assertEquals(actionsBob4.size, 4)
val channelReestablishBob = actionsBob4.findOutgoingMessage<ChannelReestablish>()
val commitSigBob2 = actionsBob4.findOutgoingMessage<CommitSig>()
assertEquals(htlcs.aliceToBob.map { it.second }.toSet(), actionsBob4.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
assertEquals(channelReestablishBob.nextFundingTxId, spliceStatus.session.fundingTx.txId)
val (alice5, actionsAlice5) = alice4.process(ChannelCommand.MessageReceived(channelReestablishBob))
assertIs<LNChannel<Normal>>(alice5)
assertEquals(actionsAlice5.size, 3)
val commitSigAlice = actionsAlice5.findOutgoingMessage<CommitSig>()
assertEquals(htlcs.bobToAlice.map { it.second }.toSet(), actionsAlice5.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
val (alice6, bob5) = exchangeSpliceSigs(alice5, commitSigAlice, bob4, commitSigBob2)
resolveHtlcs(alice6, bob5, htlcs, commitmentsCount = 2)
}

@Test
fun `disconnect -- tx_signatures sent by bob`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx()
Expand Down Expand Up @@ -861,6 +917,53 @@ class SpliceTestsCommon : LightningTestSuite() {
actionsBob6.has<ChannelAction.Storage.StoreState>()
}

@Test
fun `disconnect -- tx_signatures sent by bob -- simple taproot channels`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx(channelType = ChannelType.SupportedChannelType.SimpleTaprootStaging)
val (alice0, bob0, htlcs) = setupHtlcs(alice, bob)
val (alice1, commitSigAlice1, bob1, _) = spliceInAndOutWithoutSigs(alice0, bob0, inAmounts = listOf(80_000.sat), outAmount = 50_000.sat)
val (bob2, actionsBob2) = bob1.process(ChannelCommand.MessageReceived(commitSigAlice1))
assertIs<LNChannel<Normal>>(bob2)
val spliceTxId = actionsBob2.hasOutgoingMessage<TxSignatures>().txId
assertEquals(bob2.state.spliceStatus, SpliceStatus.None)

val (alice2, bob3, channelReestablishAlice) = disconnect(alice1, bob2)
assertEquals(channelReestablishAlice.nextFundingTxId, spliceTxId)
val (bob4, actionsBob4) = bob3.process(ChannelCommand.MessageReceived(channelReestablishAlice))
assertEquals(actionsBob4.size, 5)
val channelReestablishBob = actionsBob4.findOutgoingMessage<ChannelReestablish>()
val commitSigBob2 = actionsBob4.findOutgoingMessage<CommitSig>()
assertEquals(htlcs.aliceToBob.map { it.second }.toSet(), actionsBob4.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
val txSigsBob = actionsBob4.findOutgoingMessage<TxSignatures>()
assertEquals(channelReestablishBob.nextFundingTxId, spliceTxId)
val (alice3, actionsAlice3) = alice2.process(ChannelCommand.MessageReceived(channelReestablishBob))
assertEquals(actionsAlice3.size, 3)
assertEquals(htlcs.bobToAlice.map { it.second }.toSet(), actionsAlice3.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
val commitSigAlice2 = actionsAlice3.findOutgoingMessage<CommitSig>()

val (alice4, actionsAlice4) = alice3.process(ChannelCommand.MessageReceived(commitSigBob2))
assertTrue(actionsAlice4.isEmpty())
val (alice5, actionsAlice5) = alice4.process(ChannelCommand.MessageReceived(txSigsBob))
assertIs<LNChannel<Normal>>(alice5)
assertEquals(alice5.state.commitments.active.size, 2)
assertEquals(actionsAlice5.size, 8)
assertEquals(actionsAlice5.hasPublishTx(ChannelAction.Blockchain.PublishTx.Type.FundingTx).txid, spliceTxId)
assertEquals(htlcs.bobToAlice.map { it.second }.toSet(), actionsAlice5.filterIsInstance<ChannelAction.ProcessIncomingHtlc>().map { it.add }.toSet())
actionsAlice5.hasWatchConfirmed(spliceTxId)
actionsAlice5.has<ChannelAction.Storage.StoreState>()
actionsAlice5.has<ChannelAction.Storage.StoreOutgoingPayment.ViaSpliceOut>()
val txSigsAlice = actionsAlice5.findOutgoingMessage<TxSignatures>()

val (bob5, actionsBob5) = bob4.process(ChannelCommand.MessageReceived(commitSigAlice2))
assertTrue(actionsBob5.isEmpty())
val (bob6, actionsBob6) = bob5.process(ChannelCommand.MessageReceived(txSigsAlice))
assertIs<LNChannel<Normal>>(bob6)
assertEquals(bob6.state.commitments.active.size, 2)
assertEquals(actionsBob6.size, 2)
assertEquals(actionsBob6.hasPublishTx(ChannelAction.Blockchain.PublishTx.Type.FundingTx).txid, spliceTxId)
actionsBob6.has<ChannelAction.Storage.StoreState>()
}

@Test
fun `disconnect -- tx_signatures sent by bob -- zero-conf`() {
val (alice, bob) = reachNormalWithConfirmedFundingTx(zeroConf = true)
Expand Down Expand Up @@ -1448,8 +1551,8 @@ class SpliceTestsCommon : LightningTestSuite() {
companion object {
private val spliceFeerate = FeeratePerKw(253.sat)

private fun reachNormalWithConfirmedFundingTx(zeroConf: Boolean = false): Pair<LNChannel<Normal>, LNChannel<Normal>> {
val (alice, bob) = reachNormal(zeroConf = zeroConf)
private fun reachNormalWithConfirmedFundingTx(channelType: ChannelType.SupportedChannelType = ChannelType.SupportedChannelType.AnchorOutputs, zeroConf: Boolean = false): Pair<LNChannel<Normal>, LNChannel<Normal>> {
val (alice, bob) = reachNormal(channelType = channelType, zeroConf = zeroConf)
val fundingTx = alice.commitments.latest.localFundingStatus.signedTx!!
val (alice1, _) = alice.process(ChannelCommand.WatchReceived(WatchEventConfirmed(alice.channelId, BITCOIN_FUNDING_DEPTHOK, 42, 3, fundingTx)))
val (bob1, _) = bob.process(ChannelCommand.WatchReceived(WatchEventConfirmed(bob.channelId, BITCOIN_FUNDING_DEPTHOK, 42, 3, fundingTx)))
Expand Down

0 comments on commit affe638

Please sign in to comment.