Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor peer connection establishment #541

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ class ElectrumClient(
?: ServerError(request, JsonRPCError(0, "timeout"))
return when (result) {
is ServerError -> {
logger.warning { "received error for ${request.method}: ${result.error.message}" }
when (request) {
// Some electrum servers don't seem to respond to ping requests, even though they keep the connection alive.
is Ping -> logger.debug { "received error for ${request.method}: ${result.error.message}" }
else -> logger.warning { "received error for ${request.method}: ${result.error.message}" }
}
Either.Left(result)
}
else -> Either.Right(result as T)
Expand Down
221 changes: 122 additions & 99 deletions src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.channels.onFailure
import kotlinx.coroutines.flow.*
import org.kodein.log.newLogger
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

sealed class PeerCommand
Expand Down Expand Up @@ -275,137 +276,159 @@ class Peer(
)
}

fun connect() {
if (connectionState.value is Connection.CLOSED) {
data class ConnectionJob(val job: Job, val socket: TcpSocket) {
fun cancel() {
job.cancel()
socket.close()
}
}

private var connectionJob: ConnectionJob? = null

suspend fun connect(connectTimeout: Duration, handshakeTimeout: Duration): Boolean {
return if (connectionState.value is Connection.CLOSED) {
// Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection.
// That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app.
input.send(Disconnected)
_connectionState.value = Connection.ESTABLISHING
establishConnection()

val connectionId = currentTimestampMillis()
val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId))
logger.info { "connecting to ${walletParams.trampolineNode.host}" }
val socket = openSocket(connectTimeout) ?: return false
dpad85 marked this conversation as resolved.
Show resolved Hide resolved

val priv = nodeParams.nodePrivateKey
val pub = priv.publicKey()
val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray())
val (enc, dec, ck) = try {
withTimeout(handshakeTimeout) {
handshake(
keyPair,
remoteNodeId.value.toByteArray(),
{ s -> socket.receiveFully(s) },
{ b -> socket.send(b) }
)
}
} catch (ex: TcpSocket.IOException) {
logger.warning(ex) { "Noise handshake: ${ex.message}: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
return false
}

val session = LightningSession(enc, dec, ck)
// TODO use atomic counter instead
val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger)
// Inform the peer about the new connection.
input.send(Connected(peerConnection))
connectionJob = connectionLoop(socket, session, peerConnection, logger)
true
} else {
logger.warning { "Peer is already connecting / connected" }
false
}
}

fun disconnect() {
if (this::socket.isInitialized) socket.close()
connectionJob?.cancel()
connectionJob = null
_connectionState.value = Connection.CLOSED(null)
}

// Warning : lateinit vars have to be used AFTER their init to avoid any crashes
//
// This shouldn't be used outside the establishConnection() function
// Except from the disconnect() one that check if the lateinit var has been initialized
private lateinit var socket: TcpSocket
private fun establishConnection() = launch {
// Clean up previous connection state: we do this here to ensure that it is handled before the Connected event for the new connection.
// That means we're not sending this event if we don't reconnect. It's ok, since that has the same effect as not detecting a disconnection and closing the app.
input.send(Disconnected)

val connectionId = currentTimestampMillis()
val logger = MDCLogger(nodeParams.loggerFactory.newLogger(this::class), staticMdc = mapOf("remoteNodeId" to remoteNodeId, "connectionId" to connectionId))

logger.info { "connecting to ${walletParams.trampolineNode.host}" }
socket = try {
socketBuilder?.connect(
host = walletParams.trampolineNode.host,
port = walletParams.trampolineNode.port,
tls = TcpSocket.TLS.DISABLED,
loggerFactory = nodeParams.loggerFactory
) ?: error("socket builder is null.")
private suspend fun openSocket(timeout: Duration): TcpSocket? {
var socket: TcpSocket? = null
return try {
withTimeout(timeout) {
socket = socketBuilder?.connect(
host = walletParams.trampolineNode.host,
port = walletParams.trampolineNode.port,
tls = TcpSocket.TLS.DISABLED,
loggerFactory = nodeParams.loggerFactory
) ?: error("socket builder is null.")
socket
}
} catch (ex: Throwable) {
logger.warning(ex) { "TCP connect: ${ex.message}: " }
val ioException = when (ex) {
is TcpSocket.IOException -> ex
else -> TcpSocket.IOException.ConnectionRefused(ex)
}
socket?.close()
_connectionState.value = Connection.CLOSED(ioException)
return@launch
}

fun closeSocket(ex: TcpSocket.IOException?) {
if (_connectionState.value is Connection.CLOSED) return
logger.warning(ex) { "closing TCP socket: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
cancel()
null
}
}

val priv = nodeParams.nodePrivateKey
val pub = priv.publicKey()
val keyPair = Pair(pub.value.toByteArray(), priv.value.toByteArray())
val (enc, dec, ck) = try {
handshake(
keyPair,
remoteNodeId.value.toByteArray(),
{ s -> socket.receiveFully(s) },
{ b -> socket.send(b) }
)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP handshake: ${ex.message}" }
closeSocket(ex)
return@launch
}
val session = LightningSession(enc, dec, ck)

// TODO use atomic counter instead
val peerConnection = PeerConnection(connectionId, Channel(UNLIMITED), logger)
// Inform the peer about the new connection.
input.send(Connected(peerConnection))

suspend fun doPing() {
val ping = Ping(10, ByteVector("deadbeef"))
while (isActive) {
delay(30.seconds)
peerConnection.send(ping)
private fun connectionLoop(socket: TcpSocket, session: LightningSession, peerConnection: PeerConnection, logger: MDCLogger): ConnectionJob {
val job = launch {
fun closeSocket(ex: TcpSocket.IOException?) {
if (_connectionState.value is Connection.CLOSED) return
logger.warning(ex) { "closing TCP socket: " }
socket.close()
_connectionState.value = Connection.CLOSED(ex)
cancel()
}
}

suspend fun checkPaymentsTimeout() {
while (isActive) {
delay(10.seconds) // we schedule a check every 10 seconds
input.send(CheckPaymentsTimeout)
suspend fun doPing() {
val ping = Ping(10, ByteVector("deadbeef"))
while (isActive) {
delay(30.seconds)
peerConnection.send(ping)
}
}
}

suspend fun receiveLoop() {
try {
suspend fun checkPaymentsTimeout() {
while (isActive) {
val received = session.receive { size -> socket.receiveFully(size) }
try {
val msg = LightningMessage.decode(received)
input.send(MessageReceived(peerConnection.id, msg))
} catch (e: Throwable) {
logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" }
delay(10.seconds) // we schedule a check every 10 seconds
input.send(CheckPaymentsTimeout)
}
}

suspend fun receiveLoop() {
try {
while (isActive) {
val received = session.receive { size -> socket.receiveFully(size) }
try {
val msg = LightningMessage.decode(received)
input.send(MessageReceived(peerConnection.id, msg))
} catch (e: Throwable) {
logger.warning { "cannot deserialize message: ${received.byteVector().toHex()}" }
}
}
closeSocket(null)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP receive: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
closeSocket(null)
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP receive: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
}

suspend fun sendLoop() {
try {
for (msg in peerConnection.output) {
// Avoids polluting the logs with pings/pongs
if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" }
val encoded = LightningMessage.encode(msg)
session.send(encoded) { data, flush -> socket.send(data, flush) }
suspend fun sendLoop() {
try {
for (msg in peerConnection.output) {
// Avoids polluting the logs with pings/pongs
if (msg !is Ping && msg !is Pong) logger.info { "sending $msg" }
val encoded = LightningMessage.encode(msg)
session.send(encoded) { data, flush -> socket.send(data, flush) }
}
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP send: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
} catch (ex: TcpSocket.IOException) {
logger.warning { "TCP send: ${ex.message}" }
closeSocket(ex)
} finally {
peerConnection.output.close()
}
}

launch { doPing() }
launch { checkPaymentsTimeout() }
launch { sendLoop() }
launch(CoroutineName("keep-alive")) { doPing() }
launch(CoroutineName("check-payments-timeout")) { checkPaymentsTimeout() }
launch(CoroutineName("send-loop")) { sendLoop() }
val receiveJob = launch(CoroutineName("receive-loop")) { receiveLoop() }
// Suspend until the coroutine is cancelled or the socket is closed.
receiveJob.join()
}

receiveLoop() // This suspends until the coroutines is cancelled or the socket is closed
return ConnectionJob(job, socket)
}

/** We try swapping funds in whenever one of those fields is updated. */
Expand Down
Loading