Skip to content

Commit

Permalink
Remove dependency on Dropwizard Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
oneonestar authored and mosabua committed May 29, 2024
1 parent f666602 commit 6dcd506
Show file tree
Hide file tree
Showing 22 changed files with 536 additions and 162 deletions.
16 changes: 0 additions & 16 deletions gateway-ha/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,6 @@
</exclusions>
</dependency>

<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-auth</artifactId>
<exclusions>
<exclusion>
<!-- pulls in 2.0.1.MR -->
<groupId>jakarta.inject</groupId>
<artifactId>jakarta.inject-api</artifactId>
</exclusion>
<exclusion>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import com.google.inject.Injector;
import com.google.inject.Module;
import io.airlift.log.Logger;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.AuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.server.DefaultServerFactory;
Expand All @@ -41,6 +39,7 @@
import io.trino.gateway.ha.resource.PublicResource;
import io.trino.gateway.ha.resource.TrinoResource;
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
import io.trino.gateway.ha.security.ResourceSecurityDynamicFeature;
import org.glassfish.jersey.server.filter.RolesAllowedDynamicFeature;

import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -226,8 +225,8 @@ private void registerAuthFilters(Environment environment, Injector injector)
{
environment
.jersey()
.register(new AuthDynamicFeature(injector.getInstance(AuthFilter.class)));
logger.info("op=register type=auth filter item=%s", AuthFilter.class);
.register(injector.getInstance(ResourceSecurityDynamicFeature.class));
logger.info("op=register type=auth filter item=%s", ResourceSecurityDynamicFeature.class);
environment.jersey().register(RolesAllowedDynamicFeature.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
*/
package io.trino.gateway.ha.module;

import com.google.common.collect.ImmutableList;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import io.dropwizard.auth.AuthFilter;
import io.dropwizard.auth.Authorizer;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.auth.chained.ChainedAuthFilter;
import io.dropwizard.core.server.DefaultServerFactory;
import io.dropwizard.core.server.SimpleServerFactory;
import io.dropwizard.core.setup.Environment;
Expand All @@ -41,22 +38,24 @@
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.security.ApiAuthenticator;
import io.trino.gateway.ha.security.AuthorizationManager;
import io.trino.gateway.ha.security.BasicAuthFilter;
import io.trino.gateway.ha.security.FormAuthenticator;
import io.trino.gateway.ha.security.LbAuthenticator;
import io.trino.gateway.ha.security.LbAuthorizer;
import io.trino.gateway.ha.security.LbFilter;
import io.trino.gateway.ha.security.LbFormAuthManager;
import io.trino.gateway.ha.security.LbOAuthManager;
import io.trino.gateway.ha.security.LbPrincipal;
import io.trino.gateway.ha.security.LbUnauthorizedHandler;
import io.trino.gateway.ha.security.NoopAuthenticator;
import io.trino.gateway.ha.security.NoopAuthorizer;
import io.trino.gateway.ha.security.NoopFilter;
import io.trino.gateway.ha.security.ResourceSecurityDynamicFeature;
import io.trino.gateway.ha.security.util.Authorizer;
import io.trino.gateway.ha.security.util.ChainedAuthFilter;
import io.trino.gateway.proxyserver.ProxyHandler;
import io.trino.gateway.proxyserver.ProxyServer;
import io.trino.gateway.proxyserver.ProxyServerConfiguration;
import jakarta.ws.rs.container.ContainerRequestFilter;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
Expand All @@ -70,7 +69,7 @@ public class HaGatewayProviderModule
private final LbFormAuthManager formAuthManager;
private final AuthorizationManager authorizationManager;
private final BackendStateManager backendStateConnectionManager;
private final AuthFilter authenticationFilter;
private final ResourceSecurityDynamicFeature resourceSecurityDynamicFeature;
private final List<String> extraWhitelistPaths;
private final HaGatewayConfiguration configuration;
private final Environment environment;
Expand All @@ -85,7 +84,7 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration, Environment
formAuthManager = getFormAuthManager(configuration);

authorizationManager = new AuthorizationManager(configuration.getAuthorization(), presetUsers);
authenticationFilter = getAuthFilter(configuration);
resourceSecurityDynamicFeature = getAuthFilter(configuration);
backendStateConnectionManager = new BackendStateManager();
extraWhitelistPaths = configuration.getExtraWhitelistPaths();

Expand Down Expand Up @@ -115,37 +114,32 @@ private LbFormAuthManager getFormAuthManager(HaGatewayConfiguration configuratio
return null;
}

private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration config,
Authorizer<LbPrincipal> authorizer)
private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration config, Authorizer authorizer)
{
List<AuthFilter> authFilters = new ArrayList<>();
ImmutableList.Builder<ContainerRequestFilter> authFilters = ImmutableList.builder();
String defaultType = config.getDefaultType();
if (oauthManager != null) {
authFilters.add(new LbFilter.Builder<LbPrincipal>()
.setAuthenticator(new LbAuthenticator(oauthManager, authorizationManager))
.setAuthorizer(authorizer)
.setUnauthorizedHandler(new LbUnauthorizedHandler(defaultType))
.setPrefix("Bearer")
.buildAuthFilter());
authFilters.add(new LbFilter(
new LbAuthenticator(oauthManager, authorizationManager),
authorizer,
"Bearer",
new LbUnauthorizedHandler(defaultType)));
}

if (formAuthManager != null) {
authFilters.add(new LbFilter.Builder<LbPrincipal>()
.setAuthenticator(new FormAuthenticator(formAuthManager, authorizationManager))
.setAuthorizer(authorizer)
.setUnauthorizedHandler(new LbUnauthorizedHandler(defaultType))
.setPrefix("Bearer")
.buildAuthFilter());
authFilters.add(new LbFilter(
new FormAuthenticator(formAuthManager, authorizationManager),
authorizer,
"Bearer",
new LbUnauthorizedHandler(defaultType)));

authFilters.add(new BasicCredentialAuthFilter.Builder<LbPrincipal>()
.setAuthenticator(new ApiAuthenticator(formAuthManager, authorizationManager))
.setAuthorizer(authorizer)
.setUnauthorizedHandler(new LbUnauthorizedHandler(defaultType))
.setPrefix("Basic")
.buildAuthFilter());
authFilters.add(new BasicAuthFilter(
new ApiAuthenticator(formAuthManager, authorizationManager),
authorizer,
new LbUnauthorizedHandler(defaultType)));
}

return new ChainedAuthFilter(authFilters);
return new ChainedAuthFilter(authFilters.build());
}

private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager,
Expand Down Expand Up @@ -190,22 +184,19 @@ private int getApplicationPort()
.orElseThrow(IllegalStateException::new);
}

private AuthFilter getAuthFilter(HaGatewayConfiguration configuration)
private ResourceSecurityDynamicFeature getAuthFilter(HaGatewayConfiguration configuration)
{
AuthorizationConfiguration authorizationConfig = configuration.getAuthorization();
Authorizer<LbPrincipal> authorizer = (authorizationConfig != null)
Authorizer authorizer = (authorizationConfig != null)
? new LbAuthorizer(authorizationConfig) : new NoopAuthorizer();

AuthenticationConfiguration authenticationConfig = configuration.getAuthentication();

if (authenticationConfig != null) {
return getAuthenticationFilters(authenticationConfig, authorizer);
return new ResourceSecurityDynamicFeature(getAuthenticationFilters(authenticationConfig, authorizer));
}

return new NoopFilter.Builder<LbPrincipal>()
.setAuthenticator(new NoopAuthenticator())
.setAuthorizer(authorizer)
.buildAuthFilter();
return new ResourceSecurityDynamicFeature(new NoopFilter());
}

@Provides
Expand Down Expand Up @@ -261,9 +252,9 @@ public AuthorizationManager getAuthorizationManager()

@Provides
@Singleton
public AuthFilter getAuthenticationFilter()
public ResourceSecurityDynamicFeature getResourceSecurityDynamicFeature()
{
return authenticationFilter;
return resourceSecurityDynamicFeature;
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
*/
package io.trino.gateway.ha.security;

import io.dropwizard.auth.AuthenticationException;
import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import io.trino.gateway.ha.security.util.AuthenticationException;
import io.trino.gateway.ha.security.util.BasicCredentials;

import java.util.Optional;

public class ApiAuthenticator
implements Authenticator<BasicCredentials, LbPrincipal>
{
private final LbFormAuthManager formAuthManager;
private final AuthorizationManager authorizationManager;
Expand All @@ -32,13 +30,12 @@ public ApiAuthenticator(LbFormAuthManager formAuthManager,
this.authorizationManager = authorizationManager;
}

@Override
public Optional<LbPrincipal> authenticate(BasicCredentials credentials)
throws AuthenticationException
{
if (formAuthManager.authenticate(credentials)) {
return Optional.of(new LbPrincipal(credentials.getUsername(),
authorizationManager.getPrivileges(credentials.getUsername())));
return Optional.of(new LbPrincipal(credentials.username(),
authorizationManager.getPrivileges(credentials.username())));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.security;

import io.trino.gateway.ha.security.util.AuthenticationException;
import io.trino.gateway.ha.security.util.Authorizer;
import io.trino.gateway.ha.security.util.BasicCredentials;
import jakarta.annotation.Priority;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.core.SecurityContext;

import java.io.IOException;
import java.security.Principal;

import static io.trino.gateway.ha.security.util.BasicCredentials.extractBasicAuthCredentials;
import static jakarta.ws.rs.Priorities.AUTHENTICATION;
import static java.util.Objects.requireNonNull;

@Priority(AUTHENTICATION)
public class BasicAuthFilter
implements ContainerRequestFilter
{
private final ApiAuthenticator apiAuthenticator;
private final Authorizer lbAuthorizer;
private final LbUnauthorizedHandler lbUnauthorizedHandler;

public BasicAuthFilter(ApiAuthenticator apiAuthenticator, Authorizer lbAuthorizer, LbUnauthorizedHandler lbUnauthorizedHandler)
{
this.apiAuthenticator = requireNonNull(apiAuthenticator);
this.lbAuthorizer = requireNonNull(lbAuthorizer);
this.lbUnauthorizedHandler = requireNonNull(lbUnauthorizedHandler, "lbUnauthorizedHandler is null");
}

@Override
public void filter(ContainerRequestContext requestContext)
throws IOException
{
try {
BasicCredentials basicCredentials = extractBasicAuthCredentials(requestContext);
LbPrincipal principal = apiAuthenticator.authenticate(basicCredentials)
.orElseThrow(() -> new AuthenticationException("Authentication error"));
requestContext.setSecurityContext(new SecurityContext()
{
@Override
public Principal getUserPrincipal()
{
return principal;
}

@Override
public boolean isUserInRole(String role)
{
return lbAuthorizer.authorize(principal, role, requestContext);
}

@Override
public boolean isSecure()
{
return requestContext.getSecurityContext().isSecure();
}

@Override
public String getAuthenticationScheme()
{
return SecurityContext.BASIC_AUTH;
}
});
}
catch (Exception e) {
throw new WebApplicationException(lbUnauthorizedHandler.buildResponse());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
*/
package io.trino.gateway.ha.security;

import io.dropwizard.auth.AuthenticationException;
import io.dropwizard.auth.Authenticator;
import io.trino.gateway.ha.security.util.AuthenticationException;
import io.trino.gateway.ha.security.util.IdTokenAuthenticator;

import java.util.Optional;

public class FormAuthenticator
implements Authenticator<String, LbPrincipal>
implements IdTokenAuthenticator
{
private final LbFormAuthManager formAuthManager;
private final AuthorizationManager authorizationManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
*/
package io.trino.gateway.ha.security;

import io.dropwizard.auth.AuthenticationException;
import io.dropwizard.auth.Authenticator;
import io.trino.gateway.ha.security.util.AuthenticationException;
import io.trino.gateway.ha.security.util.IdTokenAuthenticator;

import java.util.Optional;

public class LbAuthenticator
implements Authenticator<String, LbPrincipal>
implements IdTokenAuthenticator
{
private final LbOAuthManager oauthManager;
private final AuthorizationManager authorizationManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
package io.trino.gateway.ha.security;

import io.airlift.log.Logger;
import io.dropwizard.auth.Authorizer;
import io.trino.gateway.ha.config.AuthorizationConfiguration;
import io.trino.gateway.ha.security.util.Authorizer;
import jakarta.annotation.Nullable;
import jakarta.ws.rs.container.ContainerRequestContext;

public class LbAuthorizer
implements Authorizer<LbPrincipal>
implements Authorizer
{
private static final Logger log = Logger.get(LbAuthorizer.class);
private final AuthorizationConfiguration configuration;
Expand Down
Loading

0 comments on commit 6dcd506

Please sign in to comment.