From 3ca49a2f402c1d8f925d9ab6cc23d6d0ab3eeae5 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Tue, 14 May 2024 07:10:34 +0200 Subject: [PATCH] fix(transport): various tcp transport races (#1095) Co-authored-by: diegomrsantos --- libp2p/dialer.nim | 12 +- libp2p/errors.nim | 9 - libp2p/switch.nim | 19 +- libp2p/transports/tcptransport.nim | 453 +++++++++++++++++------------ libp2p/transports/tortransport.nim | 2 +- libp2p/transports/transport.nim | 2 +- libp2p/wire.nim | 9 +- 7 files changed, 280 insertions(+), 226 deletions(-) diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index ab6e6873d4..d3cb926b7f 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -81,16 +81,18 @@ proc dialAndUpgrade( if dialed.dir != dir: dialed.dir = dir await transport.upgrade(dialed, peerId) + except CancelledError as exc: + await dialed.close() + raise exc except CatchableError as exc: # If we failed to establish the connection through one transport, # we won't succeeded through another - no use in trying again await dialed.close() debug "Connection upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId)) - if exc isnot CancelledError: - if dialed.dir == Direction.Out: - libp2p_failed_upgrades_outgoing.inc() - else: - libp2p_failed_upgrades_incoming.inc() + if dialed.dir == Direction.Out: + libp2p_failed_upgrades_outgoing.inc() + else: + libp2p_failed_upgrades_incoming.inc() # Try other address return nil diff --git a/libp2p/errors.nim b/libp2p/errors.nim index ea47e12e06..2eadd372fa 100644 --- a/libp2p/errors.nim +++ b/libp2p/errors.nim @@ -44,12 +44,3 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped = # We still don't abort but warn debug "A future has failed, enable trace logging for details", error=exc.name trace "Exception details", msg=exc.msg - -template tryAndWarn*(message: static[string]; body: untyped): untyped = - try: - body - except CancelledError as exc: - raise exc - except CatchableError as exc: - debug "An exception has ocurred, enable trace logging for details", name = exc.name, msg = message - trace "Exception details", exc = exc.msg diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 7fc1bada84..518a48c846 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -273,6 +273,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises except CancelledError as exc: trace "releasing semaphore on cancellation" upgrades.release() # always release the slot + return except CatchableError as exc: error "Exception in accept loop, exiting", exc = exc.msg upgrades.release() # always release the slot @@ -288,6 +289,12 @@ proc stop*(s: Switch) {.async, public.} = s.started = false + try: + # Stop accepting incoming connections + await allFutures(s.acceptFuts.mapIt(it.cancelAndWait())).wait(1.seconds) + except CatchableError as exc: + debug "Cannot cancel accepts", error = exc.msg + for service in s.services: discard await service.stop(s) @@ -302,18 +309,6 @@ proc stop*(s: Switch) {.async, public.} = except CatchableError as exc: warn "error cleaning up transports", msg = exc.msg - try: - await allFutures(s.acceptFuts) - .wait(1.seconds) - except CatchableError as exc: - trace "Exception while stopping accept loops", exc = exc.msg - - # check that all futures were properly - # stopped and otherwise cancel them - for a in s.acceptFuts: - if not a.finished: - a.cancel() - for service in s.services: discard await service.stop(s) diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index d7bc25d1a6..bcf398c10f 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -12,262 +12,327 @@ {.push raises: [].} import std/[sequtils] -import stew/results import chronos, chronicles -import transport, - ../errors, - ../wire, - ../multicodec, - ../connmanager, - ../multiaddress, - ../stream/connection, - ../stream/chronosstream, - ../upgrademngrs/upgrade, - ../utility +import + ./transport, + ../wire, + ../multiaddress, + ../stream/connection, + ../stream/chronosstream, + ../upgrademngrs/upgrade, + ../utility logScope: topics = "libp2p tcptransport" -export transport, results +export transport, connection, upgrade -const - TcpTransportTrackerName* = "libp2p.tcptransport" +const TcpTransportTrackerName* = "libp2p.tcptransport" type + AcceptFuture = typeof(default(StreamServer).accept()) + TcpTransport* = ref object of Transport servers*: seq[StreamServer] clients: array[Direction, seq[StreamTransport]] flags: set[ServerFlags] clientFlags: set[SocketFlags] - acceptFuts: seq[Future[StreamTransport]] + acceptFuts: seq[AcceptFuture] connectionsTimeout: Duration + stopping: bool TcpTransportError* = object of transport.TransportError -proc connHandler*(self: TcpTransport, - client: StreamTransport, - observedAddr: Opt[MultiAddress], - dir: Direction): Future[Connection] {.async.} = - - trace "Handling tcp connection", address = $observedAddr, - dir = $dir, - clients = self.clients[Direction.In].len + - self.clients[Direction.Out].len +proc connHandler*( + self: TcpTransport, + client: StreamTransport, + observedAddr: Opt[MultiAddress], + dir: Direction, +): Connection = + trace "Handling tcp connection", + address = $observedAddr, + dir = $dir, + clients = self.clients[Direction.In].len + self.clients[Direction.Out].len let conn = Connection( ChronosStream.init( client = client, dir = dir, observedAddr = observedAddr, - timeout = self.connectionsTimeout - )) + timeout = self.connectionsTimeout, + ) + ) proc onClose() {.async: (raises: []).} = - try: - block: - let - fut1 = client.join() - fut2 = conn.join() - try: # https://github.com/status-im/nim-chronos/issues/516 - discard await race(fut1, fut2) - except ValueError: raiseAssert("Futures list is not empty") - # at least one join() completed, cancel pending one, if any - if not fut1.finished: await fut1.cancelAndWait() - if not fut2.finished: await fut2.cancelAndWait() + await noCancel client.join() - trace "Cleaning up client", addrs = $client.remoteAddress, - conn + trace "Cleaning up client", addrs = $client.remoteAddress, conn - self.clients[dir].keepItIf( it != client ) + self.clients[dir].keepItIf(it != client) - block: - let - fut1 = conn.close() - fut2 = client.closeWait() - await allFutures(fut1, fut2) - if fut1.failed: - let err = fut1.error() - debug "Error cleaning up client", errMsg = err.msg, conn - static: doAssert typeof(fut2).E is void # Cannot fail + # Propagate the chronos client being closed to the connection + # TODO This is somewhat dubious since it's the connection that owns the + # client, but it allows the transport to close all connections when + # shutting down (also dubious! it would make more sense that the owner + # of all connections closes them, or the next read detects the closed + # socket and does the right thing..) - trace "Cleaned up client", addrs = $client.remoteAddress, - conn + await conn.close() - except CancelledError as exc: - let useExc {.used.} = exc - debug "Error cleaning up client", errMsg = exc.msg, conn + trace "Cleaned up client", addrs = $client.remoteAddress, conn self.clients[dir].add(client) + asyncSpawn onClose() return conn proc new*( - T: typedesc[TcpTransport], - flags: set[ServerFlags] = {}, - upgrade: Upgrade, - connectionsTimeout = 10.minutes): T {.public.} = - - let - transport = T( - flags: flags, - clientFlags: - if ServerFlags.TcpNoDelay in flags: - compilesOr: - {SocketFlags.TcpNoDelay} - do: - doAssert(false) - default(set[SocketFlags]) - else: - default(set[SocketFlags]), - upgrader: upgrade, - networkReachability: NetworkReachability.Unknown, - connectionsTimeout: connectionsTimeout) - - return transport - -method start*( - self: TcpTransport, - addrs: seq[MultiAddress]) {.async.} = - ## listen on the transport - ## - - if self.running: - warn "TCP transport already running" - return - - await procCall Transport(self).start(addrs) - trace "Starting TCP transport" - trackCounter(TcpTransportTrackerName) + T: typedesc[TcpTransport], + flags: set[ServerFlags] = {}, + upgrade: Upgrade, + connectionsTimeout = 10.minutes, +): T {.public.} = + T( + flags: flags, + clientFlags: + if ServerFlags.TcpNoDelay in flags: + {SocketFlags.TcpNoDelay} + else: + default(set[SocketFlags]) + , + upgrader: upgrade, + networkReachability: NetworkReachability.Unknown, + connectionsTimeout: connectionsTimeout, + ) + +method start*(self: TcpTransport, addrs: seq[MultiAddress]): Future[void] = + ## Start transport listening to the given addresses - for dial-only transports, + ## start with an empty list + + # TODO remove `impl` indirection throughout when `raises` is added to base + + proc impl( + self: TcpTransport, addrs: seq[MultiAddress] + ): Future[void] {.async: (raises: [transport.TransportError, CancelledError]).} = + if self.running: + warn "TCP transport already running" + return - for i, ma in addrs: - if not self.handles(ma): - trace "Invalid address detected, skipping!", address = ma - continue + trace "Starting TCP transport" self.flags.incl(ServerFlags.ReusePort) - let server = createStreamServer( - ma = ma, - flags = self.flags, - udata = self) - # always get the resolved address in case we're bound to 0.0.0.0:0 - self.addrs[i] = MultiAddress.init( - server.sock.getLocalAddress() - ).tryGet() + var supported: seq[MultiAddress] + var initialized = false + try: + for i, ma in addrs: + if not self.handles(ma): + trace "Invalid address detected, skipping!", address = ma + continue - self.servers &= server + let + ta = initTAddress(ma).expect("valid address per handles check above") + server = + try: + createStreamServer(ta, flags = self.flags) + except common.TransportError as exc: + raise (ref TcpTransportError)(msg: exc.msg, parent: exc) + + self.servers &= server + + trace "Listening on", address = ma + supported.add( + MultiAddress.init(server.sock.getLocalAddress()).expect( + "Can init from local address" + ) + ) + + initialized = true + finally: + if not initialized: + # Clean up partial success on exception + await noCancel allFutures(self.servers.mapIt(it.closeWait())) + reset(self.servers) - trace "Listening on", address = ma + try: + await procCall Transport(self).start(supported) + except CatchableError: + raiseAssert "Base method does not raise" -method stop*(self: TcpTransport) {.async.} = - ## stop the transport - ## - try: - trace "Stopping TCP transport" + trackCounter(TcpTransportTrackerName) - checkFutures( - await allFinished( - self.clients[Direction.In].mapIt(it.closeWait()) & - self.clients[Direction.Out].mapIt(it.closeWait()))) + impl(self, addrs) - if not self.running: +method stop*(self: TcpTransport): Future[void] = + ## Stop the transport and close all connections it created + proc impl(self: TcpTransport) {.async: (raises: []).} = + trace "Stopping TCP transport" + self.stopping = true + defer: + self.stopping = false + + if self.running: + # Reset the running flag + try: + await noCancel procCall Transport(self).stop() + except CatchableError: # TODO remove when `accept` is annotated with raises + raiseAssert "doesn't actually raise" + + # Stop each server by closing the socket - this will cause all accept loops + # to fail - since the running flag has been reset, it's also safe to close + # all known clients since no more of them will be added + await noCancel allFutures( + self.servers.mapIt(it.closeWait()) & + self.clients[Direction.In].mapIt(it.closeWait()) & + self.clients[Direction.Out].mapIt(it.closeWait()) + ) + + self.servers = @[] + + for acceptFut in self.acceptFuts: + if acceptFut.completed(): + await acceptFut.value().closeWait() + self.acceptFuts = @[] + + if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0: + # Future updates could consider turning this warn into an assert since + # it should never happen if the shutdown code is correct + warn "Couldn't clean up clients", + len = self.clients[Direction.In].len + self.clients[Direction.Out].len + + trace "Transport stopped" + untrackCounter(TcpTransportTrackerName) + else: + # For legacy reasons, `stop` on a transpart that wasn't started is + # expected to close outgoing connections created by the transport warn "TCP transport already stopped" - return - - await procCall Transport(self).stop() # call base - var toWait: seq[Future[void]] - for fut in self.acceptFuts: - if not fut.finished: - toWait.add(fut.cancelAndWait()) - elif fut.done: - toWait.add(fut.read().closeWait()) - - for server in self.servers: - server.stop() - toWait.add(server.closeWait()) - await allFutures(toWait) + doAssert self.clients[Direction.In].len == 0, + "No incoming connections possible without start" + await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait())) - self.servers = @[] - self.acceptFuts = @[] + impl(self) - trace "Transport stopped" - untrackCounter(TcpTransportTrackerName) - except CatchableError as exc: - trace "Error shutting down tcp transport", exc = exc.msg - -method accept*(self: TcpTransport): Future[Connection] {.async.} = - ## accept a new TCP connection +method accept*(self: TcpTransport): Future[Connection] = + ## accept a new TCP connection, returning nil on non-fatal errors ## - - if not self.running: - raise newTransportClosedError() - - try: - if self.acceptFuts.len <= 0: - self.acceptFuts = self.servers.mapIt(Future[StreamTransport](it.accept())) + ## Raises an exception when the transport is broken and cannot be used for + ## accepting further connections + # TODO returning nil for non-fatal errors is problematic in that error + # information is lost and must be logged here instead of being + # available to the caller - further refactoring should propagate errors + # to the caller instead + proc impl( + self: TcpTransport + ): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} = + if not self.running: + raise newTransportClosedError() if self.acceptFuts.len <= 0: - return + self.acceptFuts = self.servers.mapIt(it.accept()) let - finished = await one(self.acceptFuts) + finished = + try: + await one(self.acceptFuts) + except ValueError: + raise (ref TcpTransportError)(msg: "No listeners configured") + index = self.acceptFuts.find(finished) + transp = + try: + await finished + except TransportTooManyError as exc: + debug "Too many files opened", exc = exc.msg + return nil + except TransportAbortedError as exc: + debug "Connection aborted", exc = exc.msg + return nil + except TransportUseClosedError as exc: + raise newTransportClosedError(exc) + except TransportOsError as exc: + raise (ref TcpTransportError)(msg: exc.msg, parent: exc) + except common.TransportError as exc: # Needed for chronos 4.0.0 support + raise (ref TcpTransportError)(msg: exc.msg, parent: exc) + except CancelledError as exc: + raise exc + + if not self.running: # Stopped while waiting + await transp.closeWait() + raise newTransportClosedError() self.acceptFuts[index] = self.servers[index].accept() - let transp = await finished - try: - let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet() - return await self.connHandler(transp, Opt.some(observedAddr), Direction.In) - except CancelledError as exc: - debug "CancelledError", exc = exc.msg - transp.close() - raise exc - except CatchableError as exc: - debug "Failed to handle connection", exc = exc.msg - transp.close() - except TransportTooManyError as exc: - debug "Too many files opened", exc = exc.msg - except TransportAbortedError as exc: - debug "Connection aborted", exc = exc.msg - except TransportUseClosedError as exc: - debug "Server was closed", exc = exc.msg - raise newTransportClosedError(exc) - except CancelledError as exc: - raise exc - except TransportOsError as exc: - info "OS Error", exc = exc.msg - raise exc - except CatchableError as exc: - info "Unexpected error accepting connection", exc = exc.msg - raise exc + let remote = + try: + transp.remoteAddress + except TransportOsError as exc: + # The connection had errors / was closed before `await` returned control + await transp.closeWait() + debug "Cannot read remote address", exc = exc.msg + return nil -method dial*( - self: TcpTransport, - hostname: string, - address: MultiAddress, - peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async.} = - ## dial a peer - ## + let observedAddr = + MultiAddress.init(remote).expect("Can initialize from remote address") + self.connHandler(transp, Opt.some(observedAddr), Direction.In) - trace "Dialing remote peer", address = $address - let transp = - if self.networkReachability == NetworkReachability.NotReachable and self.addrs.len > 0: - self.clientFlags.incl(SocketFlags.ReusePort) - await connect(address, flags = self.clientFlags, localAddress = Opt.some(self.addrs[0])) - else: - await connect(address, flags = self.clientFlags) - - try: - let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet() - return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out) - except CatchableError as err: - await transp.closeWait() - raise err + impl(self) -method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = +method dial*( + self: TcpTransport, + hostname: string, + address: MultiAddress, + peerId: Opt[PeerId] = Opt.none(PeerId), +): Future[Connection] = + ## dial a peer + proc impl( + self: TcpTransport, hostname: string, address: MultiAddress, peerId: Opt[PeerId] + ): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} = + if self.stopping: + raise newTransportClosedError() + + let ta = initTAddress(address).valueOr: + raise (ref TcpTransportError)(msg: "Unsupported address: " & $address) + + trace "Dialing remote peer", address = $address + let transp = + try: + await( + if self.networkReachability == NetworkReachability.NotReachable and + self.addrs.len > 0: + let local = initTAddress(self.addrs[0]).expect("self address is valid") + self.clientFlags.incl(SocketFlags.ReusePort) + connect(ta, flags = self.clientFlags, localAddress = local) + else: + connect(ta, flags = self.clientFlags) + ) + except CancelledError as exc: + raise exc + except CatchableError as exc: + raise (ref TcpTransportError)(msg: exc.msg, parent: exc) + + # If `stop` is called after `connect` but before `await` returns, we might + # end up with a race condition where `stop` returns but not all connections + # have been closed - we drop connections in this case in order not to leak + # them + if self.stopping: + # Stopped while waiting for new connection + await transp.closeWait() + raise newTransportClosedError() + + let observedAddr = + try: + MultiAddress.init(transp.remoteAddress).expect("remote address is valid") + except TransportOsError as exc: + await transp.closeWait() + raise (ref TcpTransportError)(msg: exc.msg) + + self.connHandler(transp, Opt.some(observedAddr), Direction.Out) + + impl(self, hostname, address, peerId) + +method handles*(t: TcpTransport, address: MultiAddress): bool = if procCall Transport(t).handles(address): if address.protocols.isOk: return TCP.match(address) \ No newline at end of file diff --git a/libp2p/transports/tortransport.nim b/libp2p/transports/tortransport.nim index 54ee9e0ff8..d70d1873a3 100644 --- a/libp2p/transports/tortransport.nim +++ b/libp2p/transports/tortransport.nim @@ -200,7 +200,7 @@ method dial*( try: await dialPeer(transp, address) - return await self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out) + return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out) except CatchableError as err: await transp.closeWait() raise err diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index ee13addd8b..7eec34cd4b 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -35,7 +35,7 @@ type upgrader*: Upgrade networkReachability*: NetworkReachability -proc newTransportClosedError*(parent: ref Exception = nil): ref LPError = +proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError = newException(TransportClosedError, "Transport closed, no more connections!", parent) diff --git a/libp2p/wire.nim b/libp2p/wire.nim index 70d5574fd1..3e92ece834 100644 --- a/libp2p/wire.nim +++ b/libp2p/wire.nim @@ -13,6 +13,8 @@ import chronos, stew/endians2 import multiaddress, multicodec, errors, utility +export multiaddress, chronos + when defined(windows): import winlean else: @@ -30,7 +32,6 @@ const UDP, ) - proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] = ## Initialize ``TransportAddress`` with MultiAddress ``ma``. ## @@ -76,7 +77,7 @@ proc connect*( child: StreamTransport = nil, flags = default(set[SocketFlags]), localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport] - {.raises: [LPError, MaInvalidAddress].} = + {.async.} = ## Open new connection to remote peer with address ``ma`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. @@ -88,12 +89,12 @@ proc connect*( let transportAddress = initTAddress(ma).tryGet() compilesOr: - return connect(transportAddress, bufferSize, child, + return await connect(transportAddress, bufferSize, child, if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(), flags) do: # support for older chronos versions - return connect(transportAddress, bufferSize, child) + return await connect(transportAddress, bufferSize, child) proc createStreamServer*[T](ma: MultiAddress, cbproc: StreamCallback,