From 6c3453c9b68dda88de6b7e23c37652bed12f5067 Mon Sep 17 00:00:00 2001 From: Alexandre Dutra Date: Thu, 12 Dec 2024 16:01:42 +0100 Subject: [PATCH] Auth Manager API part 3: OAuth2 Manager --- .../apache/iceberg/rest/auth/AuthConfig.java | 47 +++- .../iceberg/rest/auth/AuthManagers.java | 22 +- .../iceberg/rest/auth/AuthProperties.java | 3 + .../iceberg/rest/auth/AuthSessionCache.java | 112 ++++++++ .../iceberg/rest/auth/OAuth2Manager.java | 239 ++++++++++++++++++ .../apache/iceberg/rest/auth/OAuth2Util.java | 26 +- .../rest/auth/RefreshingAuthManager.java | 88 +++++++ .../iceberg/rest/auth/TestAuthManagers.java | 39 +++ .../rest/auth/TestAuthSessionCache.java | 91 +++++++ 9 files changed, 656 insertions(+), 11 deletions(-) create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java create mode 100644 core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java index 275884e1184a..1709d4a3e514 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java @@ -21,15 +21,16 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.util.PropertyUtil; import org.immutables.value.Value; /** - * The purpose of this class is to hold configuration options for {@link - * org.apache.iceberg.rest.auth.OAuth2Util.AuthSession}. + * The purpose of this class is to hold OAuth configuration options for {@link + * OAuth2Util.AuthSession}. */ @Value.Style(redactedMask = "****") -@SuppressWarnings("ImmutablesStyle") @Value.Immutable +@SuppressWarnings({"ImmutablesStyle", "SafeLoggingPropagation"}) public interface AuthConfig { @Nullable @Value.Redacted @@ -47,7 +48,7 @@ default String scope() { return OAuth2Properties.CATALOG_SCOPE; } - @Value.Lazy + @Value.Default @Nullable default Long expiresAtMillis() { return OAuth2Util.expiresAtMillis(token()); @@ -69,4 +70,42 @@ default String oauth2ServerUri() { static ImmutableAuthConfig.Builder builder() { return ImmutableAuthConfig.builder(); } + + static AuthConfig fromProperties(Map properties) { + return builder() + .credential(properties.get(OAuth2Properties.CREDENTIAL)) + .token(properties.get(OAuth2Properties.TOKEN)) + .scope(properties.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE)) + .oauth2ServerUri( + properties.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens())) + .optionalOAuthParams(OAuth2Util.buildOptionalParam(properties)) + .keepRefreshed( + PropertyUtil.propertyAsBoolean( + properties, + OAuth2Properties.TOKEN_REFRESH_ENABLED, + OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT)) + .expiresAtMillis(expiresAtMillis(properties)) + .build(); + } + + private static Long expiresAtMillis(Map props) { + Long expiresAtMillis = null; + + if (props.containsKey(OAuth2Properties.TOKEN)) { + expiresAtMillis = OAuth2Util.expiresAtMillis(props.get(OAuth2Properties.TOKEN)); + } + + if (expiresAtMillis == null) { + if (props.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { + long millis = + PropertyUtil.propertyAsLong( + props, + OAuth2Properties.TOKEN_EXPIRES_IN_MS, + OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); + expiresAtMillis = System.currentTimeMillis() + millis; + } + } + + return expiresAtMillis; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java index 42c2b1eeba83..46188c1281c5 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java @@ -31,8 +31,23 @@ public class AuthManagers { private AuthManagers() {} public static AuthManager loadAuthManager(String name, Map properties) { - String authType = - properties.getOrDefault(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_NONE); + String authType = properties.get(AuthProperties.AUTH_TYPE); + if (authType == null) { + boolean hasCredential = properties.containsKey(OAuth2Properties.CREDENTIAL); + boolean hasToken = properties.containsKey(OAuth2Properties.TOKEN); + if (hasCredential || hasToken) { + LOG.warn( + "Inferring {}={} since property {} was provided. " + + "Please explicitly set {} to avoid this warning.", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + hasCredential ? OAuth2Properties.CREDENTIAL : OAuth2Properties.TOKEN, + AuthProperties.AUTH_TYPE); + authType = AuthProperties.AUTH_TYPE_OAUTH2; + } else { + authType = AuthProperties.AUTH_TYPE_NONE; + } + } String impl; switch (authType.toLowerCase(Locale.ROOT)) { @@ -42,6 +57,9 @@ public static AuthManager loadAuthManager(String name, Map prope case AuthProperties.AUTH_TYPE_BASIC: impl = AuthProperties.AUTH_MANAGER_IMPL_BASIC; break; + case AuthProperties.AUTH_TYPE_OAUTH2: + impl = AuthProperties.AUTH_MANAGER_IMPL_OAUTH2; + break; default: impl = authType; } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java index bf94311d5578..a4ba2db586a7 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java @@ -26,11 +26,14 @@ private AuthProperties() {} public static final String AUTH_TYPE_NONE = "none"; public static final String AUTH_TYPE_BASIC = "basic"; + public static final String AUTH_TYPE_OAUTH2 = "oauth2"; public static final String AUTH_MANAGER_IMPL_NONE = "org.apache.iceberg.rest.auth.NoopAuthManager"; public static final String AUTH_MANAGER_IMPL_BASIC = "org.apache.iceberg.rest.auth.BasicAuthManager"; + public static final String AUTH_MANAGER_IMPL_OAUTH2 = + "org.apache.iceberg.rest.auth.OAuth2Manager"; public static final String BASIC_USERNAME = "rest.auth.basic.username"; public static final String BASIC_PASSWORD = "rest.auth.basic.password"; diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java new file mode 100644 index 000000000000..4f7760aff815 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalListener; +import com.github.benmanes.caffeine.cache.Ticker; +import java.time.Duration; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; + +/** A cache for {@link AuthSession} instances. */ +public class AuthSessionCache implements AutoCloseable { + + private final Duration sessionTimeout; + private final Executor executor; + private final Ticker ticker; + + private volatile Cache sessionCache; + + /** + * Creates a new cache with the given session timeout, and with default executor and ticker for + * eviction tasks. + */ + public AuthSessionCache(Duration sessionTimeout) { + this(sessionTimeout, ForkJoinPool.commonPool(), Ticker.systemTicker()); + } + + /** + * Creates a new cache with the given session timeout, executor, and ticker. This method is useful + * for testing mostly. + * + *

The executor is used to perform cache eviction; the ticker is used to measure access time. + * The executor will not be closed when this cache is closed. + */ + public AuthSessionCache(Duration sessionTimeout, Executor executor, Ticker ticker) { + this.sessionTimeout = sessionTimeout; + this.executor = executor; + this.ticker = ticker; + } + + /** + * Returns a cached session for the given key, loading it with the given loader if it is not + * already cached. + * + * @param key the key to use for the session + * @param loader the loader to use to load the session if it is not already cached + * @param the type of the session + * @return the cached session + */ + @SuppressWarnings("unchecked") + public T cachedSession(String key, Function loader) { + return (T) sessionCache().get(key, loader); + } + + @Override + public void close() { + Cache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.invalidateAll(); + cache.cleanUp(); + } + } + + @VisibleForTesting + Cache sessionCache() { + if (sessionCache == null) { + synchronized (this) { + if (sessionCache == null) { + this.sessionCache = newSessionCache(sessionTimeout, executor, ticker); + } + } + } + return sessionCache; + } + + private static Cache newSessionCache( + Duration sessionTimeout, Executor executor, Ticker ticker) { + return Caffeine.newBuilder() + .ticker(ticker) + .executor(executor) + .expireAfterAccess(sessionTimeout) + .removalListener( + (RemovalListener) + (id, auth, cause) -> { + if (auth != null) { + auth.close(); + } + }) + .build(); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java new file mode 100644 index 000000000000..17467b97323b --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.SessionCatalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.RESTUtil; +import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.iceberg.util.PropertyUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings("unused") // loaded by reflection +public class OAuth2Manager extends RefreshingAuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(OAuth2Manager.class); + + private static final List TOKEN_PREFERENCE_ORDER = + ImmutableList.of( + OAuth2Properties.ID_TOKEN_TYPE, + OAuth2Properties.ACCESS_TOKEN_TYPE, + OAuth2Properties.JWT_TOKEN_TYPE, + OAuth2Properties.SAML2_TOKEN_TYPE, + OAuth2Properties.SAML1_TOKEN_TYPE); + + // Auth-related properties that are allowed to be passed to the table session + private static final Set TABLE_SESSION_ALLOW_LIST = + ImmutableSet.builder() + .add(OAuth2Properties.TOKEN) + .addAll(TOKEN_PREFERENCE_ORDER) + .build(); + + private RESTClient client; + private long startTimeMillis; + private OAuthTokenResponse authResponse; + private AuthSessionCache sessionCache; + + public OAuth2Manager(String name) { + super(name + "-token-refresh"); + } + + @Override + public OAuth2Util.AuthSession initSession(RESTClient initClient, Map properties) { + warnIfDeprecatedTokenEndpointUsed(properties); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + if (config.credential() != null && !config.credential().isEmpty()) { + // keep track of the start time for token refresh + this.startTimeMillis = System.currentTimeMillis(); + this.authResponse = + OAuth2Util.fetchToken( + initClient, + headers, + config.credential(), + config.scope(), + config.oauth2ServerUri(), + config.optionalOAuthParams()); + return OAuth2Util.AuthSession.fromTokenResponse( + initClient, null, authResponse, startTimeMillis, session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + initClient, null, config.token(), null, session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession catalogSession( + RESTClient sharedClient, Map properties) { + this.client = sharedClient; + this.sessionCache = new AuthSessionCache(sessionTimeout(properties)); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + keepRefreshed(config.keepRefreshed()); + // authResponse comes from the init phase + if (authResponse != null) { + return OAuth2Util.AuthSession.fromTokenResponse( + client, refreshExecutor(), authResponse, startTimeMillis, session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), config.token(), config.expiresAtMillis(), session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession contextualSession( + SessionCatalog.SessionContext context, AuthSession parent) { + return maybeCreateChildSession( + context.credentials(), + context.properties(), + ignored -> context.sessionId(), + (OAuth2Util.AuthSession) parent); + } + + @Override + public OAuth2Util.AuthSession tableSession( + TableIdentifier table, Map properties, AuthSession parent) { + return maybeCreateChildSession( + Maps.filterKeys(properties, TABLE_SESSION_ALLOW_LIST::contains), + properties, + properties::get, + (OAuth2Util.AuthSession) parent); + } + + @Override + public void close() { + try { + super.close(); + } finally { + AuthSessionCache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.close(); + } + } + } + + protected OAuth2Util.AuthSession maybeCreateChildSession( + Map credentials, + Map properties, + Function cacheKeyFunc, + OAuth2Util.AuthSession parent) { + if (credentials != null) { + // use the bearer token without exchanging + if (credentials.containsKey(OAuth2Properties.TOKEN)) { + String token = credentials.get(OAuth2Properties.TOKEN); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.TOKEN), + k -> newSessionFromAccessToken(token, properties, parent)); + } + + if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { + // fetch a token using the client credentials flow + String credential = credentials.get(OAuth2Properties.CREDENTIAL); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.CREDENTIAL), + k -> newSessionFromCredential(credential, parent)); + } + + for (String tokenType : TOKEN_PREFERENCE_ORDER) { + if (credentials.containsKey(tokenType)) { + // exchange the token for an access token using the token exchange flow + String token = credentials.get(tokenType); + return sessionCache.cachedSession( + cacheKeyFunc.apply(tokenType), + k -> newSessionFromTokenExchange(token, tokenType, parent)); + } + } + } + + return parent; + } + + protected OAuth2Util.AuthSession newSessionFromAccessToken( + String token, Map properties, OAuth2Util.AuthSession parent) { + Long expiresAtMillis = AuthConfig.fromProperties(properties).expiresAtMillis(); + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), token, expiresAtMillis, parent); + } + + protected OAuth2Util.AuthSession newSessionFromCredential( + String credential, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromCredential(client, refreshExecutor(), credential, parent); + } + + protected OAuth2Util.AuthSession newSessionFromTokenExchange( + String token, String tokenType, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromTokenExchange( + client, refreshExecutor(), token, tokenType, parent); + } + + private static void warnIfDeprecatedTokenEndpointUsed(Map properties) { + if (usesDeprecatedTokenEndpoint(properties)) { + String credential = properties.get(OAuth2Properties.CREDENTIAL); + String initToken = properties.get(OAuth2Properties.TOKEN); + boolean hasCredential = credential != null && !credential.isEmpty(); + boolean hasInitToken = initToken != null; + if (hasInitToken || hasCredential) { + LOG.warn( + "Iceberg REST client is missing the OAuth2 server URI configuration and defaults to {}/{}. " + + "This automatic fallback will be removed in a future Iceberg release." + + "It is recommended to configure the OAuth2 endpoint using the '{}' property to be prepared. " + + "This warning will disappear if the OAuth2 endpoint is explicitly configured. " + + "See https://github.com/apache/iceberg/issues/10537", + RESTUtil.stripTrailingSlash(properties.get(CatalogProperties.URI)), + ResourcePaths.tokens(), + OAuth2Properties.OAUTH2_SERVER_URI); + } + } + } + + private static boolean usesDeprecatedTokenEndpoint(Map properties) { + if (properties.containsKey(OAuth2Properties.OAUTH2_SERVER_URI)) { + String oauth2ServerUri = properties.get(OAuth2Properties.OAUTH2_SERVER_URI); + boolean relativePath = !oauth2ServerUri.startsWith("http"); + boolean sameHost = oauth2ServerUri.startsWith(properties.get(CatalogProperties.URI)); + return relativePath || sameHost; + } + return true; + } + + private static Duration sessionTimeout(Map props) { + return Duration.ofMillis( + PropertyUtil.propertyAsLong( + props, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT)); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index 1757ae653cc9..2bcf592d2aab 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -43,6 +43,9 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.rest.ErrorHandlers; +import org.apache.iceberg.rest.HTTPHeaders; +import org.apache.iceberg.rest.HTTPRequest; +import org.apache.iceberg.rest.ImmutableHTTPRequest; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.RESTUtil; import org.apache.iceberg.rest.ResourcePaths; @@ -451,18 +454,26 @@ static Long expiresAtMillis(String token) { } /** Class to handle authorization headers and token refresh. */ - public static class AuthSession { + public static class AuthSession implements org.apache.iceberg.rest.auth.AuthSession { private static int tokenRefreshNumRetries = 5; private static final long MAX_REFRESH_WINDOW_MILLIS = 300_000; // 5 minutes private static final long MIN_REFRESH_WAIT_MILLIS = 10; private volatile Map headers; private volatile AuthConfig config; - public AuthSession(Map baseHeaders, AuthConfig config) { - this.headers = RESTUtil.merge(baseHeaders, authHeaders(config.token())); + public AuthSession(Map headers, AuthConfig config) { + this.headers = ImmutableMap.copyOf(headers); this.config = config; } + @Override + public HTTPRequest authenticate(HTTPRequest request) { + HTTPHeaders newHeaders = request.headers().putIfAbsent(HTTPHeaders.of(headers())); + return newHeaders.equals(request.headers()) + ? request + : ImmutableHTTPRequest.builder().from(request).headers(newHeaders).build(); + } + public Map headers() { return headers; } @@ -487,6 +498,11 @@ public synchronized void stopRefreshing() { this.config = ImmutableAuthConfig.copyOf(config).withKeepRefreshed(false); } + @Override + public void close() { + stopRefreshing(); + } + public String credential() { return config.credential(); } @@ -647,7 +663,7 @@ public static AuthSession fromAccessToken( AuthSession parent) { AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(token)), AuthConfig.builder() .from(parent.config()) .token(token) @@ -727,7 +743,7 @@ private static AuthSession fromTokenResponse( } AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(response.token())), AuthConfig.builder() .from(parent.config()) .token(response.token()) diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java new file mode 100644 index 000000000000..2b443e0ea5c1 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.List; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.iceberg.util.ThreadPools; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An {@link AuthManager} that provides machinery for refreshing authentication data asynchronously, + * using a background thread pool. + */ +public abstract class RefreshingAuthManager implements AuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(RefreshingAuthManager.class); + + private final String executorNamePrefix; + private boolean keepRefreshed = true; + private volatile ScheduledExecutorService refreshExecutor; + + protected RefreshingAuthManager(String executorNamePrefix) { + this.executorNamePrefix = executorNamePrefix; + } + + public void keepRefreshed(boolean keep) { + this.keepRefreshed = keep; + } + + @Override + public void close() { + ScheduledExecutorService service = refreshExecutor; + this.refreshExecutor = null; + if (service != null) { + List tasks = service.shutdownNow(); + tasks.forEach( + task -> { + if (task instanceof Future) { + ((Future) task).cancel(true); + } + }); + + try { + if (!service.awaitTermination(1, TimeUnit.MINUTES)) { + LOG.warn("Timed out waiting for refresh executor to terminate"); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for refresh executor to terminate", e); + Thread.currentThread().interrupt(); + } + } + } + + @Nullable + protected ScheduledExecutorService refreshExecutor() { + if (!keepRefreshed) { + return null; + } + if (refreshExecutor == null) { + synchronized (this) { + if (refreshExecutor == null) { + this.refreshExecutor = ThreadPools.newScheduledPool(executorNamePrefix, 1); + } + } + } + return refreshExecutor; + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java index 21bd8c1b2963..d49f398d7a47 100644 --- a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java @@ -43,6 +43,45 @@ public void after() { System.setErr(standardErr); } + @Test + void oauth2Explicit() { + try (AuthManager manager = + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_OAUTH2))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromToken() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.TOKEN, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property token was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromCredential() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.CREDENTIAL, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property credential was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + @Test void noop() { try (AuthManager manager = AuthManagers.loadAuthManager("test", Map.of())) { diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java new file mode 100644 index 000000000000..52b742d2c536 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class TestAuthSessionCache { + + @Test + void cachedHitsAndMisses() { + AuthSessionCache cache = + new AuthSessionCache(Duration.ofHours(1), Runnable::run, System::nanoTime); + AuthSession session1 = Mockito.mock(AuthSession.class); + AuthSession session2 = Mockito.mock(AuthSession.class); + + @SuppressWarnings("unchecked") + Function loader = Mockito.mock(Function.class); + Mockito.when(loader.apply("key1")).thenReturn(session1); + Mockito.when(loader.apply("key2")).thenReturn(session2); + + AuthSession session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + session = cache.cachedSession("key2", loader); + assertThat(session).isNotNull().isSameAs(session2); + + session = cache.cachedSession("key2", loader); + assertThat(session).isNotNull().isSameAs(session2); + + Mockito.verify(loader, times(1)).apply("key1"); + Mockito.verify(loader, times(1)).apply("key2"); + + assertThat(cache.sessionCache().asMap()).hasSize(2); + cache.close(); + assertThat(cache.sessionCache().asMap()).isEmpty(); + + Mockito.verify(session1).close(); + Mockito.verify(session2).close(); + } + + @Test + void cacheEviction() { + AtomicLong ticker = new AtomicLong(0); + AuthSessionCache cache = new AuthSessionCache(Duration.ofHours(1), Runnable::run, ticker::get); + AuthSession session1 = Mockito.mock(AuthSession.class); + + @SuppressWarnings("unchecked") + Function loader = Mockito.mock(Function.class); + Mockito.when(loader.apply("key1")).thenReturn(session1); + + AuthSession session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + Mockito.verify(loader, times(1)).apply("key1"); + Mockito.verify(session1, never()).close(); + + ticker.set(TimeUnit.HOURS.toNanos(1)); + cache.sessionCache().cleanUp(); + Mockito.verify(session1).close(); + + cache.close(); + } +}