Skip to content

Commit

Permalink
Try all resolved IP addresses when client fails to connect (#2905)
Browse files Browse the repository at this point in the history
* Try all resolved IP addresses when client fails to connect

* Fix for Scala 2.12
  • Loading branch information
kyri-petrou authored Jun 13, 2024
1 parent 5ea81ef commit baf5399
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,41 @@ object NettyConnectionPool {

for {
resolvedHosts <- dnsResolver.resolve(location.host)
pickedHost <- Random.nextIntBounded(resolvedHosts.size)
host = resolvedHosts(pickedHost)
channelFuture <- ZIO.attempt {
val bootstrap = new Bootstrap()
.channelFactory(channelFactory)
.group(eventLoopGroup)
.remoteAddress(new InetSocketAddress(host, location.port))
.withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt))
.handler(initializer)
(localAddress match {
case Some(addr) => bootstrap.localAddress(addr)
case _ => bootstrap
}).connect()
}
ch <- ZIO.attempt(channelFuture.channel())
_ <- Scope.addFinalizer {
NettyFutureExecutor.executed {
channelFuture.cancel(true)
ch.close()
}.when(ch.isOpen).ignoreLogged
hosts <- Random.shuffle(resolvedHosts.toList)
hostsNec <- ZIO.succeed(NonEmptyChunk.fromIterable(hosts.head, hosts.tail))
ch <- collectFirstSuccess(hostsNec) { host =>
ZIO.suspend {
val bootstrap = new Bootstrap()
.channelFactory(channelFactory)
.group(eventLoopGroup)
.remoteAddress(new InetSocketAddress(host, location.port))
.withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt))
.handler(initializer)
localAddress.foreach(bootstrap.localAddress)

val channelFuture = bootstrap.connect()
val ch = channelFuture.channel()
Scope.addFinalizer {
NettyFutureExecutor.executed {
channelFuture.cancel(true)
ch.close()
}.when(ch.isOpen).ignoreLogged
} *> NettyFutureExecutor.executed(channelFuture).as(ch)
}
}
_ <- NettyFutureExecutor.executed(channelFuture)
} yield ch
}

private def collectFirstSuccess[R, E, A, B](
as: NonEmptyChunk[A],
)(f: A => ZIO[R, E, B])(implicit trace: Trace): ZIO[R, E, B] = {
ZIO.suspendSucceed {
val it = as.iterator
def loop: ZIO[R, E, B] = f(it.next()).catchAll(e => if (it.hasNext) loop else ZIO.fail(e))
loop
}
}

/**
* Refreshes the idle timeout handler on the channel pipeline.
* @return
Expand Down
64 changes: 64 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/ClientConnectionSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package zio.http

import java.net.{InetAddress, UnknownHostException}

import zio._
import zio.test.Assertion._
import zio.test.TestAspect._
import zio.test._

import zio.http.internal.{DynamicServer, HttpRunnableSpec, serverTestLayer}
import zio.http.netty.NettyConfig

object ClientConnectionSpec extends HttpRunnableSpec {

private def tests =
List(
test("tries a different IP address when the connection fails") {
val app = Handler.ok.toRoutes.deploy(Request()).map(_.status)
assertZIO(app)(equalTo(Status.Ok))
} @@ nonFlaky(10),
)

override def spec = {
suite("ClientConnectionSpec") {
serve.as(tests)
}.provideSome[DynamicServer & Server & Client](Scope.default)
.provideShared(
DynamicServer.live,
serverTestLayer,
Client.live,
ZLayer.succeed(Client.Config.default.connectionTimeout(10.millis)),
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
ZLayer.succeed(TestResolver),
) @@ sequential @@ withLiveClock @@ withLiveRandom
}

private object TestResolver extends DnsResolver {
import scala.collection.compat._

override def resolve(host: String)(implicit trace: Trace): ZIO[Any, UnknownHostException, Chunk[InetAddress]] = {
ZIO.succeed {
Chunk.from((0 to 10).map { i =>
InetAddress.getByAddress(Array(127, 0, 0, i.toByte))
})
}
}
}
}

0 comments on commit baf5399

Please sign in to comment.