Skip to content

Commit

Permalink
Add request counter to ensure request context is associated with corr…
Browse files Browse the repository at this point in the history
…esponding request

Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks committed Oct 22, 2023
1 parent 75f83bc commit f9f0ed0
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 51 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import groovy.json.JsonBuilder

buildscript {
ext {
opensearch_version = System.getProperty("opensearch.version", "2.11.1-SNAPSHOT")
opensearch_version = System.getProperty("opensearch.version", "2.11.0-SNAPSHOT")
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")

Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/security/filter/NettyAttribute.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ public static <T> Optional<T> popFrom(final ChannelHandlerContext ctx, final Att
return Optional.ofNullable(ctx.channel().attr(attribute).getAndSet(null));
}

/**
* Gets an attribute value from the request context
*/
public static <T> Optional<T> peekFrom(final RestRequest request, final AttributeKey<T> attribute) {
if (request.getHttpChannel() instanceof Netty4HttpChannel) {
Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel();
return Optional.ofNullable(nettyChannel.attr(attribute).get());
}
return Optional.empty();
}

/**
* Gets an attribute value from the channel handler context
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -56,6 +57,7 @@
import org.opensearch.security.privileges.RestLayerPrivilegesEvaluator;
import org.opensearch.security.securityconf.impl.AllowlistingSettings;
import org.opensearch.security.securityconf.impl.WhitelistingSettings;
import org.opensearch.security.ssl.http.netty.Netty4RequestContext;
import org.opensearch.security.ssl.transport.PrincipalExtractor;
import org.opensearch.security.ssl.util.ExceptionUtils;
import org.opensearch.security.ssl.util.SSLRequestHelper;
Expand All @@ -68,9 +70,7 @@

import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX;
import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX;
import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE;
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport.REQUEST_CONTEXTS;;

public class SecurityRestFilter {

Expand Down Expand Up @@ -131,31 +131,37 @@ public SecurityRestFilter(
*/
public RestHandler wrap(RestHandler original, AdminDNs adminDNs) {
return (request, channel, client) -> {
final Optional<Map<String, Netty4RequestContext>> requestContexts = NettyAttribute.peekFrom(request, REQUEST_CONTEXTS);
String requestId = request.header("X-Channel-Request-ID");

final Optional<SecurityResponse> maybeSavedResponse = NettyAttribute.popFrom(request, EARLY_RESPONSE);
if (maybeSavedResponse.isPresent()) {
NettyAttribute.clearAttribute(request, CONTEXT_TO_RESTORE);
NettyAttribute.clearAttribute(request, IS_AUTHENTICATED);
channel.sendResponse(maybeSavedResponse.get().asRestResponse());
return;
}
final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel);

NettyAttribute.popFrom(request, CONTEXT_TO_RESTORE).ifPresent(storedContext -> {
// X_OPAQUE_ID will be overritten on restore - save to apply after restoring the saved context
final String xOpaqueId = threadContext.getHeader(Task.X_OPAQUE_ID);
storedContext.restore();
if (xOpaqueId != null) {
threadContext.putHeader(Task.X_OPAQUE_ID, xOpaqueId);
if (requestContexts.isPresent() && requestId != null && requestContexts.get().get(requestId) != null) {
Netty4RequestContext requestContext = requestContexts.get().get(requestId);
requestContexts.get().remove(requestId);
if (requestContext.earlyResponse != null) {
channel.sendResponse(requestContext.earlyResponse.asRestResponse());
return;
}
});

final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel);
// Authenticate request
if (!requestContext.isAuthenticated) {
// we aren't authenticated so we should skip this step
checkAndAuthenticateRequest(requestChannel);
} else {
ThreadContext.StoredContext storedContext = requestContext.storedContext;
// X_OPAQUE_ID will be overritten on restore - save to apply after restoring the saved context
final String xOpaqueId = threadContext.getHeader(Task.X_OPAQUE_ID);
storedContext.restore();
if (xOpaqueId != null) {
threadContext.putHeader(Task.X_OPAQUE_ID, xOpaqueId);
}
}

// Authenticate request
if (!NettyAttribute.popFrom(request, IS_AUTHENTICATED).orElse(false)) {
// we aren't authenticated so we should skip this step
} else {
checkAndAuthenticateRequest(requestChannel);
}

if (requestChannel.getQueuedResponse().isPresent()) {
channel.sendResponse(requestChannel.getQueuedResponse().get().asRestResponse());
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@

package org.opensearch.security.http;

import io.netty.util.AttributeKey;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.security.filter.SecurityResponse;
import org.opensearch.security.filter.SecurityRestFilter;
import org.opensearch.security.ssl.SecurityKeyStore;
import org.opensearch.security.ssl.SslExceptionHandler;
Expand All @@ -45,13 +42,6 @@

public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport {

public static final AttributeKey<SecurityResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);
public static final AttributeKey<Boolean> SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress");
public static final AttributeKey<Boolean> IS_AUTHENTICATED = AttributeKey.newInstance("opensearch-http-is-authenticated");

public SecurityHttpServerTransport(
final Settings settings,
final NetworkService networkService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,49 @@
package org.opensearch.security.ssl.http.netty;

import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;

import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS;
import static org.opensearch.security.http.SecurityHttpServerTransport.REQUEST_CONTEXTS;

import org.opensearch.security.filter.NettyAttribute;
import java.util.Map;

@Sharable
public class Netty4ConditionalDecompressor extends HttpContentDecompressor {

private String contentEncodingOverride;

@Override
protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception {
final boolean hasAnEarlyReponse = NettyAttribute.peekFrom(ctx, EARLY_RESPONSE).isPresent();
final boolean shouldDecompress = NettyAttribute.popFrom(ctx, SHOULD_DECOMPRESS).orElse(false);
if (hasAnEarlyReponse || !shouldDecompress) {
if (contentEncodingOverride != null) {
// If there was an error prompting an early response,... don't decompress
// If there is no explicit decompress flag,... don't decompress
// If there is a decompress flag and it is false,... don't decompress
return super.newContentDecoder("identity");
return super.newContentDecoder(contentEncodingOverride);
}

// Decompresses the content based on its encoding
return super.newContentDecoder(contentEncoding);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
contentEncodingOverride = null;
if (msg instanceof HttpMessage) {
final HttpMessage message = (HttpMessage) msg;
final HttpHeaders headers = message.headers();
Map<String, Netty4RequestContext> requestContexts = ctx.channel().attr(REQUEST_CONTEXTS).get();
String requestId = headers.get("X-Channel-Request-ID");
if (requestId != null && requestContexts != null) {
Netty4RequestContext requestContext = requestContexts.get(requestId);
if (requestContext != null && (!requestContext.shouldDecompress || requestContext.earlyResponse != null)) {
contentEncodingOverride = "identity";
}
}
}
super.channelRead(ctx, msg);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.util.ReferenceCountUtil;

import org.opensearch.ExceptionsHelper;
import org.opensearch.common.util.concurrent.ThreadContext;

Expand All @@ -34,16 +35,16 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.OpenSearchSecurityException;

import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;

import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.API_AUTHTOKEN_SUFFIX;
import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX;
import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX;
import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX;
import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE;
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.http.SecurityHttpServerTransport.REQUEST_CONTEXTS;
import static org.opensearch.security.http.SecurityHttpServerTransport.REQUEST_COUNTER;

@Sharable
public class Netty4HttpRequestHeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {
Expand Down Expand Up @@ -78,9 +79,10 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro
return;
}

// Start by setting this value to false, only requests that meet all the criteria will be decompressed
ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.FALSE);
ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.FALSE);
Netty4RequestContext requestContext = new Netty4RequestContext();
Map<String, Netty4RequestContext> requestContexts = ctx.channel().attr(REQUEST_CONTEXTS).get();

// ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.TRUE);

final Netty4HttpChannel httpChannel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
String rawPath = SecurityRestUtils.path(msg.uri());
Expand All @@ -107,23 +109,35 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro
}

ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false);
ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore);
requestContext.storedContext = contextToRestore;
// ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore);

requestChannel.getQueuedResponse().ifPresent(response -> ctx.channel().attr(EARLY_RESPONSE).set(response));
requestChannel.getQueuedResponse().ifPresent(response -> {
// ctx.channel().attr(EARLY_RESPONSE).set(response);
requestContext.earlyResponse = response;
});

// TODO Check if response code on queued response is 4XX
boolean shouldDecompress = !shouldSkipAuthentication && requestChannel.getQueuedResponse().isEmpty();

if (requestChannel.getQueuedResponse().isEmpty() || shouldSkipAuthentication) {
// Only allow decompression on authenticated requests that also aren't one of those ^
ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.valueOf(shouldDecompress));
ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.TRUE);
requestContext.shouldDecompress = Boolean.valueOf(shouldDecompress);
requestContext.isAuthenticated = Boolean.TRUE;
// ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.valueOf(shouldDecompress));
// ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.TRUE);
}
} catch (final OpenSearchSecurityException e) {
final SecurityResponse earlyResponse = new SecurityResponse(ExceptionsHelper.status(e).getStatus(), e);
ctx.channel().attr(EARLY_RESPONSE).set(earlyResponse);
requestContext.earlyResponse = earlyResponse;
// ctx.channel().attr(EARLY_RESPONSE).set(earlyResponse);
} catch (final SecurityRequestChannelUnsupported srcu) {
// Use defaults for unsupported channels
} finally {
AtomicInteger requestCounter = ctx.channel().attr(REQUEST_COUNTER).get();
int requestId = requestCounter.incrementAndGet();
msg.headers().add("X-Channel-Request-ID", requestId);
requestContexts.put(String.valueOf(requestId), requestContext);
ctx.fireChannelRead(msg);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.security.ssl.http.netty;

import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.security.filter.SecurityResponse;

public class Netty4RequestContext {
public SecurityResponse earlyResponse;
public ThreadContext.StoredContext storedContext;
public Boolean shouldDecompress;
public Boolean isAuthenticated;

public Netty4RequestContext() {
this.shouldDecompress = Boolean.FALSE;
this.isAuthenticated = Boolean.FALSE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AttributeKey;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -41,6 +47,10 @@
import org.opensearch.transport.SharedGroupFactory;

public class SecuritySSLNettyHttpServerTransport extends Netty4HttpServerTransport {
public static final AttributeKey<Map<String, Netty4RequestContext>> REQUEST_CONTEXTS = AttributeKey.newInstance(
"opensearch-http-request-contexts"
);
public static final AttributeKey<AtomicInteger> REQUEST_COUNTER = AttributeKey.newInstance("opensearch-http-request-counter");

private static final Logger logger = LogManager.getLogger(SecuritySSLNettyHttpServerTransport.class);
private final SecurityKeyStore sks;
Expand Down Expand Up @@ -113,6 +123,8 @@ protected void initChannel(Channel ch) throws Exception {
super.initChannel(ch);
final SslHandler sslHandler = new SslHandler(SecuritySSLNettyHttpServerTransport.this.sks.createHTTPSSLEngine());
ch.pipeline().addFirst("ssl_http", sslHandler);
ch.attr(REQUEST_CONTEXTS).set(new HashMap<>());
ch.attr(REQUEST_COUNTER).set(new AtomicInteger(0));
}
}

Expand Down

0 comments on commit f9f0ed0

Please sign in to comment.