diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 4df15407f1..4f92201e0c 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -244,11 +244,55 @@ async fn build_s3_client( let mut anonymous = config.anonymous; + let cached_creds = if let Some(credentials_cache) = credentials_cache { + let creds = credentials_cache.provide_cached_credentials().await; + creds.ok() + } else { + None + }; + + let provider = if let Some(cached_creds) = cached_creds { + let provider = CredentialsProviderChain::first_try("different_region_cache", cached_creds) + .or_default_provider() + .await; + Some(aws_credential_types::provider::SharedCredentialsProvider::new(provider)) + } else if config.access_key.is_some() && config.key_id.is_some() { + let creds = Credentials::from_keys( + config.key_id.clone().unwrap(), + config.access_key.clone().unwrap(), + config.session_token.clone(), + ); + Some(aws_credential_types::provider::SharedCredentialsProvider::new(creds)) + } else if config.access_key.is_some() || config.key_id.is_some() { + return Err(super::Error::InvalidArgument { + msg: "Must provide both access_key and key_id when building S3-Like Client".to_string(), + }); + } else { + None + }; + let conf: SdkConfig = if anonymous { - aws_config::SdkConfig::builder().build() + let mut builder = aws_config::SdkConfig::builder(); + builder.set_credentials_provider(provider); + builder.build() } else { - aws_config::load_from_env().await + let loader = aws_config::from_env(); + // Set region now to avoid imds + let loader = if let Some(region) = &config.region_name { + loader.region(Region::new(region.to_owned())) + } else { + loader + }; + // Set creds now to avoid imds + let loader = if let Some(provider) = provider { + loader.credentials_provider(provider) + } else { + loader + }; + + loader.load().await }; + let builder = aws_sdk_s3::config::Builder::from(&conf); let builder = match &config.endpoint_url { None => builder, @@ -294,33 +338,6 @@ async fn build_s3_client( .build(); let builder = builder.timeout_config(timeout_config); - let cached_creds = if let Some(credentials_cache) = credentials_cache { - let creds = credentials_cache.provide_cached_credentials().await; - creds.ok() - } else { - None - }; - - let builder = if let Some(cached_creds) = cached_creds { - let provider = CredentialsProviderChain::first_try("different_region_cache", cached_creds) - .or_default_provider() - .await; - builder.credentials_provider(provider) - } else if config.access_key.is_some() && config.key_id.is_some() { - let creds = Credentials::from_keys( - config.key_id.clone().unwrap(), - config.access_key.clone().unwrap(), - config.session_token.clone(), - ); - builder.credentials_provider(creds) - } else if config.access_key.is_some() || config.key_id.is_some() { - return Err(super::Error::InvalidArgument { - msg: "Must provide both access_key and key_id when building S3-Like Client".to_string(), - }); - } else { - builder - }; - let builder_copy = builder.clone(); let s3_conf = builder.build(); const CRED_TRIES: u64 = 4;