Skip to content

Commit

Permalink
Improve SSL error handling
Browse files Browse the repository at this point in the history
Eager reload if SSL communicaton fail due to race condition when cert is updated.
Mute SSL errors caused by missbehaving clients.
  • Loading branch information
eperott committed Feb 10, 2020
1 parent 90e852d commit 4881345
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@
import java.io.File;
import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ReloadWatcher {
private static final Logger logger = LoggerFactory.getLogger(ReloadWatcher.class);

private static final long RELOAD_MARGIN_MILLIS = 1000;
private final long intervalInMs;
private final Collection<File> files;

private long reloadAt;
private long nextReloadAt;
private long reloadedAt;

public ReloadWatcher(HttpServerOptions httpServerOptions) {
intervalInMs = httpServerOptions.sslReloadIntervalInSeconds * 1000;
intervalInMs = TimeUnit.SECONDS.toMillis(httpServerOptions.sslReloadIntervalInSeconds);
files = Stream.of(httpServerOptions.sslServerKeyFile,
httpServerOptions.sslServerKeyPasswordFile,
httpServerOptions.sslServerCertificateFile,
Expand All @@ -33,12 +35,29 @@ public ReloadWatcher(HttpServerOptions httpServerOptions) {
}

private void reset(long now) {
reloadedAt = now;
reloadAt = now + intervalInMs;
logger.debug("Reset reloaded at to {}", reloadedAt);
// Create a 1 second margin to compensate for poor resolution of File.lastModified()
reloadedAt = now - RELOAD_MARGIN_MILLIS;

nextReloadAt = now + intervalInMs;
logger.debug("Next reload at {}", nextReloadAt);
}

public synchronized void forceReload() {
if (!enabled()) {
return;
}

logger.info("Forced reload of exporter certificates on next scrape");

reloadedAt = 0L;
nextReloadAt = 0L;
}

boolean needReload() {
if (!enabled()) {
return false;
}

long now = System.currentTimeMillis();

if (timeToPoll(now)) {
Expand All @@ -48,8 +67,12 @@ boolean needReload() {
return false;
}

private boolean enabled() {
return intervalInMs > 0;
}

private boolean timeToPoll(long now) {
return intervalInMs > 0 && now > reloadAt;
return now > nextReloadAt;
}

private synchronized boolean reallyNeedReload(long now) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public void maybeAddHandler(SocketChannel ch) {
if (isEnabled()) {
ch.pipeline()
.addFirst(createSslHandler(ch))
.addLast(createSuppressingSslExceptionHandler());
.addLast(new UnexpectedSslExceptionHandler(reloadWatcher))
.addLast(new SuppressingSslExceptionHandler());
}
}

Expand Down Expand Up @@ -97,10 +98,6 @@ private void maybeReloadContext() {
}
}

private ChannelHandler createSuppressingSslExceptionHandler() {
return new SuppressingSslExceptionHandler();
}

@VisibleForTesting
SslContext getSslContext() {
return sslContextRef.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLHandshakeException;
import java.net.SocketAddress;

/**
* This handler will catch and suppress exceptions which are triggered when a client send a
Expand All @@ -21,9 +22,11 @@ public class SuppressingSslExceptionHandler extends ChannelHandlerAdapter {

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (handshakeException(cause) || notAJdkSslRecord(cause)) {
if (handshakeException(cause)
|| sslRecordException(cause)
|| decoderSslRecordException(cause)) {
try {
logger.info("Exception while processing SSL scrape request: {}", cause.getMessage());
logger.info("Exception while processing SSL scrape request from {}: {}", remotePeer(ctx), cause.getMessage());
logger.debug("Exception while processing SSL scrape request", cause);
} finally {
ReferenceCountUtil.release(cause);
Expand All @@ -33,12 +36,24 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}

private SocketAddress remotePeer(ChannelHandlerContext ctx) {
if (ctx.channel() == null) {
return null;
}
return ctx.channel().remoteAddress();
}

private boolean handshakeException(Throwable cause) {
return cause instanceof DecoderException
&& cause.getCause() instanceof SSLHandshakeException;
}

private boolean notAJdkSslRecord(Throwable cause) {
private boolean sslRecordException(Throwable cause) {
return cause instanceof NotSslRecordException;
}

private boolean decoderSslRecordException(Throwable cause) {
return cause instanceof DecoderException
&& cause.getCause() instanceof NotSslRecordException;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.zegelin.cassandra.exporter.netty.ssl;

import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLException;

public class UnexpectedSslExceptionHandler extends ChannelHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(UnexpectedSslExceptionHandler.class);

private final ReloadWatcher reloadWatcher;

UnexpectedSslExceptionHandler(ReloadWatcher reloadWatcher) {
this.reloadWatcher = reloadWatcher;
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
try {
if (unexpectedMessage(cause)) {
logger.warn(cause.getMessage());
// This may indicate that we're currently using invalid combo of key & cert
reloadWatcher.forceReload();
}
} finally {
ctx.fireExceptionCaught(cause);
}
}

private boolean unexpectedMessage(Throwable cause) {
return cause instanceof DecoderException
&& cause.getCause() instanceof SSLException
&& cause.getCause().getMessage().contains("unexpected_message");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import static org.assertj.core.api.Assertions.assertThat;

public class TestReloadWatcher {
public static final long INITIAL_FILE_AGE_MILLIS = 5000;
public static final long SLEEP_MILLIS = 1001;

private HttpServerOptions options;
private ReloadWatcher watcher;
Expand All @@ -19,9 +21,15 @@ public class TestReloadWatcher {
public void before() throws IOException {
options = new HttpServerOptions();
options.sslReloadIntervalInSeconds = 1;

options.sslServerKeyFile = givenTemporaryFile("server-key");
options.sslServerCertificateFile = givenTemporaryFile("server-cert");
options.sslTrustedCertificateFile = givenTemporaryFile("trusted-cert");

options.sslServerKeyFile.setLastModified(System.currentTimeMillis() - INITIAL_FILE_AGE_MILLIS);
options.sslServerCertificateFile.setLastModified(System.currentTimeMillis() - INITIAL_FILE_AGE_MILLIS);
options.sslTrustedCertificateFile.setLastModified(System.currentTimeMillis() - INITIAL_FILE_AGE_MILLIS);

watcher = new ReloadWatcher(options);
}

Expand All @@ -34,17 +42,60 @@ public void testNoImmediateReload() {

@Test
public void testNoReloadWhenFilesAreUntouched() throws InterruptedException {
Thread.sleep(1001);
Thread.sleep(SLEEP_MILLIS);

assertThat(watcher.needReload()).isFalse();
}

@Test
public void testReloadWhenFilesAreTouched() throws InterruptedException {
Thread.sleep(1001);
public void testReloadOnceWhenFilesAreTouched() throws InterruptedException {
Thread.sleep(SLEEP_MILLIS);

options.sslServerKeyFile.setLastModified(System.currentTimeMillis());
options.sslServerCertificateFile.setLastModified(System.currentTimeMillis());

Thread.sleep(SLEEP_MILLIS);

assertThat(watcher.needReload()).isTrue();

Thread.sleep(SLEEP_MILLIS);

assertThat(watcher.needReload()).isFalse();
}

// Verify that we reload certificates on next pass again in case files are modified
// just as we check for reload.
@Test
public void testReloadAgainWhenFilesAreTouchedJustAfterReload() throws InterruptedException {
Thread.sleep(SLEEP_MILLIS);

options.sslServerKeyFile.setLastModified(System.currentTimeMillis());
assertThat(watcher.needReload()).isTrue();
options.sslServerCertificateFile.setLastModified(System.currentTimeMillis());

Thread.sleep(SLEEP_MILLIS);

assertThat(watcher.needReload()).isTrue();
}

@Test
public void testReloadWhenForced() throws InterruptedException {
Thread.sleep(SLEEP_MILLIS);

watcher.forceReload();

assertThat(watcher.needReload()).isTrue();
}

@Test
public void testNoReloadWhenDisabled() throws InterruptedException {
options.sslReloadIntervalInSeconds = 0;
watcher = new ReloadWatcher(options);

Thread.sleep(SLEEP_MILLIS);
options.sslServerKeyFile.setLastModified(System.currentTimeMillis());

assertThat(watcher.needReload()).isFalse();
}

private File givenTemporaryFile(String filename) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.zegelin.cassandra.exporter.netty.ssl;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.NotSslRecordException;
Expand All @@ -10,37 +11,66 @@

import javax.net.ssl.SSLHandshakeException;

import java.net.InetSocketAddress;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;


public class TestSuppressingSslExceptionHandler {
@Mock
private ChannelHandlerContext context;

@Mock
private Channel channel;

private SuppressingSslExceptionHandler handler;

@BeforeMethod
public void before() {
MockitoAnnotations.initMocks(this);
handler = new SuppressingSslExceptionHandler();

when(context.channel()).thenReturn(channel);
when(channel.remoteAddress()).thenReturn(InetSocketAddress.createUnresolved("127.0.0.1", 12345));
}

@Test
public void testNotSslExceptionFromJdkImplementationIsMuted() throws Exception {
handler.exceptionCaught(context, new NotSslRecordException());
public void testNotSslExceptionFromJdkImplementationIsMuted() {
handler.exceptionCaught(context, new NotSslRecordException("Some HTTP_REQUEST in message"));
verify(context, times(0)).fireExceptionCaught(any());
}

@Test
public void testSslHandshakeExceptionFromOpenSslImplementationIsMuted() throws Exception {
public void testSslHandshakeExceptionFromOpenSslImplementationIsMuted() {
handler.exceptionCaught(context, new DecoderException(new SSLHandshakeException("Some HTTP_REQUEST in message")));
verify(context, times(0)).fireExceptionCaught(any());
}

@Test
public void testOtherExceptionIsPropagated() throws Exception {
public void testNotSslRecordExceptionIsMuted() {
handler.exceptionCaught(context, new DecoderException(new NotSslRecordException("Some HTTP_REQUEST in message")));
verify(context, times(0)).fireExceptionCaught(any());
}

@Test
public void testInfoLogDoNotBailOnNullChannel() {
when(context.channel()).thenReturn(null);
handler.exceptionCaught(context, new NotSslRecordException("Some HTTP_REQUEST in message"));
verify(context, times(0)).fireExceptionCaught(any());
}

@Test
public void testInfoLogDoNotBailOnNullRemoteAddress() {
when(channel.remoteAddress()).thenReturn(null);
handler.exceptionCaught(context, new NotSslRecordException("Some HTTP_REQUEST in message"));
verify(context, times(0)).fireExceptionCaught(any());
}

@Test
public void testOtherExceptionIsPropagated() {
handler.exceptionCaught(context, new NullPointerException());
verify(context, times(1)).fireExceptionCaught(any());
}
Expand Down
Loading

0 comments on commit 4881345

Please sign in to comment.