Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure: Support vended credentials refresh in ADLSFileIO. #11577

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,40 @@
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import org.apache.iceberg.azure.adlsv2.VendedAzureSasCredentialProvider;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.base.Strings;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.util.PropertyUtil;

public class AzureProperties implements Serializable {
public static final String ADLS_SAS_TOKEN_PREFIX = "adls.sas-token.";
public static final String ADLS_SAS_TOKEN_EXPIRE_AT_MS_PREFIX = "adls.sas-token-expire-at-ms.";
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
public static final String ADLS_CONNECTION_STRING_PREFIX = "adls.connection-string.";
public static final String ADLS_READ_BLOCK_SIZE = "adls.read.block-size-bytes";
public static final String ADLS_WRITE_BLOCK_SIZE = "adls.write.block-size-bytes";
public static final String ADLS_SHARED_KEY_ACCOUNT_NAME = "adls.auth.shared-key.account.name";
public static final String ADLS_SHARED_KEY_ACCOUNT_KEY = "adls.auth.shared-key.account.key";

/**
* When set, the {@link org.apache.iceberg.azure.adlsv2.VendedAzureSasCredentialProvider} will be
* used to fetch and refresh vended credentials from this endpoint.
*/
public static final String REFRESH_CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint";
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved

/** Controls whether vended credentials should be refreshed or not. Defaults to true. */
public static final String REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled";
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved

private Map<String, String> adlsSasTokens = Collections.emptyMap();
private Map<String, String> adlsConnectionStrings = Collections.emptyMap();
private Map.Entry<String, String> namedKeyCreds;
private Integer adlsReadBlockSize;
private Long adlsWriteBlockSize;

private VendedAzureSasCredentialProvider vendedAzureSasCredentialProvider;
private String refreshCredentialsEndpoint;
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
private boolean refreshCredentialsEnabled;
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved

public AzureProperties() {}

public AzureProperties(Map<String, String> properties) {
Expand All @@ -67,6 +83,16 @@ public AzureProperties(Map<String, String> properties) {
if (properties.containsKey(ADLS_WRITE_BLOCK_SIZE)) {
this.adlsWriteBlockSize = Long.parseLong(properties.get(ADLS_WRITE_BLOCK_SIZE));
}
this.refreshCredentialsEndpoint = properties.get(REFRESH_CREDENTIALS_ENDPOINT);
this.refreshCredentialsEnabled =
PropertyUtil.propertyAsBoolean(properties, REFRESH_CREDENTIALS_ENABLED, true);
Map<String, String> credentialProviderProperties = Maps.newHashMap(properties);
nastra marked this conversation as resolved.
Show resolved Hide resolved
if (refreshCredentialsEnabled && !Strings.isNullOrEmpty(refreshCredentialsEndpoint)) {
credentialProviderProperties.put(
VendedAzureSasCredentialProvider.URI, refreshCredentialsEndpoint);
this.vendedAzureSasCredentialProvider =
new VendedAzureSasCredentialProvider(credentialProviderProperties);
}
}

public Optional<Integer> adlsReadBlockSize() {
Expand All @@ -90,7 +116,9 @@ public Optional<Long> adlsWriteBlockSize() {
*/
public void applyClientConfiguration(String account, DataLakeFileSystemClientBuilder builder) {
String sasToken = adlsSasTokens.get(account);
if (sasToken != null && !sasToken.isEmpty()) {
if (refreshCredentialsEnabled && !Strings.isNullOrEmpty(refreshCredentialsEndpoint)) {
builder.credential(vendedAzureSasCredentialProvider.credentialForAccount(account));
} else if (sasToken != null && !sasToken.isEmpty()) {
builder.sasToken(sasToken);
} else if (namedKeyCreds != null) {
builder.credential(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.azure.adlsv2;

import com.azure.core.credential.AzureSasCredential;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
nastra marked this conversation as resolved.
Show resolved Hide resolved

public class AzureSasCredentialRefresher {
nastra marked this conversation as resolved.
Show resolved Hide resolved
private final Supplier<Pair<String, Long>> sasTokenWithExpirationSupplier;
private final ScheduledExecutorService refreshExecutor;
private final AzureSasCredential azureSasCredential;

private static final long MAX_REFRESH_WINDOW_MILLIS = 300_000; // 5 minutes;
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
private static final long MIN_REFRESH_WAIT_MILLIS = 10;

public AzureSasCredentialRefresher(
Supplier<Pair<String, Long>> sasTokenWithExpirationSupplier,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this actually need to be a supplier given that it's being immediately used in L40?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the supplier is being used during initialization and also during each scheduled refresh to fetch the new credentials. Supplier holds the logic to fetch new credentials from the API endpoint, since we are going to use it multiple times I modelled it as supplier instead of single method call. Please suggest if there is a cleaner way to achieve the same.

ScheduledExecutorService refreshExecutor) {
this.sasTokenWithExpirationSupplier = sasTokenWithExpirationSupplier;
this.refreshExecutor = refreshExecutor;
Pair<String, Long> sasTokenWithExpiration = sasTokenWithExpirationSupplier.get();
this.azureSasCredential = new AzureSasCredential(sasTokenWithExpiration.getLeft());
scheduleRefresh(sasTokenWithExpiration.getRight());
}

public AzureSasCredential azureSasCredential() {
return this.azureSasCredential;
}

private void scheduleRefresh(Long expireAtMillis) {
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
this.refreshExecutor.schedule(
() -> {
Pair<String, Long> sasTokenWithExpiration = sasTokenWithExpirationSupplier.get();
azureSasCredential.update(sasTokenWithExpiration.getLeft());
if (sasTokenWithExpiration.getRight() != null) {
nastra marked this conversation as resolved.
Show resolved Hide resolved
this.scheduleRefresh(sasTokenWithExpiration.getRight());
}
},
refreshDelay(expireAtMillis),
TimeUnit.MILLISECONDS);
}

private long refreshDelay(Long expireAtMillis) {
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
long expireInMillis = expireAtMillis - System.currentTimeMillis();
// how much ahead of time to start the request to allow it to complete
long refreshWindowMillis = Math.min(expireInMillis / 10, MAX_REFRESH_WINDOW_MILLIS);
// how much time to wait before expiration
long waitIntervalMillis = expireInMillis - refreshWindowMillis;
// how much time to actually wait
return Math.max(waitIntervalMillis, MIN_REFRESH_WAIT_MILLIS);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.azure.adlsv2;

import com.azure.core.credential.AzureSasCredential;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.iceberg.azure.AzureProperties;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.rest.ErrorHandlers;
import org.apache.iceberg.rest.HTTPClient;
import org.apache.iceberg.rest.RESTClient;
import org.apache.iceberg.rest.RESTUtil;
import org.apache.iceberg.rest.auth.OAuth2Properties;
import org.apache.iceberg.rest.auth.OAuth2Util;
import org.apache.iceberg.rest.credentials.Credential;
import org.apache.iceberg.rest.responses.LoadCredentialsResponse;
import org.apache.iceberg.util.SerializableMap;
import org.apache.iceberg.util.ThreadPools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VendedAzureSasCredentialProvider implements Serializable, AutoCloseable {
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
private static final Logger LOG = LoggerFactory.getLogger(VendedAzureSasCredentialProvider.class);

private static final String THREAD_PREFIX = "adls-fileio-credential-refresh";
ChaladiMohanVamsi marked this conversation as resolved.
Show resolved Hide resolved
public static final String URI = "credentials.uri";

private final SerializableMap<String, String> properties;
private transient volatile Map<String, AzureSasCredentialRefresher>
azureSasCredentialRefresherMap;
private transient volatile RESTClient client;
private transient volatile ScheduledExecutorService refreshExecutor;

public VendedAzureSasCredentialProvider(Map<String, String> properties) {
Preconditions.checkArgument(null != properties, "Invalid properties: null");
Preconditions.checkArgument(null != properties.get(URI), "Invalid URI: null");
this.properties = SerializableMap.copyOf(properties);
azureSasCredentialRefresherMap = Maps.newHashMap();
}

public AzureSasCredential credentialForAccount(String storageAccount) {
Map<String, AzureSasCredentialRefresher> refresherForAccountMap =
azureSasCredentialRefresherMap();
if (refresherForAccountMap.containsKey(storageAccount)) {
return refresherForAccountMap.get(storageAccount).azureSasCredential();
} else {
AzureSasCredentialRefresher azureSasCredentialRefresher =
new AzureSasCredentialRefresher(
() -> this.sasTokenWithExpiration(storageAccount), credentialRefreshExecutor());
refresherForAccountMap.put(storageAccount, azureSasCredentialRefresher);
return azureSasCredentialRefresher.azureSasCredential();
}
}

private Pair<String, Long> sasTokenWithExpiration(String storageAccount) {
LoadCredentialsResponse response = fetchCredentials();
List<Credential> adlsCredentials =
response.credentials().stream()
.filter(c -> c.prefix().contains(storageAccount))
.collect(Collectors.toList());
Preconditions.checkState(
!adlsCredentials.isEmpty(),
String.format("Invalid ADLS Credentials for storage-account %s: empty", storageAccount));
Preconditions.checkState(
adlsCredentials.size() == 1,
"Invalid ADLS Credentials: only one ADLS credential should exist per storage-account");

Credential adlsCredential = adlsCredentials.get(0);
checkCredential(adlsCredential, AzureProperties.ADLS_SAS_TOKEN_PREFIX + storageAccount);
checkCredential(
adlsCredential, AzureProperties.ADLS_SAS_TOKEN_EXPIRE_AT_MS_PREFIX + storageAccount);

String updatedSasToken =
adlsCredential.config().get(AzureProperties.ADLS_SAS_TOKEN_PREFIX + storageAccount);
Long tokenExpiresAtMillis =
Long.parseLong(
adlsCredential
.config()
.get(AzureProperties.ADLS_SAS_TOKEN_EXPIRE_AT_MS_PREFIX + storageAccount));

return Pair.of(updatedSasToken, tokenExpiresAtMillis);
}

private Map<String, AzureSasCredentialRefresher> azureSasCredentialRefresherMap() {
if (this.azureSasCredentialRefresherMap == null) {
synchronized (this) {
if (this.azureSasCredentialRefresherMap == null) {
this.azureSasCredentialRefresherMap = Maps.newHashMap();
}
}
}
return this.azureSasCredentialRefresherMap;
}

private ScheduledExecutorService credentialRefreshExecutor() {
if (this.refreshExecutor == null) {
synchronized (this) {
if (this.refreshExecutor == null) {
this.refreshExecutor = ThreadPools.newScheduledPool(THREAD_PREFIX, 1);
}
}
}
return this.refreshExecutor;
}

private RESTClient httpClient() {
if (null == client) {
synchronized (this) {
if (null == client) {
client = HTTPClient.builder(properties).uri(properties.get(URI)).build();
}
}
}

return client;
}

private LoadCredentialsResponse fetchCredentials() {
Map<String, String> headers =
RESTUtil.merge(
configHeaders(properties),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate why this is needed here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to implement the similar behaviour present in RESTSessionCatalog, where catalog can be configured to pass explicit headers to server by setting the configuration with header. prefix.

OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)));
return httpClient()
.get(
properties.get(URI),
null,
LoadCredentialsResponse.class,
headers,
ErrorHandlers.defaultErrorHandler());
}

private Map<String, String> configHeaders(Map<String, String> props) {
return RESTUtil.extractPrefixMap(props, "header.");
}

private void checkCredential(Credential credential, String property) {
Preconditions.checkState(
credential.config().containsKey(property),
"Invalid ADLS Credentials: %s not set",
property);
}

@Override
public void close() {
IOUtils.closeQuietly(client);
shutdownRefreshExecutor();
}

private void shutdownRefreshExecutor() {
if (refreshExecutor != null) {
ScheduledExecutorService service = refreshExecutor;
this.refreshExecutor = null;

List<Runnable> tasks = service.shutdownNow();
tasks.forEach(
task -> {
if (task instanceof Future) {
((Future<?>) task).cancel(true);
}
});

try {
if (!service.awaitTermination(1, TimeUnit.MINUTES)) {
LOG.warn("Timed out waiting for refresh executor to terminate");
}
} catch (InterruptedException e) {
LOG.warn("Interrupted while waiting for refresh executor to terminate", e);
Thread.currentThread().interrupt();
}
}
}
}
Loading