diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 6dbb7b7676..b183593a91 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -11,17 +11,21 @@ package com.amazon.dlic.auth.http.jwt; +import static org.apache.http.HttpHeaders.AUTHORIZATION; + import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Collection; +import java.util.Map; +import java.util.Optional; import java.util.Map.Entry; import java.util.regex.Pattern; import com.google.common.annotations.VisibleForTesting; import org.apache.cxf.rs.security.jose.jwt.JwtClaims; import org.apache.cxf.rs.security.jose.jwt.JwtToken; -import org.apache.hc.core5.http.HttpHeaders; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -35,10 +39,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator { @@ -62,8 +66,8 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { jwtUrlParameter = settings.get("jwt_url_parameter"); - jwtHeaderName = settings.get("jwt_header", HttpHeaders.AUTHORIZATION); - isDefaultAuthHeader = HttpHeaders.AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); + jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); + isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS); @@ -82,7 +86,8 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { @Override @SuppressWarnings("removal") - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -99,7 +104,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) throws OpenSearchSecurityException { + private AuthCredentials extractCredentials0(final SecurityRequest request) throws OpenSearchSecurityException { String jwtString = getJwtTokenString(request); @@ -140,7 +145,7 @@ private AuthCredentials extractCredentials0(final RestRequest request) throws Op } - protected String getJwtTokenString(RestRequest request) { + protected String getJwtTokenString(SecurityRequest request) { String jwtToken = request.header(jwtHeaderName); if (isDefaultAuthHeader && jwtToken != null && BASIC.matcher(jwtToken).matches()) { jwtToken = null; @@ -148,10 +153,10 @@ protected String getJwtTokenString(RestRequest request) { if (jwtUrlParameter != null) { if (jwtToken == null || jwtToken.isEmpty()) { - jwtToken = request.param(jwtUrlParameter); + jwtToken = request.params().get(jwtUrlParameter); } else { // just consume to avoid "contains unrecognized parameter" - request.param(jwtUrlParameter); + request.params().get(jwtUrlParameter); } } @@ -235,10 +240,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - return wwwAuthenticateResponse; + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials authCredentials) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""), "") + ); } public String getRequiredAudience() { diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index 338a490037..df1c605167 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -11,10 +11,14 @@ package com.amazon.dlic.auth.http.jwt; +import static org.apache.http.HttpHeaders.AUTHORIZATION; + import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Collection; +import java.util.Map; +import java.util.Optional; import java.util.Map.Entry; import java.util.regex.Pattern; @@ -22,7 +26,8 @@ import io.jsonwebtoken.JwtParser; import io.jsonwebtoken.JwtParserBuilder; import io.jsonwebtoken.security.WeakKeyException; -import org.apache.hc.core5.http.HttpHeaders; + +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -30,10 +35,9 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.KeyUtils; @@ -58,8 +62,8 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { String signingKey = settings.get("signing_key"); jwtUrlParameter = settings.get("jwt_url_parameter"); - jwtHeaderName = settings.get("jwt_header", HttpHeaders.AUTHORIZATION); - isDefaultAuthHeader = HttpHeaders.AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); + jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); + isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); requireAudience = settings.get("required_audience"); @@ -83,7 +87,8 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { @Override @SuppressWarnings("removal") - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -100,7 +105,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) { + private AuthCredentials extractCredentials0(final SecurityRequest request) { if (jwtParser == null) { log.error("Missing Signing Key. JWT authentication will not work"); return null; @@ -112,10 +117,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } if ((jwtToken == null || jwtToken.isEmpty()) && jwtUrlParameter != null) { - jwtToken = request.param(jwtUrlParameter); + jwtToken = request.params().get(jwtUrlParameter); } else { // just consume to avoid "contains unrecognized parameter" - request.param(jwtUrlParameter); + request.params().get(jwtUrlParameter); } if (jwtToken == null || jwtToken.length() == 0) { @@ -170,10 +175,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - return wwwAuthenticateResponse; + public Optional reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""), "") + ); } @Override @@ -181,7 +186,7 @@ public String getType() { return "jwt"; } - protected String extractSubject(final Claims claims, final RestRequest request) { + protected String extractSubject(final Claims claims, final SecurityRequest request) { String subject = claims.getSubject(); if (subjectKey != null) { // try to get roles from claims, first as Object to avoid having to catch the ExpectedTypeException @@ -205,7 +210,7 @@ protected String extractSubject(final Claims claims, final RestRequest request) } @SuppressWarnings("unchecked") - protected String[] extractRoles(final Claims claims, final RestRequest request) { + protected String[] extractRoles(final Claims claims, final SecurityRequest request) { // no roles key specified if (rolesKey == null) { return new String[0]; diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 99bc23b746..ad24b8db95 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -11,6 +11,8 @@ package com.amazon.dlic.auth.http.kerberos; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; + import java.io.Serializable; import java.nio.file.Files; import java.nio.file.Path; @@ -22,13 +24,17 @@ import java.security.PrivilegedExceptionAction; import java.util.Base64; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Optional; import java.util.Set; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; import com.google.common.base.Strings; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.ietf.jgss.GSSContext; @@ -48,15 +54,13 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; public class HTTPSpnegoAuthenticator implements HTTPAuthenticator { - private static final String EMPTY_STRING = ""; private static final Oid[] KRB_OIDS = new Oid[] { KrbConstants.SPNEGO, KrbConstants.KRB5MECH }; protected final Logger log = LogManager.getLogger(this.getClass()); @@ -170,7 +174,7 @@ public Void run() { @Override @SuppressWarnings("removal") - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext threadContext) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -187,7 +191,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) { + private AuthCredentials extractCredentials0(final SecurityRequest request) { if (acceptorPrincipal == null || acceptorKeyTabPath == null) { log.error("Missing acceptor principal or keytab configuration. Kerberos authentication will not work"); @@ -279,25 +283,22 @@ public GSSCredential run() throws GSSException { } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - final BytesRestResponse wwwAuthenticateResponse; - XContentBuilder response = getNegotiateResponseBody(); - - if (response != null) { - wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, response); - } else { - wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { + final Map headers = new HashMap<>(); + String responseBody = ""; + final String negotiateResponseBody = getNegotiateResponseBody(); + if (negotiateResponseBody != null) { + responseBody = negotiateResponseBody; + headers.putAll(SecurityResponse.CONTENT_TYPE_APP_JSON); } - if (credentials == null || credentials.getNativeCredentials() == null) { - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); + if (creds == null || creds.getNativeCredentials() == null) { + headers.put("WWW-Authenticate", "Negotiate"); } else { - wwwAuthenticateResponse.addHeader( - "WWW-Authenticate", - "Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials()) - ); + headers.put("WWW-Authenticate", "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())); } - return wwwAuthenticateResponse; + + return Optional.of(new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody)); } @Override @@ -369,7 +370,7 @@ private static String getUsernameFromGSSContext(final GSSContext gssContext, fin return null; } - private XContentBuilder getNegotiateResponseBody() { + private String getNegotiateResponseBody() { try { XContentBuilder negotiateResponseBody = XContentFactory.jsonBuilder(); negotiateResponseBody.startObject(); @@ -381,7 +382,7 @@ private XContentBuilder getNegotiateResponseBody() { negotiateResponseBody.endObject(); negotiateResponseBody.endObject(); negotiateResponseBody.endObject(); - return negotiateResponseBody; + return negotiateResponseBody.toString(); } catch (Exception ex) { log.error("Can't construct response body", ex); return null; diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 7210ed5950..a344c653f3 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -19,6 +19,7 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -34,6 +35,7 @@ import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; + import org.apache.commons.lang3.StringUtils; import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; @@ -44,21 +46,21 @@ import org.apache.cxf.rs.security.jose.jwt.JwtClaims; import org.apache.cxf.rs.security.jose.jwt.JwtToken; import org.apache.cxf.rs.security.jose.jwt.JwtUtils; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; - import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction; +import org.opensearch.security.filter.SecurityResponse; class AuthTokenProcessorHandler { private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class); @@ -118,7 +120,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - BytesRestResponse handle(RestRequest restRequest) throws Exception { + Optional handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -126,9 +128,9 @@ BytesRestResponse handle(RestRequest restRequest) throws Exception { sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction>() { @Override - public BytesRestResponse run() throws SamlConfigException, IOException { + public Optional run() throws SamlConfigException, IOException { return handleLowLevel(restRequest); } }); @@ -181,7 +183,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { + private Optional handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -229,15 +231,15 @@ private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlCon AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return null; + return Optional.empty(); } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); + return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString)); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); + return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed")); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index 64d816fabf..a3f37ba46e 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -18,6 +18,8 @@ import java.security.PrivateKey; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.Map; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -35,6 +37,7 @@ import net.shibboleth.utilities.java.support.xml.BasicParserPool; import org.apache.commons.lang3.StringEscapeUtils; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensaml.core.config.InitializationException; @@ -55,11 +58,13 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.Destroyable; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestChannelUnsupported; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.PemKeyReader; import org.opensearch.security.user.AuthCredentials; @@ -148,17 +153,18 @@ public HTTPSamlAuthenticator(final Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(RestRequest restRequest, ThreadContext threadContext) throws OpenSearchSecurityException { - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) + throws OpenSearchSecurityException { + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { return null; } - AuthCredentials authCredentials = this.httpJwtAuthenticator.extractCredentials(restRequest, threadContext); + AuthCredentials authCredentials = this.httpJwtAuthenticator.extractCredentials(request, threadContext); if (AUTHINFO_SUFFIX.equals(suffix)) { - this.initLogoutUrl(restRequest, threadContext, authCredentials); + this.initLogoutUrl(threadContext, authCredentials); } return authCredentials; @@ -170,26 +176,32 @@ public String getType() { } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + public Optional reRequestAuthentication(final SecurityRequest request, final AuthCredentials authCredentials) { try { Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { - final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request); - if (restResponse != null) { - return restResponse; + // Verficiation of SAML ASC endpoint only works with RestRequests + if (!(request instanceof OpenSearchRequest)) { + throw new SecurityRequestChannelUnsupported(); + } else { + final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request; + final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest(); + Optional restResponse = this.authTokenProcessorHandler.handle(restRequest); + if (restResponse.isPresent()) { + return restResponse; + } } } - Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); - - authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - - return authenticateResponse; + final Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)), "") + ); } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return null; + return Optional.empty(); } } @@ -397,7 +409,7 @@ String buildLogoutUrl(AuthCredentials authCredentials) { } - private void initLogoutUrl(RestRequest restRequest, ThreadContext threadContext, AuthCredentials authCredentials) { + private void initLogoutUrl(ThreadContext threadContext, AuthCredentials authCredentials) { threadContext.putTransient(ConfigConstants.SSO_LOGOUT_URL, buildLogoutUrl(authCredentials)); } diff --git a/src/main/java/org/opensearch/security/auditlog/AuditLog.java b/src/main/java/org/opensearch/security/auditlog/AuditLog.java index 612b790686..997b9e4b87 100644 --- a/src/main/java/org/opensearch/security/auditlog/AuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/AuditLog.java @@ -35,23 +35,23 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.config.AuditConfig; import org.opensearch.security.compliance.ComplianceConfig; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; public interface AuditLog extends Closeable { // login - void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request); + void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request); - void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request); + void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request); // privs - void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request); + void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request); - void logGrantedPrivileges(String effectiveUser, RestRequest request); + void logGrantedPrivileges(String effectiveUser, SecurityRequest request); void logMissingPrivileges(String privilege, TransportRequest request, Task task); @@ -63,13 +63,13 @@ public interface AuditLog extends Closeable { // spoof void logBadHeaders(TransportRequest request, String action, Task task); - void logBadHeaders(RestRequest request); + void logBadHeaders(SecurityRequest request); void logSecurityIndexAttempt(TransportRequest request, String action, Task task); void logSSLException(TransportRequest request, Throwable t, String action, Task task); - void logSSLException(RestRequest request, Throwable t); + void logSSLException(SecurityRequest request, Throwable t); void logDocumentRead(String index, String id, ShardId shardId, Map fieldNameValues); diff --git a/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java b/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java index 942f06804f..df96400f96 100644 --- a/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java +++ b/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java @@ -27,7 +27,7 @@ package org.opensearch.security.auditlog; import org.opensearch.OpenSearchException; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; @@ -42,7 +42,7 @@ public AuditLogSslExceptionHandler(final AuditLog auditLog) { } @Override - public void logError(Throwable t, RestRequest request, int type) { + public void logError(Throwable t, SecurityRequestChannel request, int type) { switch (type) { case 0: auditLog.logSSLException(request, t); diff --git a/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java b/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java index 809951d48a..1ac4492a94 100644 --- a/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java @@ -35,9 +35,9 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.config.AuditConfig; import org.opensearch.security.compliance.ComplianceConfig; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; @@ -49,12 +49,12 @@ public void close() throws IOException { } @Override - public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { // noop, intentionally left empty } @Override - public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { // noop, intentionally left empty } @@ -79,7 +79,7 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } @Override - public void logBadHeaders(RestRequest request) { + public void logBadHeaders(SecurityRequest request) { // noop, intentionally left empty } @@ -94,17 +94,17 @@ public void logSSLException(TransportRequest request, Throwable t, String action } @Override - public void logSSLException(RestRequest request, Throwable t) { + public void logSSLException(SecurityRequest request, Throwable t) { // noop, intentionally left empty } @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { // noop, intentionally left empty } @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { // noop, intentionally left empty } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index a8f511be97..de53f0b744 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -60,12 +60,12 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auditlog.config.AuditConfig; import org.opensearch.security.compliance.ComplianceConfig; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.User; @@ -139,7 +139,7 @@ public ComplianceConfig getComplianceConfig() { } @Override - public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { if (!checkRestFilter(AuditCategory.FAILED_LOGIN, effectiveUser, request)) { return; @@ -157,7 +157,7 @@ public void logFailedLogin(String effectiveUser, boolean securityadmin, String i } @Override - public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { if (!checkRestFilter(AuditCategory.AUTHENTICATED, effectiveUser, request)) { return; @@ -174,7 +174,7 @@ public void logSucceededLogin(String effectiveUser, boolean securityadmin, Strin } @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { if (!checkRestFilter(AuditCategory.MISSING_PRIVILEGES, effectiveUser, request)) { return; } @@ -189,7 +189,7 @@ public void logMissingPrivileges(String privilege, String effectiveUser, RestReq } @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { if (!checkRestFilter(AuditCategory.GRANTED_PRIVILEGES, effectiveUser, request)) { return; } @@ -348,7 +348,7 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } @Override - public void logBadHeaders(RestRequest request) { + public void logBadHeaders(SecurityRequest request) { if (!checkRestFilter(AuditCategory.BAD_HEADERS, getUser(), request)) { return; @@ -437,7 +437,7 @@ public void logSSLException(TransportRequest request, Throwable t, String action } @Override - public void logSSLException(RestRequest request, Throwable t) { + public void logSSLException(SecurityRequest request, Throwable t) { if (!checkRestFilter(AuditCategory.SSL_EXCEPTION, getUser(), request)) { return; @@ -898,7 +898,7 @@ private boolean checkComplianceFilter( } @VisibleForTesting - boolean checkRestFilter(final AuditCategory category, final String effectiveUser, RestRequest request) { + boolean checkRestFilter(final AuditCategory category, final String effectiveUser, SecurityRequest request) { final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { log.trace( diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java index c88f1fca3f..8da4b13d4c 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java @@ -31,9 +31,9 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.config.AuditConfig; import org.opensearch.security.auditlog.routing.AuditMessageRouter; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; @@ -131,28 +131,28 @@ protected void save(final AuditMessage msg) { } @Override - public void logFailedLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, RestRequest request) { + public void logFailedLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, SecurityRequest request) { if (enabled) { super.logFailedLogin(effectiveUser, securityAdmin, initiatingUser, request); } } @Override - public void logSucceededLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, RestRequest request) { + public void logSucceededLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, SecurityRequest request) { if (enabled) { super.logSucceededLogin(effectiveUser, securityAdmin, initiatingUser, request); } } @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { if (enabled) { super.logMissingPrivileges(privilege, effectiveUser, request); } } @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { if (enabled) { super.logGrantedPrivileges(effectiveUser, request); } @@ -187,7 +187,7 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } @Override - public void logBadHeaders(RestRequest request) { + public void logBadHeaders(SecurityRequest request) { if (enabled) { super.logBadHeaders(request); } @@ -208,7 +208,7 @@ public void logSSLException(TransportRequest request, Throwable t, String action } @Override - public void logSSLException(RestRequest request, Throwable t) { + public void logSSLException(SecurityRequest request, Throwable t) { if (enabled) { super.logSSLException(request, t); } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java index 7c3bf5c77f..a41b4625c2 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java @@ -48,6 +48,8 @@ import org.opensearch.security.auditlog.AuditLog.Origin; import org.opensearch.security.auditlog.config.AuditConfig; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.securityconf.impl.CType; import org.opensearch.security.support.WildcardMatcher; @@ -369,16 +371,31 @@ void addRestMethod(final RestRequest.Method method) { } } - void addRestRequestInfo(final RestRequest request, final AuditConfig.Filter filter) { + void addRestRequestInfo(final SecurityRequest request, final AuditConfig.Filter filter) { if (request != null) { - final String path = request.path(); + final String path = request.path().toString(); addPath(path); addRestHeaders(request.getHeaders(), filter.shouldExcludeSensitiveHeaders()); addRestParams(request.params()); addRestMethod(request.method()); - if (filter.shouldLogRequestBody() && request.hasContentOrSourceParam()) { + + if (filter.shouldLogRequestBody()) { + + if (!(request instanceof OpenSearchRequest)) { + // The request body is only avaliable on some request sources + return; + } + + final OpenSearchRequest securityRestRequest = (OpenSearchRequest) request; + final RestRequest restRequest = securityRestRequest.breakEncapsulationForRequest(); + + if (!(restRequest.hasContentOrSourceParam())) { + // If there is no content, don't attempt to save any body information + return; + } + try { - final Tuple xContentTuple = request.contentOrSourceParam(); + final Tuple xContentTuple = restRequest.contentOrSourceParam(); final String requestBody = XContentHelper.convertToJson(xContentTuple.v2(), false, xContentTuple.v1()); if (path != null && requestBody != null diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 35d4e05c4c..88e88c3c52 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -32,6 +32,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.SortedSet; import java.util.concurrent.Callable; @@ -43,6 +44,7 @@ import com.google.common.cache.RemovalListener; import com.google.common.cache.RemovalNotification; import com.google.common.collect.Multimap; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.greenrobot.eventbus.Subscribe; @@ -50,15 +52,14 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auth.blocking.ClientBlockRegistry; import org.opensearch.security.auth.internal.NoOpAuthenticationBackend; import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.http.OnBehalfOfAuthenticator; import org.opensearch.security.http.XFFResolver; import org.opensearch.security.securityconf.DynamicConfigModel; @@ -68,6 +69,10 @@ import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; +import static org.apache.http.HttpStatus.SC_FORBIDDEN; +import static org.apache.http.HttpStatus.SC_SERVICE_UNAVAILABLE; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; + public class BackendRegistry { protected final Logger log = LogManager.getLogger(this.getClass()); @@ -185,16 +190,18 @@ public void onDynamicConfigModelChanged(DynamicConfigModel dcm) { * @return The authenticated user, null means another roundtrip * @throws OpenSearchSecurityException */ - public boolean authenticate(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + public boolean authenticate(final SecurityRequestChannel request) { final boolean isDebugEnabled = log.isDebugEnabled(); - if (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress - && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress())) { + final boolean isBlockedBasedOnAddress = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(address -> isBlocked(address)) + .orElse(false); + if (isBlockedBasedOnAddress) { if (isDebugEnabled) { - log.debug("Rejecting REST request because of blocked address: {}", request.getHttpChannel().getRemoteAddress()); + log.debug("Rejecting REST request because of blocked address: {}", request.getRemoteAddress().orElse(null)); } - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); - + request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")); return false; } @@ -214,17 +221,17 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g if (!isInitialized()) { log.error("Not yet initialized (you may need to run securityadmin)"); - channel.sendResponse(new BytesRestResponse(RestStatus.SERVICE_UNAVAILABLE, "OpenSearch Security not initialized.")); + request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, null, "OpenSearch Security not initialized.")); return false; } final TransportAddress remoteAddress = xffResolver.resolve(request); final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { - log.trace("Rest authentication request from {} [original: {}]", remoteAddress, request.getHttpChannel().getRemoteAddress()); + log.trace("Rest authentication request from {} [original: {}]", remoteAddress, request.getRemoteAddress().orElse(null)); } - threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, remoteAddress); + threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, remoteAddress); boolean authenticated = false; @@ -256,7 +263,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } final AuthCredentials ac; try { - ac = httpAuthenticator.extractCredentials(request, threadContext); + ac = httpAuthenticator.extractCredentials(request, threadPool.getThreadContext()); } catch (Exception e1) { if (isDebugEnabled) { log.debug("'{}' extracting credentials from {} http authenticator", e1.toString(), httpAuthenticator.getType(), e1); @@ -281,17 +288,16 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } if (authDomain.isChallenge()) { - final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, null); - if (restResponse != null) { + final Optional restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse.isPresent()) { auditLog.logFailedLogin("", false, null, request); if (isTraceEnabled) { log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); } notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse(restResponse); + request.queueForSending(restResponse.get()); return false; } - } else { // no reRequest possible if (isTraceEnabled) { @@ -302,12 +308,11 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { - final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, ac); // credentials found in request but we need another client challenge - if (restResponse != null) { - // auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + final Optional restResponse = httpAuthenticator.reRequestAuthentication(request, ac); + if (restResponse.isPresent()) { notifyIpAuthFailureListeners(request, ac); - channel.sendResponse(restResponse); + request.queueForSending(restResponse.get()); return false; } else { // no reRequest possible @@ -334,9 +339,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g authDomain.getBackend().getClass().getName() )) { authFailureListener.onAuthFailure( - (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress) - ? ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress() - : null, + request.getRemoteAddress().map(InetSocketAddress::getAddress).orElse(null), ac, request ); @@ -347,9 +350,10 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g if (adminDns.isAdmin(authenticatedUser)) { log.error("Cannot authenticate rest user because admin user is not permitted to login via HTTP"); auditLog.logFailedLogin(authenticatedUser.getName(), true, null, request); - channel.sendResponse( - new BytesRestResponse( - RestStatus.FORBIDDEN, + request.queueForSending( + new SecurityResponse( + SC_FORBIDDEN, + null, "Cannot authenticate user because admin user is not permitted to login via HTTP" ) ); @@ -370,10 +374,8 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g if (authenticated) { final User impersonatedUser = impersonate(request, authenticatedUser); - threadContext.putTransient( - ConfigConstants.OPENDISTRO_SECURITY_USER, - impersonatedUser == null ? authenticatedUser : impersonatedUser - ); + threadPool.getThreadContext() + .putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, impersonatedUser == null ? authenticatedUser : impersonatedUser); auditLog.logSucceededLogin( (impersonatedUser == null ? authenticatedUser : impersonatedUser).getName(), false, @@ -390,14 +392,15 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); anonymousUser.setRequestedTenant(tenant); - threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, anonymousUser); + threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, anonymousUser); auditLog.logSucceededLogin(anonymousUser.getName(), false, null, request); if (isDebugEnabled) { log.debug("Anonymous User is authenticated"); } return true; } - BytesRestResponse challengeResponse = null; + + Optional challengeResponse = Optional.empty(); if (firstChallengingHttpAuthenticator != null) { @@ -406,7 +409,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); - if (challengeResponse != null) { + if (challengeResponse.isPresent()) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } @@ -422,25 +425,16 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse( - challengeResponse != null - ? challengeResponse - : new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed") + request.queueForSending( + challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")) ); return false; } - return authenticated; } - private void notifyIpAuthFailureListeners(RestRequest request, AuthCredentials authCredentials) { - notifyIpAuthFailureListeners( - (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress) - ? ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress() - : null, - authCredentials, - request - ); + private void notifyIpAuthFailureListeners(SecurityRequestChannel request, AuthCredentials authCredentials) { + notifyIpAuthFailureListeners(request.getRemoteAddress().map(InetSocketAddress::getAddress).orElse(null), authCredentials, request); } private void notifyIpAuthFailureListeners(InetAddress remoteAddress, AuthCredentials authCredentials, Object request) { @@ -587,7 +581,7 @@ public User call() throws Exception { } } - private User impersonate(final RestRequest request, final User originalUser) throws OpenSearchSecurityException { + private User impersonate(final SecurityRequest request, final User originalUser) throws OpenSearchSecurityException { final String impersonatedUserHeader = request.header("opendistro_security_impersonate_as"); diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index 70b94d6dbf..f5a4df34b5 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -26,10 +26,13 @@ package org.opensearch.security.auth; +import java.util.Optional; + import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; /** @@ -67,7 +70,7 @@ public interface HTTPAuthenticator { * If the authentication flow needs another roundtrip with the request originator do not mark it as complete. * @throws OpenSearchSecurityException */ - AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException; + AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) throws OpenSearchSecurityException; /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method @@ -75,10 +78,9 @@ public interface HTTPAuthenticator { * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips * then the response will be returned which can then be sent via response channel. * - * @param request + * @param request The request to reauthenticate or not * @param credentials The credentials from the prior authentication attempt - * @return null if re-request is not supported/necessary, response object otherwise. - * If an object is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return Optional response if is not supported/necessary, response object otherwise. */ - BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials); + Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 30df84ef5f..57750bfa9a 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -41,8 +41,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.AuditLog; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.http.XFFResolver; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.SecurityUtils; @@ -185,7 +185,7 @@ public InjectedUser getInjectedUser() { return injectedUser; } - boolean injectUser(RestRequest request) { + boolean injectUser(SecurityRequestChannel request) { InjectedUser injectedUser = getInjectedUser(); if (injectedUser == null) { return false; diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java index 293f934434..6cbd7eaf78 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java @@ -49,6 +49,7 @@ import org.opensearch.security.dlic.rest.validation.EndpointValidator; import org.opensearch.security.dlic.rest.validation.RequestContentValidator; import org.opensearch.security.dlic.rest.validation.ValidationResult; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.securityconf.DynamicConfigFactory; import org.opensearch.security.securityconf.impl.CType; import org.opensearch.security.securityconf.impl.SecurityDynamicConfiguration; @@ -371,6 +372,7 @@ protected final ValidationResult> loadConfigurat ) { final var configuration = load(cType, logComplianceEvent); if (configuration.getSeqNo() < 0) { + return ValidationResult.error( RestStatus.FORBIDDEN, forbiddenMessage( @@ -411,16 +413,19 @@ public RestApiAdminPrivilegesEvaluator restApiAdminPrivilegesEvaluator() { @Override public ValidationResult onConfigDelete(SecurityConfiguration securityConfiguration) throws IOException { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Access denied")); } @Override public ValidationResult onConfigLoad(SecurityConfiguration securityConfiguration) throws IOException { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Access denied")); } @Override public ValidationResult onConfigChange(SecurityConfiguration securityConfiguration) throws IOException { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Access denied")); } @@ -551,12 +556,12 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie final String userName = user == null ? null : user.getName(); if (authError != null) { LOGGER.error("No permission to access REST API: " + authError); - securityApiDependencies.auditLog().logMissingPrivileges(authError, userName, request); + securityApiDependencies.auditLog().logMissingPrivileges(authError, userName, SecurityRequestFactory.from(request)); // for rest request request.params().clear(); return channel -> forbidden(channel, "No permission to access REST API: " + authError); } else { - securityApiDependencies.auditLog().logGrantedPrivileges(userName, request); + securityApiDependencies.auditLog().logGrantedPrivileges(userName, SecurityRequestFactory.from(request)); } final var originalUserAndRemoteAddress = Utils.userAndRemoteAddressFrom(threadPool.getThreadContext()); diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/NodesDnApiAction.java b/src/main/java/org/opensearch/security/dlic/rest/api/NodesDnApiAction.java index a10139b594..ed1f3e0fbb 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/NodesDnApiAction.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/NodesDnApiAction.java @@ -35,6 +35,7 @@ import org.opensearch.security.securityconf.impl.NodesDn; import org.opensearch.security.securityconf.impl.SecurityDynamicConfiguration; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.tools.SecurityAdmin; import org.opensearch.threadpool.ThreadPool; import static org.opensearch.security.dlic.rest.api.Responses.forbiddenMessage; diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java b/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java index 35f4332520..f1a336986b 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java @@ -35,6 +35,8 @@ import org.opensearch.rest.RestRequest.Method; import org.opensearch.security.configuration.AdminDNs; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.privileges.PrivilegesEvaluator; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -446,7 +448,8 @@ private String checkAdminCertBasedAccessPermissions(RestRequest request) throws } // Certificate based access, Check if we have an admin TLS certificate - SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); + final SecurityRequest securityRequest = SecurityRequestFactory.from(request); + SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, securityRequest, principalExtractor); if (sslInfo == null) { // here we log on error level, since authentication finally failed diff --git a/src/main/java/org/opensearch/security/dlic/rest/validation/EndpointValidator.java b/src/main/java/org/opensearch/security/dlic/rest/validation/EndpointValidator.java index 442d39cf43..6eb865e937 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/validation/EndpointValidator.java +++ b/src/main/java/org/opensearch/security/dlic/rest/validation/EndpointValidator.java @@ -104,6 +104,7 @@ default ValidationResult entityStatic(final SecurityConfi final var configuration = securityConfiguration.configuration(); final var entityName = securityConfiguration.entityName(); if (configuration.isStatic(entityName)) { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Resource '" + entityName + "' is static.")); } return ValidationResult.success(securityConfiguration); @@ -122,6 +123,7 @@ default ValidationResult entityHidden(final SecurityConfi final var configuration = securityConfiguration.configuration(); final var entityName = securityConfiguration.entityName(); if (configuration.isHidden(entityName)) { + return ValidationResult.error(RestStatus.NOT_FOUND, notFoundMessage("Resource '" + entityName + "' is not available.")); } return ValidationResult.success(securityConfiguration); @@ -149,6 +151,7 @@ default ValidationResult isAllowedToChangeEntityWithRestA final var configuration = securityConfiguration.configuration(); final var existingEntity = configuration.getCEntry(securityConfiguration.entityName()); if (restApiAdminPrivilegesEvaluator().containsRestApiAdminPermissions(existingEntity)) { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Access denied")); } } else { @@ -158,6 +161,7 @@ default ValidationResult isAllowedToChangeEntityWithRestA configuration.getImplementingClass() ); if (restApiAdminPrivilegesEvaluator().containsRestApiAdminPermissions(configEntityContent)) { + return ValidationResult.error(RestStatus.FORBIDDEN, forbiddenMessage("Access denied")); } } diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java new file mode 100644 index 0000000000..85c70b8f7a --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.net.ssl.SSLEngine; + +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; + +import io.netty.handler.ssl.SslHandler; + +/** + * Wraps the functionality of RestRequest for use in the security plugin + */ +public class OpenSearchRequest implements SecurityRequest { + + protected final RestRequest underlyingRequest; + + OpenSearchRequest(final RestRequest request) { + underlyingRequest = request; + } + + @Override + public Map> getHeaders() { + return underlyingRequest.getHeaders(); + } + + @Override + public SSLEngine getSSLEngine() { + if (underlyingRequest == null + || underlyingRequest.getHttpChannel() == null + || !(underlyingRequest.getHttpChannel() instanceof Netty4HttpChannel)) { + return null; + } + + // We look for Ssl_handler called `ssl_http` in the outbound pipeline of Netty channel first, and if its not + // present we look for it in inbound channel. If its present in neither we return null, else we return the sslHandler. + final Netty4HttpChannel httpChannel = (Netty4HttpChannel) underlyingRequest.getHttpChannel(); + SslHandler sslhandler = (SslHandler) httpChannel.getNettyChannel().pipeline().get("ssl_http"); + if (sslhandler == null && httpChannel.inboundPipeline() != null) { + sslhandler = (SslHandler) httpChannel.inboundPipeline().get("ssl_http"); + } + + return sslhandler != null ? sslhandler.engine() : null; + } + + @Override + public String path() { + return underlyingRequest.path(); + } + + @Override + public Method method() { + return underlyingRequest.method(); + } + + @Override + public Optional getRemoteAddress() { + return Optional.ofNullable(this.underlyingRequest.getHttpChannel().getRemoteAddress()); + } + + @Override + public String uri() { + return underlyingRequest.uri(); + } + + @Override + public Map params() { + return underlyingRequest.params(); + } + + /** Gets access to the underlying request object */ + public RestRequest breakEncapsulationForRequest() { + return underlyingRequest; + } +} diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java new file mode 100644 index 0000000000..45035e0d83 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; + +public class OpenSearchRequestChannel extends OpenSearchRequest implements SecurityRequestChannel { + + private final Logger log = LogManager.getLogger(OpenSearchRequest.class); + + private final AtomicReference responseRef = new AtomicReference(null); + private final AtomicBoolean hasCompleted = new AtomicBoolean(false); + private final RestChannel underlyingChannel; + + OpenSearchRequestChannel(final RestRequest request, final RestChannel channel) { + super(request); + underlyingChannel = channel; + } + + /** Gets access to the underlying channel object */ + public RestChannel breakEncapsulationForChannel() { + return underlyingChannel; + } + + @Override + public void queueForSending(final SecurityResponse response) { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isPresent()) { + throw new UnsupportedOperationException("Another response was already queued"); + } + + responseRef.set(response); + } + + @Override + public Optional getQueuedResponse() { + return Optional.ofNullable(responseRef.get()); + } + + @Override + public boolean sendResponse() { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isEmpty()) { + throw new UnsupportedOperationException("No response has been associated with this channel"); + } + + final SecurityResponse response = responseRef.get(); + + try { + final BytesRestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(response.getStatus()), response.getBody()); + if (response.getHeaders() != null) { + response.getHeaders().forEach(restResponse::addHeader); + } + underlyingChannel.sendResponse(restResponse); + + return true; + } catch (final Exception e) { + log.error("Error when attempting to send response", e); + throw new RuntimeException(e); + } finally { + hasCompleted.set(true); + } + + } +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index f433a5857d..00b117ebb8 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -464,6 +464,7 @@ public void onFailure(Exception e) { : String.format("no permissions for %s and %s", pres.getMissingPrivileges(), user); } log.debug(err); + listener.onFailure(new OpenSearchSecurityException(err, RestStatus.FORBIDDEN)); } } catch (OpenSearchException e) { diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequest.java b/src/main/java/org/opensearch/security/filter/SecurityRequest.java new file mode 100644 index 0000000000..7e6e94e0a6 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequest.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +import javax.net.ssl.SSLEngine; + +import org.opensearch.rest.RestRequest.Method; + +/** How the security plugin interacts with requests */ +public interface SecurityRequest { + + /** Collection of headers associated with the request */ + public Map> getHeaders(); + + /** The SSLEngine associated with the request */ + public SSLEngine getSSLEngine(); + + /** The path of the request */ + public String path(); + + /** The method type of this request */ + public Method method(); + + /** The remote address of the request, possible null */ + public Optional getRemoteAddress(); + + /** The full uri of the request */ + public String uri(); + + /** Finds the first value of the matching header or null */ + default public String header(final String headerName) { + final Optional>> headersMap = Optional.ofNullable(getHeaders()); + return headersMap.map(headers -> headers.get(headerName)).map(List::stream).flatMap(Stream::findFirst).orElse(null); + } + + /** The parameters associated with this request */ + public Map params(); +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java new file mode 100644 index 0000000000..1eec754c08 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; + +/** + * When a request is recieved by the security plugin this governs getting information about the request and complete with with a response + */ +public interface SecurityRequestChannel extends SecurityRequest { + + /** Associate a response with this channel */ + public void queueForSending(final SecurityResponse response); + + /** Acess the queued response */ + public Optional getQueuedResponse(); + + /** Send the response through the channel */ + public boolean sendResponse(); +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java new file mode 100644 index 0000000000..bcacc2cf7a --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +/** Thrown when a security rest channel is not supported */ +public class SecurityRequestChannelUnsupported extends RuntimeException { + +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java new file mode 100644 index 0000000000..de74df01ff --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; + +/** + * Generates wrapped versions of requests for use in the security plugin + */ +public class SecurityRequestFactory { + + /** Creates a security request from a RestRequest */ + public static SecurityRequest from(final RestRequest request) { + return new OpenSearchRequest(request); + } + + /** Creates a security request channel from a RestRequest & RestChannel */ + public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) { + return new OpenSearchRequestChannel(request, channel); + } +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java new file mode 100644 index 0000000000..8618be3aab --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Map; + +import org.apache.http.HttpHeaders; + +public class SecurityResponse { + + public static final Map CONTENT_TYPE_APP_JSON = Map.of(HttpHeaders.CONTENT_TYPE, "application/json"); + + private final int status; + private final Map headers; + private final String body; + + public SecurityResponse(final int status, final Map headers, final String body) { + this.status = status; + this.headers = headers; + this.body = body; + } + + public int getStatus() { + return status; + } + + public Map getHeaders() { + return headers; + } + + public String getBody() { + return body; + } + +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index 54d98ffe13..ce80a9e143 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -35,6 +35,7 @@ import javax.net.ssl.SSLPeerUnverifiedException; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.greenrobot.eventbus.Subscribe; @@ -42,13 +43,9 @@ import org.opensearch.OpenSearchException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.NamedRoute; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; -import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auditlog.AuditLog.Origin; import org.opensearch.security.auth.BackendRegistry; @@ -131,17 +128,40 @@ public SecurityRestFilter( public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { return (request, channel, client) -> { org.apache.logging.log4j.ThreadContext.clearAll(); - if (!checkAndAuthenticateRequest(request, channel)) { - User user = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); - boolean isSuperAdminUser = userIsSuperAdmin(user, adminDNs); - if (isSuperAdminUser - || (whitelistingSettings.checkRequestIsAllowed(request, channel, client) - && allowlistingSettings.checkRequestIsAllowed(request, channel, client))) { - if (isSuperAdminUser || authorizeRequest(original, request, channel, user)) { - original.handleRequest(request, channel, client); - } - } + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel); + + // Authenticate request + checkAndAuthenticateRequest(requestChannel); + if (requestChannel.getQueuedResponse().isPresent()) { + requestChannel.sendResponse(); + return; + } + + // Authorize Request + final User user = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + if (userIsSuperAdmin(user, adminDNs)) { + // Super admins are always authorized + original.handleRequest(request, channel, client); + return; + } + + final Optional deniedResponse = whitelistingSettings.checkRequestIsAllowed(requestChannel) + .or(() -> allowlistingSettings.checkRequestIsAllowed(requestChannel)); + + if (deniedResponse.isPresent()) { + requestChannel.queueForSending(deniedResponse.orElseThrow()); + requestChannel.sendResponse(); + return; + } + + authorizeRequest(original, requestChannel, user); + if (requestChannel.getQueuedResponse().isPresent()) { + requestChannel.sendResponse(); + return; } + + // Caller was authorized, forward the request to the handler + original.handleRequest(request, channel, client); }; } @@ -152,8 +172,7 @@ private boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { return user != null && adminDNs.isAdmin(user); } - private boolean authorizeRequest(RestHandler original, RestRequest request, RestChannel channel, User user) { - + private void authorizeRequest(RestHandler original, SecurityRequestChannel request, User user) { List restRoutes = original.routes(); Optional handler = restRoutes.stream() .filter(rh -> rh.getMethod().equals(request.method())) @@ -190,38 +209,37 @@ private boolean authorizeRequest(RestHandler original, RestRequest request, Rest err = String.format("no permissions for %s and %s", pres.getMissingPrivileges(), user); } log.debug(err); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, err)); - return false; + + request.queueForSending(new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, null, err)); + return; } } - - // if handler is not an instance of NamedRoute then we pass through to eval at Transport Layer. - return true; } - private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel channel) throws Exception { - + public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) throws Exception { threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN, Origin.REST.toString()); - if (HTTPHelper.containsBadHeader(request)) { + if (HTTPHelper.containsBadHeader(requestChannel)) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); log.error(exception.toString()); - auditLog.logBadHeaders(request); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, exception)); - return true; + auditLog.logBadHeaders(requestChannel); + + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + return; } if (SSLRequestHelper.containsBadHeader(threadContext, ConfigConstants.OPENDISTRO_SECURITY_CONFIG_PREFIX)) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); log.error(exception.toString()); - auditLog.logBadHeaders(request); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, exception)); - return true; + auditLog.logBadHeaders(requestChannel); + + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + return; } final SSLInfo sslInfo; try { - if ((sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor)) != null) { + if ((sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, requestChannel, principalExtractor)) != null) { if (sslInfo.getPrincipal() != null) { threadContext.putTransient("_opendistro_security_ssl_principal", sslInfo.getPrincipal()); } @@ -234,22 +252,23 @@ private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel cha } } catch (SSLPeerUnverifiedException e) { log.error("No ssl info", e); - auditLog.logSSLException(request, e); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, e)); - return true; + auditLog.logSSLException(requestChannel, e); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, null)); + return; } if (!compatConfig.restAuthEnabled()) { - return false; + // Authentication is disabled + return; } - Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(requestChannel.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (request.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix)) && !(WHO_AM_I_SUFFIX.equals(suffix))) { - if (!registry.authenticate(request, channel, threadContext)) { + if (requestChannel.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix)) && !(WHO_AM_I_SUFFIX.equals(suffix))) { + if (!registry.authenticate(requestChannel)) { // another roundtrip org.apache.logging.log4j.ThreadContext.remove("user"); - return true; + return; } else { // make it possible to filter logs by username org.apache.logging.log4j.ThreadContext.put( @@ -258,8 +277,6 @@ private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel cha ); } } - - return false; } @Subscribe diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index 16a822df6d..ff07db147e 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -27,16 +27,18 @@ package org.opensearch.security.http; import java.nio.file.Path; +import java.util.Map; +import java.util.Optional; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.HTTPHelper; import org.opensearch.security.user.AuthCredentials; @@ -50,9 +52,9 @@ public HTTPBasicAuthenticator(final Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext threadContext) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { - final boolean forceLogin = request.paramAsBoolean("force_login", false); + final boolean forceLogin = Boolean.getBoolean(request.params().get("force_login")); if (forceLogin) { return null; @@ -64,10 +66,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - return wwwAuthenticateResponse; + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""), "") + ); } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index cea59f7311..433ec01458 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import javax.naming.InvalidNameException; import javax.naming.ldap.LdapName; @@ -41,9 +42,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; @@ -57,7 +58,7 @@ public HTTPClientCertAuthenticator(final Settings settings, final Path configPat } @Override - public AuthCredentials extractCredentials(final RestRequest request, final ThreadContext threadContext) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { final String principal = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); @@ -98,8 +99,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - return null; + public Optional reRequestAuthentication(final SecurityRequest response, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index 1db7b9769b..de57e2f9e6 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -27,6 +27,7 @@ package org.opensearch.security.http; import java.nio.file.Path; +import java.util.Optional; import java.util.regex.Pattern; import com.google.common.base.Predicates; @@ -37,9 +38,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; @@ -56,7 +57,7 @@ public HTTPProxyAuthenticator(Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext context) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) { if (context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) != Boolean.TRUE) { throw new OpenSearchSecurityException("xff not done"); @@ -89,8 +90,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - return null; + public Optional reRequestAuthentication(final SecurityRequest response, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index 02077bed7c..8499b88f62 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -15,6 +15,7 @@ import java.security.PrivilegedAction; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.Map.Entry; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -33,10 +34,10 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.authtoken.jwt.EncryptionDecryptionUtil; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.KeyUtils; @@ -120,7 +121,8 @@ private String[] extractBackendRolesFromClaims(Claims claims) { @Override @SuppressWarnings("removal") - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -137,7 +139,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) { + private AuthCredentials extractCredentials0(final SecurityRequest request) { if (!oboEnabled) { log.error("On-behalf-of authentication is disabled"); return null; @@ -201,7 +203,7 @@ private AuthCredentials extractCredentials0(final RestRequest request) { return null; } - private String extractJwtFromHeader(RestRequest request) { + private String extractJwtFromHeader(SecurityRequest request) { String jwtToken = request.header(HttpHeaders.AUTHORIZATION); if (jwtToken == null || jwtToken.isEmpty()) { @@ -229,7 +231,7 @@ private void logDebug(String message, Object... args) { } } - public Boolean isRequestAllowed(final RestRequest request) { + public Boolean isRequestAllowed(final SecurityRequest request) { Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; if (isAccessToRestrictedEndpoints(request, suffix)) { @@ -241,8 +243,8 @@ public Boolean isRequestAllowed(final RestRequest request) { } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - return null; + public Optional reRequestAuthentication(final SecurityRequest response, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/http/RemoteIpDetector.java b/src/main/java/org/opensearch/security/http/RemoteIpDetector.java index f464c0653a..7b76a82c42 100644 --- a/src/main/java/org/opensearch/security/http/RemoteIpDetector.java +++ b/src/main/java/org/opensearch/security/http/RemoteIpDetector.java @@ -43,6 +43,7 @@ package org.opensearch.security.http; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.LinkedList; import java.util.List; @@ -52,7 +53,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.support.ConfigConstants; final class RemoteIpDetector { @@ -115,8 +116,12 @@ public String getRemoteIpHeader() { return remoteIpHeader; } - String detect(RestRequest request, ThreadContext threadContext) { - final String originalRemoteAddr = ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress().getHostAddress(); + String detect(SecurityRequest request, ThreadContext threadContext) { + + final String originalRemoteAddr = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(InetAddress::getHostAddress) + .orElseThrow(); final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { @@ -173,8 +178,10 @@ String detect(RestRequest request, ThreadContext threadContext) { if (remoteIp != null) { if (isTraceEnabled) { - final String originalRemoteHost = ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress() - .getHostName(); + final String originalRemoteHost = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(InetAddress::getHostName) + .orElseThrow(); log.trace( "Incoming request {} with originalRemoteAddr '{}', originalRemoteHost='{}', will be seen as newRemoteAddr='{}'", request.uri(), @@ -196,7 +203,7 @@ String detect(RestRequest request, ThreadContext threadContext) { log.trace( "Skip RemoteIpDetector for request {} with originalRemoteAddr '{}' cause no internal proxy matches", request.uri(), - request.getHttpChannel().getRemoteAddress() + request.getRemoteAddress().orElse(null) ); } } diff --git a/src/main/java/org/opensearch/security/http/XFFResolver.java b/src/main/java/org/opensearch/security/http/XFFResolver.java index ddb7255179..e9ad412831 100644 --- a/src/main/java/org/opensearch/security/http/XFFResolver.java +++ b/src/main/java/org/opensearch/security/http/XFFResolver.java @@ -34,9 +34,11 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.securityconf.DynamicConfigModel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.threadpool.ThreadPool; @@ -53,20 +55,23 @@ public XFFResolver(final ThreadPool threadPool) { this.threadContext = threadPool.getThreadContext(); } - public TransportAddress resolve(final RestRequest request) throws OpenSearchSecurityException { + public TransportAddress resolve(final SecurityRequest request) throws OpenSearchSecurityException { final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { - log.trace("resolve {}", request.getHttpChannel().getRemoteAddress()); + log.trace("resolve {}", request.getRemoteAddress().orElse(null)); } - if (enabled - && request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress - && request.getHttpChannel() instanceof Netty4HttpChannel) { + boolean requestFromNetty = false; + if (request instanceof OpenSearchRequest) { + final OpenSearchRequest securityRequestChannel = (OpenSearchRequest) request; + final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest(); - final InetSocketAddress isa = new InetSocketAddress( - detector.detect(request, threadContext), - ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getPort() - ); + requestFromNetty = restRequest.getHttpChannel() instanceof Netty4HttpChannel; + } + + if (enabled && request.getRemoteAddress().isPresent() && requestFromNetty) { + final InetSocketAddress remoteAddress = request.getRemoteAddress().get(); + final InetSocketAddress isa = new InetSocketAddress(detector.detect(request, threadContext), remoteAddress.getPort()); if (isa.isUnresolved()) { throw new OpenSearchSecurityException("Cannot resolve address " + isa.getHostString()); @@ -74,23 +79,21 @@ public TransportAddress resolve(final RestRequest request) throws OpenSearchSecu if (isTraceEnabled) { if (threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) == Boolean.TRUE) { - log.trace("xff resolved {} to {}", request.getHttpChannel().getRemoteAddress(), isa); + log.trace("xff resolved {} to {}", remoteAddress, isa); } else { log.trace("no xff done for {}", request.getClass()); } } return new TransportAddress(isa); - } else if (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress) { - + } else if (request.getRemoteAddress().isPresent()) { if (isTraceEnabled) { log.trace("no xff done (enabled or no netty request) {},{},{},{}", enabled, request.getClass()); - } - return new TransportAddress((InetSocketAddress) request.getHttpChannel().getRemoteAddress()); + return new TransportAddress((InetSocketAddress) request.getRemoteAddress().get()); } else { throw new OpenSearchSecurityException( "Cannot handle this request. Remote address is " - + request.getHttpChannel().getRemoteAddress() + + request.getRemoteAddress().orElse(null) + " with request class " + request.getClass() ); diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index 0423fecefe..9a16e6c61b 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -37,7 +37,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -54,7 +54,7 @@ public HTTPExtendedProxyAuthenticator(Settings settings, final Path configPath) } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext context) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) { AuthCredentials credentials = super.extractCredentials(request, context); if (credentials == null) { return null; diff --git a/src/main/java/org/opensearch/security/rest/SecurityConfigUpdateAction.java b/src/main/java/org/opensearch/security/rest/SecurityConfigUpdateAction.java index c582c9f51b..bfbc16f98d 100644 --- a/src/main/java/org/opensearch/security/rest/SecurityConfigUpdateAction.java +++ b/src/main/java/org/opensearch/security/rest/SecurityConfigUpdateAction.java @@ -29,6 +29,7 @@ import org.opensearch.security.action.configupdate.ConfigUpdateAction; import org.opensearch.security.action.configupdate.ConfigUpdateRequest; import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; import org.opensearch.security.support.ConfigConstants; @@ -73,7 +74,12 @@ public List routes() { protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String[] configTypes = request.paramAsStringArrayOrEmptyIfAll("config_types"); - SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); + SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo( + settings, + configPath, + SecurityRequestFactory.from(request), + principalExtractor + ); if (sslInfo == null) { return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.FORBIDDEN, "")); diff --git a/src/main/java/org/opensearch/security/rest/SecurityWhoAmIAction.java b/src/main/java/org/opensearch/security/rest/SecurityWhoAmIAction.java index 8bab16484d..4f560f40b6 100644 --- a/src/main/java/org/opensearch/security/rest/SecurityWhoAmIAction.java +++ b/src/main/java/org/opensearch/security/rest/SecurityWhoAmIAction.java @@ -32,6 +32,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; import org.opensearch.security.ssl.util.SSLRequestHelper.SSLInfo; @@ -97,8 +98,12 @@ public void accept(RestChannel channel) throws Exception { BytesRestResponse response = null; try { - - SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); + SSLInfo sslInfo = SSLRequestHelper.getSSLInfo( + settings, + configPath, + SecurityRequestFactory.from(request), + principalExtractor + ); if (sslInfo == null) { response = new BytesRestResponse(RestStatus.FORBIDDEN, "No security data"); diff --git a/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java index 98fc7a266a..ba249e8c63 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/AllowlistingSettings.java @@ -15,12 +15,13 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; -import org.opensearch.client.node.NodeClient; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; +import org.apache.http.HttpStatus; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; public class AllowlistingSettings { private boolean enabled; @@ -75,7 +76,7 @@ public String toString() { * GET /_cluster/settings - OK * GET /_cluster/settings/ - OK */ - private boolean requestIsAllowlisted(RestRequest request) { + private boolean requestIsAllowlisted(final SecurityRequest request) { // ALSO ALLOWS REQUEST TO HAVE TRAILING '/' // pathWithoutTrailingSlash stores the endpoint path without extra '/'. eg: /_cat/nodes @@ -107,21 +108,30 @@ private boolean requestIsAllowlisted(RestRequest request) { * then all PUT /_opendistro/_security/api/rolesmapping/{resource_name} work. * Currently, each resource_name has to be allowlisted separately */ - public boolean checkRequestIsAllowed(RestRequest request, RestChannel channel, NodeClient client) throws IOException { + public Optional checkRequestIsAllowed(final SecurityRequest request) { // if allowlisting is enabled but the request is not allowlisted, then return false, otherwise true. if (this.enabled && !requestIsAllowlisted(request)) { - channel.sendResponse( - new BytesRestResponse( - RestStatus.FORBIDDEN, - channel.newErrorBuilder() - .startObject() - .field("error", request.method() + " " + request.path() + " API not allowlisted") - .field("status", RestStatus.FORBIDDEN) - .endObject() - ) + return Optional.of( + new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) ); - return false; } - return true; + return Optional.empty(); + } + + protected String getVerb() { + return "allowlisted"; + } + + protected String generateFailureMessage(final SecurityRequest request) { + try { + return XContentFactory.jsonBuilder() + .startObject() + .field("error", request.method() + " " + request.path() + " API not " + getVerb()) + .field("status", RestStatus.FORBIDDEN) + .endObject() + .toString(); + } catch (final IOException ioe) { + throw new RuntimeException(ioe); + } } } diff --git a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java index 4462bae90f..2e1ab791d2 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java @@ -11,16 +11,14 @@ package org.opensearch.security.securityconf.impl; -import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; -import org.opensearch.client.node.NodeClient; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; +import org.apache.http.HttpStatus; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; public class WhitelistingSettings extends AllowlistingSettings { private boolean enabled; @@ -75,7 +73,7 @@ public String toString() { * GET /_cluster/settings - OK * GET /_cluster/settings/ - OK */ - private boolean requestIsWhitelisted(RestRequest request) { + private boolean requestIsWhitelisted(final SecurityRequest request) { // ALSO ALLOWS REQUEST TO HAVE TRAILING '/' // pathWithoutTrailingSlash stores the endpoint path without extra '/'. eg: /_cat/nodes @@ -108,21 +106,17 @@ private boolean requestIsWhitelisted(RestRequest request) { * Currently, each resource_name has to be whitelisted separately */ @Override - public boolean checkRequestIsAllowed(RestRequest request, RestChannel channel, NodeClient client) throws IOException { + public Optional checkRequestIsAllowed(final SecurityRequest request) { // if whitelisting is enabled but the request is not whitelisted, then return false, otherwise true. if (this.enabled && !requestIsWhitelisted(request)) { - channel.sendResponse( - new BytesRestResponse( - RestStatus.FORBIDDEN, - channel.newErrorBuilder() - .startObject() - .field("error", request.method() + " " + request.path() + " API not whitelisted") - .field("status", RestStatus.FORBIDDEN) - .endObject() - ) + return Optional.of( + new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) ); - return false; } - return true; + return Optional.empty(); + } + + protected String getVerb() { + return "whitelisted"; } } diff --git a/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java b/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java index 90f01df9b5..b274a2ecd3 100644 --- a/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java +++ b/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java @@ -17,13 +17,13 @@ package org.opensearch.security.ssl; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; public interface SslExceptionHandler { - default void logError(Throwable t, RestRequest request, int type) { + default void logError(Throwable t, SecurityRequestChannel request, int type) { // no-op } diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java b/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java index e053da7787..dcd25c2837 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java @@ -33,6 +33,8 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -64,17 +66,17 @@ public ValidatingDispatcher( @Override public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { - checkRequest(request, channel); + checkRequest(SecurityRequestFactory.from(request, channel)); originalDispatcher.dispatchRequest(request, channel, threadContext); } @Override public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) { - checkRequest(channel.request(), channel); + checkRequest(SecurityRequestFactory.from(channel.request(), channel)); originalDispatcher.dispatchBadRequest(channel, threadContext, cause); } - protected void checkRequest(final RestRequest request, final RestChannel channel) { + protected void checkRequest(final SecurityRequestChannel request) { if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); diff --git a/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java b/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java index 54d109f497..8e32893eab 100644 --- a/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java +++ b/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java @@ -39,6 +39,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -84,8 +85,12 @@ public void accept(RestChannel channel) throws Exception { BytesRestResponse response = null; try { - - SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); + SSLInfo sslInfo = SSLRequestHelper.getSSLInfo( + settings, + configPath, + SecurityRequestFactory.from(request), + principalExtractor + ); X509Certificate[] certs = sslInfo == null ? null : sslInfo.getX509Certs(); X509Certificate[] localCerts = sslInfo == null ? null : sslInfo.getLocalCertificates(); diff --git a/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java b/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java index 1a23d0bb59..df92bfc703 100644 --- a/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java +++ b/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java @@ -36,7 +36,6 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; -import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -45,8 +44,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.env.Environment; -import org.opensearch.http.netty4.Netty4HttpChannel; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.transport.PrincipalExtractor.Type; @@ -121,25 +119,14 @@ public String toString() { public static SSLInfo getSSLInfo( final Settings settings, final Path configPath, - final RestRequest request, + final SecurityRequest request, PrincipalExtractor principalExtractor ) throws SSLPeerUnverifiedException { - - if (request == null || request.getHttpChannel() == null || !(request.getHttpChannel() instanceof Netty4HttpChannel)) { - return null; - } - - final Netty4HttpChannel httpChannel = (Netty4HttpChannel) request.getHttpChannel(); - SslHandler sslhandler = (SslHandler) httpChannel.getNettyChannel().pipeline().get("ssl_http"); - if (sslhandler == null && httpChannel.inboundPipeline() != null) { - sslhandler = (SslHandler) httpChannel.inboundPipeline().get("ssl_http"); - } - - if (sslhandler == null) { + final SSLEngine engine = request.getSSLEngine(); + if (engine == null) { return null; } - final SSLEngine engine = sslhandler.engine(); final SSLSession session = engine.getSession(); X509Certificate[] x509Certs = null; diff --git a/src/main/java/org/opensearch/security/support/HTTPHelper.java b/src/main/java/org/opensearch/security/support/HTTPHelper.java index c3b191f770..fe590f0d34 100644 --- a/src/main/java/org/opensearch/security/support/HTTPHelper.java +++ b/src/main/java/org/opensearch/security/support/HTTPHelper.java @@ -32,8 +32,7 @@ import java.util.Map; import org.apache.logging.log4j.Logger; - -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.user.AuthCredentials; public class HTTPHelper { @@ -86,7 +85,7 @@ public static AuthCredentials extractCredentials(String authorizationHeader, Log } } - public static boolean containsBadHeader(final RestRequest request) { + public static boolean containsBadHeader(final SecurityRequest request) { final Map> headers; diff --git a/src/main/java/org/opensearch/security/util/AuthTokenUtils.java b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java index 30f331d3a7..3884bf75fe 100644 --- a/src/main/java/org/opensearch/security/util/AuthTokenUtils.java +++ b/src/main/java/org/opensearch/security/util/AuthTokenUtils.java @@ -12,7 +12,7 @@ package org.opensearch.security.util; import org.opensearch.common.settings.Settings; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequest; import static org.opensearch.rest.RestRequest.Method.POST; import static org.opensearch.rest.RestRequest.Method.PUT; @@ -21,7 +21,7 @@ public class AuthTokenUtils { private static final String ON_BEHALF_OF_SUFFIX = "api/generateonbehalfoftoken"; private static final String ACCOUNT_SUFFIX = "api/account"; - public static Boolean isAccessToRestrictedEndpoints(final RestRequest request, final String suffix) { + public static Boolean isAccessToRestrictedEndpoints(final SecurityRequest request, final String suffix) { if (suffix == null) { return false; } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java index 04a30ba2db..4a28c0a752 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java @@ -85,7 +85,10 @@ public void testTokenMissing() throws Exception { HTTPJwtAuthenticator jwtAuth = new HTTPJwtAuthenticator(settings, null); Map headers = new HashMap(); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -101,7 +104,10 @@ public void testInvalid() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -120,7 +126,10 @@ public void testBearer() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNotNull(credentials); Assert.assertEquals("Leonard McCoy", credentials.getUsername()); @@ -139,7 +148,10 @@ public void testBearerWrongPosition() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken + "Bearer " + " 123"); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -152,7 +164,10 @@ public void testBasicAuthHeader() throws Exception { String basicAuth = BaseEncoding.base64().encode("user:password".getBytes(StandardCharsets.UTF_8)); Map headers = Collections.singletonMap(HttpHeaders.AUTHORIZATION, "Basic " + basicAuth); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, Collections.emptyMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, Collections.emptyMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -261,7 +276,7 @@ public void testUrlParam() throws Exception { FakeRestRequest req = new FakeRestRequest(headers, new HashMap()); req.params().put("abc", jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(req, null); + AuthCredentials credentials = jwtAuth.extractCredentials(req.asSecurityRequest(), null); Assert.assertNotNull(credentials); Assert.assertEquals("Leonard McCoy", credentials.getUsername()); @@ -311,7 +326,10 @@ public void testRS256() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); @@ -334,7 +352,10 @@ public void testES512() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); @@ -411,7 +432,7 @@ private AuthCredentials extractCredentialsFromJwtHeader(final Settings.Builder s final String jwsToken = jwtBuilder.signWith(secretKey, SignatureAlgorithm.HS512).compact(); final HTTPJwtAuthenticator jwtAuth = new HTTPJwtAuthenticator(settings, null); final Map headers = Map.of("Authorization", jwsToken); - return jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap<>()), null); + return jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap<>()).asSecurityRequest(), null); } } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index ccdab8d77a..d483a2ec81 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -52,7 +52,8 @@ public void basicTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) + .asSecurityRequest(), null ); @@ -74,7 +75,7 @@ public void jwksUriTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), null ); @@ -95,7 +96,8 @@ public void jwksMissingRequiredIssuerInClaimTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_ISSUER_OCT_1), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_ISSUER_OCT_1), new HashMap<>()) + .asSecurityRequest(), null ); @@ -109,7 +111,7 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), null ); @@ -126,7 +128,8 @@ public void jwksMissingRequiredAudienceInClaimTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_AUDIENCE_OCT_1), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_AUDIENCE_OCT_1), new HashMap<>()) + .asSecurityRequest(), null ); @@ -143,7 +146,7 @@ public void jwksNotMatchingRequiredAudienceInClaimTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), null ); @@ -155,7 +158,7 @@ public void jwksUriMissingTest() { var exception = Assert.assertThrows(Exception.class, () -> { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(Settings.builder().build(), null); jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), null ); }); @@ -178,7 +181,7 @@ public void testEscapeKid() { new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_INVALID_KID), new HashMap() - ), + ).asSecurityRequest(), null ); @@ -200,7 +203,8 @@ public void bearerTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) + .asSecurityRequest(), null ); @@ -223,7 +227,8 @@ public void testRoles() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) + .asSecurityRequest(), null ); @@ -242,7 +247,7 @@ public void testExp() throws Exception { new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_EXPIRED_SIGNED_OCT_1), new HashMap() - ), + ).asSecurityRequest(), null ); @@ -267,7 +272,7 @@ public void testExpInSkew() throws Exception { new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), new HashMap() - ), + ).asSecurityRequest(), null ); @@ -292,7 +297,7 @@ public void testNbf() throws Exception { new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), new HashMap() - ), + ).asSecurityRequest(), null ); @@ -317,7 +322,7 @@ public void testNbfInSkew() throws Exception { new FakeRestRequest( ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), new HashMap() - ), + ).asSecurityRequest(), null ); @@ -336,7 +341,8 @@ public void testRS256() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -355,7 +361,8 @@ public void testBadSignature() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap()) + .asSecurityRequest(), null ); @@ -376,7 +383,7 @@ public void testPeculiarJsonEscaping() { new FakeRestRequest( ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), new HashMap() - ), + ).asSecurityRequest(), null ); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index a30e43f9fa..6b5c541981 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -32,7 +32,8 @@ public void basicTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -58,7 +59,8 @@ public void wrongSigTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()) + .asSecurityRequest(), null ); @@ -80,7 +82,8 @@ public void noAlgTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -105,7 +108,8 @@ public void mismatchedAlgTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -128,7 +132,8 @@ public void keyExchangeTest() throws Exception { try { AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -139,21 +144,24 @@ public void keyExchangeTest() throws Exception { Assert.assertEquals(4, creds.getAttributes().size()); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()) + .asSecurityRequest(), null ); Assert.assertNull(creds); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()) + .asSecurityRequest(), null ); Assert.assertNull(creds); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), null ); @@ -175,7 +183,8 @@ public void keyExchangeTest() throws Exception { try { AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()) + .asSecurityRequest(), null ); diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index 2594388128..17a2148fa5 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -25,6 +25,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -34,6 +35,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -42,24 +44,26 @@ import org.opensaml.saml.saml2.core.NameIDType; import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestResponse; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.DefaultObjectMapper; +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.test.helper.file.FileHelper; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.FakeRestRequest; import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_CONTENT; import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_URL; +import static org.hamcrest.MatcherAssert.assertThat; public class HTTPSamlAuthenticatorTest { protected MockSamlIdpServer mockSamlIdpServer; @@ -140,12 +144,8 @@ public void basicTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -161,6 +161,17 @@ public void basicTest() throws Exception { Assert.assertEquals("horst", jwt.getClaim("sub")); } + private Optional sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) { + final SecurityRequest tokenRestChannel = SecurityRequestFactory.from(request); + + return samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + } + + private String getResponse(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) throws Exception { + SecurityResponse response = sendToAuthenticator(samlAuthenticator, request).orElseThrow(); + return response.getBody(); + } + @Test public void decryptAssertionsTest() throws Exception { mockSamlIdpServer.setAuthenticateUser("horst"); @@ -188,12 +199,7 @@ public void decryptAssertionsTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -237,12 +243,7 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -289,12 +290,7 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -341,12 +337,7 @@ public void shouldNotEscapeSamlEntities() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -393,12 +384,7 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -441,12 +427,7 @@ public void testMetadataBody() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -507,12 +488,7 @@ public void unsolicitedSsoTest() throws Exception { null, "/opendistrosecurity/saml/acs/idpinitiated" ); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -559,12 +535,9 @@ public void badUnsolicitedSsoTest() throws Exception { authenticateHeaders, "/opendistrosecurity/saml/acs/idpinitiated" ); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); + Assert.assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatus()); } @Test @@ -592,12 +565,9 @@ public void wrongCertTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @Test @@ -622,12 +592,9 @@ public void noSignatureTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @SuppressWarnings("unchecked") @@ -656,12 +623,7 @@ public void rolesTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -704,12 +666,7 @@ public void idpEndpointWithQueryStringTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -759,12 +716,7 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -864,11 +816,9 @@ public void initialConnectionFailureTest() throws Exception { HTTPSamlAuthenticator samlAuthenticator = new HTTPSamlAuthenticator(settings, null); RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = new TestRestChannel(restRequest); - BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); - restChannel.sendResponse(authenticatorResponse); + Optional maybeResponse = sendToAuthenticator(samlAuthenticator, restRequest); - Assert.assertNull(restChannel.response); + assertThat(maybeResponse.isPresent(), Matchers.equalTo(false)); mockSamlIdpServer.start(); @@ -884,12 +834,7 @@ public void initialConnectionFailureTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); - tokenRestChannel.sendResponse(authenticatorResponse); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); HashMap response = DefaultObjectMapper.objectMapper.readValue( responseJson, new TypeReference>() { @@ -908,17 +853,11 @@ public void initialConnectionFailureTest() throws Exception { private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuthenticator) { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = new TestRestChannel(restRequest); - - final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); - restChannel.sendResponse(authenticatorResponse); - - List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow(); - Assert.assertNotNull(wwwAuthenticateHeaders); - Assert.assertEquals("More than one WWW-Authenticate header: " + wwwAuthenticateHeaders, 1, wwwAuthenticateHeaders.size()); + String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate"); - String wwwAuthenticateHeader = wwwAuthenticateHeaders.get(0); + Assert.assertNotNull(wwwAuthenticateHeader); Matcher wwwAuthenticateHeaderMatcher = WWW_AUTHENTICATE_PATTERN.matcher(wwwAuthenticateHeader); diff --git a/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java b/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java index 458e240596..80c8fd7b17 100644 --- a/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java +++ b/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java @@ -16,6 +16,8 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; public class MockRestRequest extends RestRequest { @@ -44,4 +46,8 @@ public boolean hasContent() { public BytesReference content() { return null; } + + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } } diff --git a/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java index bbdbc3aefd..935fb924a3 100644 --- a/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java @@ -21,10 +21,10 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.AuditTestUtils; import org.opensearch.security.auditlog.helper.RetrySink; import org.opensearch.security.auditlog.integration.TestAuditlogImpl; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.test.AbstractSecurityUnitTest; import org.opensearch.transport.TransportRequest; @@ -132,7 +132,7 @@ public void testRestFilterEnabledCheck() { final Settings settings = Settings.builder().put(ConfigConstants.OPENDISTRO_SECURITY_AUDIT_ENABLE_REST, false).build(); final AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); for (AuditCategory category : AuditCategory.values()) { - Assert.assertFalse(al.checkRestFilter(category, "user", mock(RestRequest.class))); + Assert.assertFalse(al.checkRestFilter(category, "user", mock(SecurityRequestChannel.class))); } } diff --git a/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java b/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java index 0d08abaeea..ba4ee7b55d 100644 --- a/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java +++ b/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java @@ -227,11 +227,16 @@ protected void logAll(AuditLog auditLog) { } protected void logRestSucceededLogin(AuditLog auditLog) { - auditLog.logSucceededLogin("testuser.rest.succeededlogin", false, "testuser.rest.succeededlogin", new MockRestRequest()); + auditLog.logSucceededLogin( + "testuser.rest.succeededlogin", + false, + "testuser.rest.succeededlogin", + new MockRestRequest().asSecurityRequest() + ); } protected void logRestFailedLogin(AuditLog auditLog) { - auditLog.logFailedLogin("testuser.rest.failedlogin", false, "testuser.rest.failedlogin", new MockRestRequest()); + auditLog.logFailedLogin("testuser.rest.failedlogin", false, "testuser.rest.failedlogin", new MockRestRequest().asSecurityRequest()); } protected void logMissingPrivileges(AuditLog auditLog) { @@ -243,7 +248,7 @@ protected void logTransportBadHeaders(AuditLog auditLog) { } protected void logRestBadHeaders(AuditLog auditLog) { - auditLog.logBadHeaders(new MockRestRequest()); + auditLog.logBadHeaders(new MockRestRequest().asSecurityRequest()); } protected void logSecurityIndexAttempt(AuditLog auditLog) { @@ -251,7 +256,7 @@ protected void logSecurityIndexAttempt(AuditLog auditLog) { } protected void logRestSSLException(AuditLog auditLog) { - auditLog.logSSLException(new MockRestRequest(), new Exception()); + auditLog.logSSLException(new MockRestRequest().asSecurityRequest(), new Exception()); } protected void logTransportSSLException(AuditLog auditLog) { diff --git a/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java index d563308e31..4072d94436 100644 --- a/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java +++ b/src/test/java/org/opensearch/security/authtoken/jwt/AuthTokenUtilsTest.java @@ -14,6 +14,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.util.AuthTokenUtils; import org.opensearch.test.rest.FakeRestRequest; import org.junit.Test; @@ -33,7 +34,7 @@ public void testIsAccessToRestrictedEndpointsForOnBehalfOfToken() { .withMethod(RestRequest.Method.POST) .build(); - assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/generateonbehalfoftoken")); + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(SecurityRequestFactory.from(request), "api/generateonbehalfoftoken")); } @Test @@ -44,7 +45,7 @@ public void testIsAccessToRestrictedEndpointsForAccount() { .withMethod(RestRequest.Method.PUT) .build(); - assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/account")); + assertTrue(AuthTokenUtils.isAccessToRestrictedEndpoints(SecurityRequestFactory.from(request), "api/account")); } @Test @@ -55,7 +56,7 @@ public void testIsAccessToRestrictedEndpointsFalseCase() { .withMethod(RestRequest.Method.GET) .build(); - assertFalse(AuthTokenUtils.isAccessToRestrictedEndpoints(request, "api/someotherendpoint")); + assertFalse(AuthTokenUtils.isAccessToRestrictedEndpoints(SecurityRequestFactory.from(request), "api/someotherendpoint")); } @Test diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index 37ac45080b..2cfd23fc23 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -12,13 +12,14 @@ package org.opensearch.security.cache; import java.nio.file.Path; +import java.util.Optional; import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; public class DummyHTTPAuthenticator implements HTTPAuthenticator { @@ -33,14 +34,15 @@ public String getType() { } @Override - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { count++; return new AuthCredentials("dummy").markComplete(); } @Override - public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { - return null; + public Optional reRequestAuthentication(SecurityRequest channel, AuthCredentials credentials) { + return Optional.empty(); } public static long getCount() { diff --git a/src/test/java/org/opensearch/security/http/OnBehalfOfAuthenticatorTest.java b/src/test/java/org/opensearch/security/http/OnBehalfOfAuthenticatorTest.java index fbb03bf7c3..1ff6adee3a 100644 --- a/src/test/java/org/opensearch/security/http/OnBehalfOfAuthenticatorTest.java +++ b/src/test/java/org/opensearch/security/http/OnBehalfOfAuthenticatorTest.java @@ -95,7 +95,10 @@ public void testTokenMissing() throws Exception { OnBehalfOfAuthenticator jwtAuth = new OnBehalfOfAuthenticator(defaultSettings(), clusterName); Map headers = new HashMap(); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -109,7 +112,10 @@ public void testInvalid() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -126,7 +132,10 @@ public void testDisabled() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -143,7 +152,10 @@ public void testNonSpecifyOBOSetting() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNotNull(credentials); } @@ -165,7 +177,10 @@ public void testBearer() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNotNull(credentials); Assert.assertEquals("Leonard McCoy", credentials.getUsername()); @@ -188,7 +203,10 @@ public void testBearerWrongPosition() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken + "Bearer " + " 123"); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -205,7 +223,10 @@ public void testBasicAuthHeader() throws Exception { Map headers = Collections.singletonMap(HttpHeaders.AUTHORIZATION, "Basic " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, Collections.emptyMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, Collections.emptyMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -349,7 +370,10 @@ public void testDifferentIssuer() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer " + jwsToken); - AuthCredentials credentials = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); Assert.assertNull(credentials); } @@ -375,7 +399,7 @@ private AuthCredentials extractCredentialsFromJwtHeader( SignatureAlgorithm.HS512 ).compact(); final Map headers = Map.of("Authorization", "Bearer " + jwsToken); - return jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap<>()), null); + return jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap<>()).asSecurityRequest(), null); } private Settings defaultSettings() { diff --git a/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java b/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java index d3bf10d943..f7a2011a68 100644 --- a/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java +++ b/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java @@ -47,6 +47,8 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; @@ -77,19 +79,19 @@ public void testGetType() { @Test(expected = OpenSearchSecurityException.class) public void testThrowsExceptionWhenMissingXFFDone() { authenticator = new HTTPExtendedProxyAuthenticator(Settings.EMPTY, null); - authenticator.extractCredentials(new TestRestRequest(), new ThreadContext(Settings.EMPTY)); + authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), new ThreadContext(Settings.EMPTY)); } @Test public void testReturnsNullWhenUserHeaderIsUnconfigured() { authenticator = new HTTPExtendedProxyAuthenticator(Settings.EMPTY, null); - assertNull(authenticator.extractCredentials(new TestRestRequest(), context)); + assertNull(authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), context)); } @Test public void testReturnsNullWhenUserHeaderIsMissing() { - assertNull(authenticator.extractCredentials(new TestRestRequest(), context)); + assertNull(authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), context)); } @Test @@ -105,7 +107,7 @@ public void testReturnsCredentials() { settings = Settings.builder().put(settings).put("attr_header_prefix", "proxy_").build(); authenticator = new HTTPExtendedProxyAuthenticator(settings, null); - AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers), context); + AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers).asSecurityRequest(), context); assertNotNull(creds); assertEquals("aValidUser", creds.getUsername()); assertEquals("123,456", creds.getAttributes().get("attr.proxy.uid")); @@ -122,7 +124,7 @@ public void testTrimOnRoles() { settings = Settings.builder().put(settings).put("roles_header", "roles").put("roles_separator", ",").build(); authenticator = new HTTPExtendedProxyAuthenticator(settings, null); - AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers), context); + AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers).asSecurityRequest(), context); assertNotNull(creds); assertEquals("aValidUser", creds.getUsername()); assertEquals(ImmutableSet.of("role1", "role2"), creds.getBackendRoles()); @@ -163,6 +165,9 @@ public boolean hasContent() { return false; } + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } } static class HttpRequestImpl implements HttpRequest { diff --git a/src/test/java/org/opensearch/security/util/FakeRestRequest.java b/src/test/java/org/opensearch/security/util/FakeRestRequest.java index f8d8aca508..121ddf778e 100644 --- a/src/test/java/org/opensearch/security/util/FakeRestRequest.java +++ b/src/test/java/org/opensearch/security/util/FakeRestRequest.java @@ -18,6 +18,8 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; public class FakeRestRequest extends RestRequest { @@ -117,4 +119,7 @@ private static Map> convert(Map headers) { return ret; } + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } }