From 2746ddd33d234534b184736f8a0c4c6732905ab4 Mon Sep 17 00:00:00 2001 From: t-bast Date: Wed, 27 Sep 2023 10:01:23 +0200 Subject: [PATCH] Refactor connection establishment Use the same pattern as the `ElectrumClient`, with an explicit job and timeouts. The `connect` function now suspends, which makes it easier to handle from the caller's point of view. Fixes #531 --- .../kotlin/fr/acinq/lightning/io/Peer.kt | 221 ++++++++++-------- 1 file changed, 122 insertions(+), 99 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index 95434404c..d23581c6c 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -28,6 +28,7 @@ import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.onFailure import kotlinx.coroutines.flow.* import org.kodein.log.newLogger +import kotlin.time.Duration import kotlin.time.Duration.Companion.seconds sealed class PeerCommand @@ -278,137 +279,159 @@ class Peer( ) } - fun connect() { - if (connectionState.value is Connection.CLOSED) { + data class ConnectionJob(val job: Job, val socket: TcpSocket) { + fun cancel() { + job.cancel() + socket.close() + } + } + + private var connectionJob: ConnectionJob? = null + + suspend fun connect(connectTimeout: Duration, handshakeTimeout: Duration): Boolean { + return if (connectionState.value is Connection.CLOSED) { + // Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection. + // That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app. + input.send(Disconnected) _connectionState.value = Connection.ESTABLISHING - establishConnection() + + val connectionId = currentTimestampMillis() + val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId)) + logger.info { "connecting to ${walletParams.trampolineNode.host}" } + val socket = openSocket(connectTimeout) ?: return false + + val priv = nodeParams.nodePrivateKey + val pub = priv.publicKey() + val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray()) + val (enc, dec, ck) = try { + withTimeout(handshakeTimeout) { + handshake( + keyPair, + remoteNodeId.value.toByteArray(), + { s -> socket.receiveFully(s) }, + { b -> socket.send(b) } + ) + } + } catch (ex: TcpSocket.IOException) { + logger.warning(ex) { "Noise handshake: ${ex.message}: " } + socket.close() + _connectionState.value = Connection.CLOSED(ex) + return false + } + + val session = LightningSession(enc, dec, ck) + // TODO use atomic counter instead + val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger) + // Inform the peer about the new connection. + input.send(Connected(peerConnection)) + connectionJob = connectionLoop(socket, session, peerConnection, logger) + true } else { logger.warning { "Peer is already connecting / connected" } + false } } fun disconnect() { - if (this::socket.isInitialized) socket.close() + connectionJob?.cancel() + connectionJob = null _connectionState.value = Connection.CLOSED(null) } - // Warning : lateinit vars have to be used AFTER their init to avoid any crashes - // - // This shouldn't be used outside the establishConnection() function - // Except from the disconnect() one that check if the lateinit var has been initialized - private lateinit var socket: TcpSocket - private fun establishConnection() = launch { - // Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection. - // That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app. - input.send(Disconnected) - - val connectionId = currentTimestampMillis() - val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId)) - - logger.info { "connecting to ${walletParams.trampolineNode.host}" } - socket = try { - socketBuilder?.connect( - host = walletParams.trampolineNode.host, - port = walletParams.trampolineNode.port, - tls = TcpSocket.TLS.DISABLED, - loggerFactory = nodeParams.loggerFactory - ) ?: error("socket builder is null.") + private suspend fun openSocket(timeout: Duration): TcpSocket? { + var socket: TcpSocket? = null + return try { + withTimeout(timeout) { + socket = socketBuilder?.connect( + host = walletParams.trampolineNode.host, + port = walletParams.trampolineNode.port, + tls = TcpSocket.TLS.DISABLED, + loggerFactory = nodeParams.loggerFactory + ) ?: error("socket builder is null.") + socket + } } catch (ex: Throwable) { logger.warning(ex) { "TCP connect: ${ex.message}: " } val ioException = when (ex) { is TcpSocket.IOException -> ex else -> TcpSocket.IOException.ConnectionRefused(ex) } + socket?.close() _connectionState.value = Connection.CLOSED(ioException) - return@launch - } - - fun closeSocket(ex: TcpSocket.IOException?) { - if (_connectionState.value is Connection.CLOSED) return - logger.warning(ex) { "closing TCP socket: " } - socket.close() - _connectionState.value = Connection.CLOSED(ex) - cancel() + null } + } - val priv = nodeParams.nodePrivateKey - val pub = priv.publicKey() - val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray()) - val (enc, dec, ck) = try { - handshake( - keyPair, - remoteNodeId.value.toByteArray(), - { s -> socket.receiveFully(s) }, - { b -> socket.send(b) } - ) - } catch (ex: TcpSocket.IOException) { - logger.warning { "TCP handshake: ${ex.message}" } - closeSocket(ex) - return@launch - } - val session = LightningSession(enc, dec, ck) - - // TODO use atomic counter instead - val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger) - // Inform the peer about the new connection. - input.send(Connected(peerConnection)) - - suspend fun doPing() { - val ping = Ping(10, ByteVector("deadbeef")) - while (isActive) { - delay(30.seconds) - peerConnection.send(ping) + private fun connectionLoop(socket: TcpSocket, session: LightningSession, peerConnection: PeerConnection, logger: MDCLogger): ConnectionJob { + val job = launch { + fun closeSocket(ex: TcpSocket.IOException?) { + if (_connectionState.value is Connection.CLOSED) return + logger.warning(ex) { "closing TCP socket: " } + socket.close() + _connectionState.value = Connection.CLOSED(ex) + cancel() } - } - suspend fun checkPaymentsTimeout() { - while (isActive) { - delay(10.seconds) // we schedule a check every 10 seconds - input.send(CheckPaymentsTimeout) + suspend fun doPing() { + val ping = Ping(10, ByteVector("deadbeef")) + while (isActive) { + delay(30.seconds) + peerConnection.send(ping) + } } - } - suspend fun receiveLoop() { - try { + suspend fun checkPaymentsTimeout() { while (isActive) { - val received = session.receive { size -> socket.receiveFully(size) } - try { - val msg = LightningMessage.decode(received) - input.send(MessageReceived(peerConnection.id, msg)) - } catch (e: Throwable) { - logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" } + delay(10.seconds) // we schedule a check every 10 seconds + input.send(CheckPaymentsTimeout) + } + } + + suspend fun receiveLoop() { + try { + while (isActive) { + val received = session.receive { size -> socket.receiveFully(size) } + try { + val msg = LightningMessage.decode(received) + input.send(MessageReceived(peerConnection.id, msg)) + } catch (e: Throwable) { + logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" } + } } + closeSocket(null) + } catch (ex: TcpSocket.IOException) { + logger.warning { "TCP receive: ${ex.message}" } + closeSocket(ex) + } finally { + peerConnection.output.close() } - closeSocket(null) - } catch (ex: TcpSocket.IOException) { - logger.warning { "TCP receive: ${ex.message}" } - closeSocket(ex) - } finally { - peerConnection.output.close() } - } - suspend fun sendLoop() { - try { - for (msg in peerConnection.output) { - // Avoids polluting the logs with pings/pongs - if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" } - val encoded = LightningMessage.encode(msg) - session.send(encoded) { data, flush -> socket.send(data, flush) } + suspend fun sendLoop() { + try { + for (msg in peerConnection.output) { + // Avoids polluting the logs with pings/pongs + if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" } + val encoded = LightningMessage.encode(msg) + session.send(encoded) { data, flush -> socket.send(data, flush) } + } + } catch (ex: TcpSocket.IOException) { + logger.warning { "TCP send: ${ex.message}" } + closeSocket(ex) + } finally { + peerConnection.output.close() } - } catch (ex: TcpSocket.IOException) { - logger.warning { "TCP send: ${ex.message}" } - closeSocket(ex) - } finally { - peerConnection.output.close() } - } - launch { doPing() } - launch { checkPaymentsTimeout() } - launch { sendLoop() } + launch(CoroutineName("keep-alive")) { doPing() } + launch(CoroutineName("check-payments-timeout")) { checkPaymentsTimeout() } + launch(CoroutineName("send-loop")) { sendLoop() } + val receiveJob = launch(CoroutineName("receive-loop")) { receiveLoop() } + // Suspend until the coroutine is cancelled or the socket is closed. + receiveJob.join() + } - receiveLoop() // This suspends until the coroutines is cancelled or the socket is closed + return ConnectionJob(job, socket) } /**