diff --git a/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala b/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala index ce787dc6f..38400f037 100644 --- a/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala +++ b/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala @@ -31,42 +31,42 @@ abstract class GrpcServiceSpecBase extends AnyFunSpec with Matchers { protected[this] def withInMemoryTransportScalaServer[T](f: ManagedChannel => T): T = { val channelName = "test-in-mem-server" - val server = InProcessServerBuilder.forName(channelName) - .addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext)) - .build() - - try { - server.start() - val channel = InProcessChannelBuilder.forName(channelName) + withManagedServer( + InProcessServerBuilder + .forName(channelName) + .addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext)) + .build() + ) { + val channel = InProcessChannelBuilder + .forName(channelName) .usePlaintext() .build() f(channel) - } finally { - server.shutdown() - server.awaitTermination(3000, TimeUnit.MILLISECONDS) - () } } private[this] def withServer[T]( createServer: NettyServerBuilder => Server )(f: ManagedChannel => T): T = { - val port = UniquePortGenerator.get() - val server = createServer(NettyServerBuilder.forPort(port)) - try { - server.start() + val port = UniquePortGenerator.get() + withManagedServer(createServer(NettyServerBuilder.forPort(port))) { val channel = NettyChannelBuilder .forAddress("localhost", port) .negotiationType(NegotiationType.PLAINTEXT) .build() f(channel) - } finally { - server.shutdown() - server.awaitTermination(3000, TimeUnit.MILLISECONDS) - () } } + private[this] def withManagedServer[T](server: Server)(f: => T): T = try { + server.start() + f + } finally { + server.shutdown() + server.awaitTermination(3000, TimeUnit.MILLISECONDS) + () + } + private[this] val singleThreadExecutionContext = new ExecutionContext { override def reportFailure(cause: Throwable): Unit = cause.printStackTrace()