Skip to content

Commit

Permalink
Refactor connection establishment
Browse files Browse the repository at this point in the history
Use the same pattern as the `ElectrumClient`, with an explicit job and
timeouts. The `connect` function now suspends, which makes it easier to
handle from the caller's point of view.

Fixes #531
  • Loading branch information
t-bast committed Nov 3, 2023
1 parent 5f1cc0f commit ff7a089
Showing 1 changed file with 122 additions and 99 deletions.
221 changes: 122 additions & 99 deletions src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.channels.onFailure
import kotlinx.coroutines.flow.*
import org.kodein.log.newLogger
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

sealed class PeerCommand
Expand Down Expand Up @@ -275,137 +276,159 @@ class Peer(
)
}

fun connect() {
if (connectionState.value is Connection.CLOSED) {
data class ConnectionJob(val job: Job, val socket: TcpSocket) {
fun cancel() {
job.cancel()
socket.close()
}
}

private var connectionJob: ConnectionJob? = null

suspend fun connect(connectTimeout: Duration, handshakeTimeout: Duration): Boolean {
return if (connectionState.value is Connection.CLOSED) {
// Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection.
// That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app.
input.send(Disconnected)
_connectionState.value = Connection.ESTABLISHING
establishConnection()

val connectionId = currentTimestampMillis()
val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId))
logger.info { "connecting to ${walletParams.trampolineNode.host}" }
val socket = openSocket(connectTimeout) ?: return false

val priv = nodeParams.nodePrivateKey
val pub = priv.publicKey()
val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray())
val (enc, dec, ck) = try {
withTimeout(handshakeTimeout) {
handshake(
keyPair,
remoteNodeId.value.toByteArray(),
{ s -> socket.receiveFully(s) },
{ b -> socket.send(b) }
)
}
} catch (ex: TcpSocket.IOException) {
logger.warning(ex) { "Noise handshake: ${ex.message}: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
return false
}

val session = LightningSession(enc, dec, ck)
// TODO use atomic counter instead
val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger)
// Inform the peer about the new connection.
input.send(Connected(peerConnection))
connectionJob = connectionLoop(socket, session, peerConnection, logger)
true
} else {
logger.warning { "Peer is already connecting / connected" }
false
}
}

fun disconnect() {
if (this::socket.isInitialized) socket.close()
connectionJob?.cancel()
connectionJob = null
_connectionState.value = Connection.CLOSED(null)
}

// Warning : lateinit vars have to be used AFTER their init to avoid any crashes
//
// This shouldn't be used outside the establishConnection() function
// Except from the disconnect() one that check if the lateinit var has been initialized
private lateinit var socket: TcpSocket
private fun establishConnection() = launch {
// Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection.
// That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app.
input.send(Disconnected)

val connectionId = currentTimestampMillis()
val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId))

logger.info { "connecting to ${walletParams.trampolineNode.host}" }
socket = try {
socketBuilder?.connect(
host = walletParams.trampolineNode.host,
port = walletParams.trampolineNode.port,
tls = TcpSocket.TLS.DISABLED,
loggerFactory = nodeParams.loggerFactory
) ?: error("socket builder is null.")
private suspend fun openSocket(timeout: Duration): TcpSocket? {
var socket: TcpSocket? = null
return try {
withTimeout(timeout) {
socket = socketBuilder?.connect(
host = walletParams.trampolineNode.host,
port = walletParams.trampolineNode.port,
tls = TcpSocket.TLS.DISABLED,
loggerFactory = nodeParams.loggerFactory
) ?: error("socket builder is null.")
socket
}
} catch (ex: Throwable) {
logger.warning(ex) { "TCP connect: ${ex.message}: " }
val ioException = when (ex) {
is TcpSocket.IOException -> ex
else -> TcpSocket.IOException.ConnectionRefused(ex)
}
socket?.close()
_connectionState.value = Connection.CLOSED(ioException)
return@launch
}

fun closeSocket(ex: TcpSocket.IOException?) {
if (_connectionState.value is Connection.CLOSED) return
logger.warning(ex) { "closing TCP socket: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
cancel()
null
}
}

val priv = nodeParams.nodePrivateKey
val pub = priv.publicKey()
val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray())
val (enc, dec, ck) = try {
handshake(
keyPair,
remoteNodeId.value.toByteArray(),
{ s -> socket.receiveFully(s) },
{ b -> socket.send(b) }
)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP handshake: ${ex.message}" }
closeSocket(ex)
return@launch
}
val session = LightningSession(enc, dec, ck)

// TODO use atomic counter instead
val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger)
// Inform the peer about the new connection.
input.send(Connected(peerConnection))

suspend fun doPing() {
val ping = Ping(10, ByteVector("deadbeef"))
while (isActive) {
delay(30.seconds)
peerConnection.send(ping)
private fun connectionLoop(socket: TcpSocket, session: LightningSession, peerConnection: PeerConnection, logger: MDCLogger): ConnectionJob {
val job = launch {
fun closeSocket(ex: TcpSocket.IOException?) {
if (_connectionState.value is Connection.CLOSED) return
logger.warning(ex) { "closing TCP socket: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
cancel()
}
}

suspend fun checkPaymentsTimeout() {
while (isActive) {
delay(10.seconds) // we schedule a check every 10 seconds
input.send(CheckPaymentsTimeout)
suspend fun doPing() {
val ping = Ping(10, ByteVector("deadbeef"))
while (isActive) {
delay(30.seconds)
peerConnection.send(ping)
}
}
}

suspend fun receiveLoop() {
try {
suspend fun checkPaymentsTimeout() {
while (isActive) {
val received = session.receive { size -> socket.receiveFully(size) }
try {
val msg = LightningMessage.decode(received)
input.send(MessageReceived(peerConnection.id, msg))
} catch (e: Throwable) {
logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" }
delay(10.seconds) // we schedule a check every 10 seconds
input.send(CheckPaymentsTimeout)
}
}

suspend fun receiveLoop() {
try {
while (isActive) {
val received = session.receive { size -> socket.receiveFully(size) }
try {
val msg = LightningMessage.decode(received)
input.send(MessageReceived(peerConnection.id, msg))
} catch (e: Throwable) {
logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" }
}
}
closeSocket(null)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP receive: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
closeSocket(null)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP receive: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
}

suspend fun sendLoop() {
try {
for (msg in peerConnection.output) {
// Avoids polluting the logs with pings/pongs
if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" }
val encoded = LightningMessage.encode(msg)
session.send(encoded) { data, flush -> socket.send(data, flush) }
suspend fun sendLoop() {
try {
for (msg in peerConnection.output) {
// Avoids polluting the logs with pings/pongs
if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" }
val encoded = LightningMessage.encode(msg)
session.send(encoded) { data, flush -> socket.send(data, flush) }
}
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP send: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP send: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
}

launch { doPing() }
launch { checkPaymentsTimeout() }
launch { sendLoop() }
launch(CoroutineName("keep-alive")) { doPing() }
launch(CoroutineName("check-payments-timeout")) { checkPaymentsTimeout() }
launch(CoroutineName("send-loop")) { sendLoop() }
val receiveJob = launch(CoroutineName("receive-loop")) { receiveLoop() }
// Suspend until the coroutine is cancelled or the socket is closed.
receiveJob.join()
}

receiveLoop() // This suspends until the coroutines is cancelled or the socket is closed
return ConnectionJob(job, socket)
}

/** We try swapping funds in whenever one of those fields is updated. */
Expand Down

0 comments on commit ff7a089

Please sign in to comment.