diff --git a/Cargo.toml b/Cargo.toml index 8d1ac2e..d7b3a22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,9 +19,10 @@ exclude = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -aws-config = "0.55" -aws-sdk-glue = "0.28" -aws-types = "0.55" +aws-config = "1.1" +aws-sdk-glue = "1.22" +aws-types = "1.1" +aws-credential-types = "1.1" chrono = "0.4" clap = { version = "4.5", features = ["derive"] } datafusion = { version = "36.0", features = ["avro"] } diff --git a/src/args.rs b/src/args.rs index 0524b41..e598080 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,4 +1,3 @@ -//use crate::GlobbingPath; use aws_sdk_glue::Client; use aws_types::SdkConfig; use chrono::{DateTime, Utc}; @@ -6,7 +5,6 @@ use clap::Parser; use datafusion::common::{DataFusionError, Result}; use regex::Regex; use std::collections::HashMap; -use std::env; use url::Url; #[derive(Parser, Debug)] @@ -83,29 +81,6 @@ impl Args { }*/ } -#[allow(dead_code)] -async fn get_sdk_config(args: &Args) -> SdkConfig { - set_aws_profile_when_needed(args); - set_aws_region_when_needed(); - - aws_config::load_from_env().await -} - -#[allow(dead_code)] -fn set_aws_profile_when_needed(args: &Args) { - if let Some(aws_profile) = &args.profile { - env::set_var("AWS_PROFILE", aws_profile); - } -} - -#[allow(dead_code)] -fn set_aws_region_when_needed() { - match env::var("AWS_DEFAULT_REGION") { - Ok(_) => {} - Err(_) => env::set_var("AWS_DEFAULT_REGION", "eu-central-1"), - } -} - #[allow(dead_code)] async fn get_storage_location( sdk_config: &SdkConfig, diff --git a/src/main.rs b/src/main.rs index 021d3a7..a50e613 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,10 @@ use clap::Parser; use datafusion::catalog::TableReference; use std::env; use std::sync::Arc; +use aws_config::BehaviorVersion; +use aws_credential_types::provider::ProvideCredentials; + +use aws_types::SdkConfig; use datafusion::common::{DataFusionError, Result}; use datafusion::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; @@ -19,13 +23,24 @@ use opendal::services::S3; use opendal::Operator; use url::Url; -fn init_s3_operator_via_builder(url: &Url) -> Result { +async fn init_s3_operator_via_builder(url: &Url, sdk_config: &SdkConfig) -> Result { + + let cp = sdk_config.credentials_provider().unwrap(); + let creds = cp.provide_credentials().await + .map_err(|e| DataFusionError::Execution(format!("Failed to get credentials: {e}")))?; + let mut builder = S3::default(); let bucket_name = url.host_str().unwrap(); builder.bucket(bucket_name); - //https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html + builder.access_key_id(creds.access_key_id()); + builder.secret_access_key(creds.secret_access_key()); + if let Some(session_token) = creds.session_token() { + builder.security_token(session_token); + } + + //https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html if let Ok(aws_endpoint_url) = env::var("AWS_ENDPOINT_URL") { builder.endpoint(&aws_endpoint_url); } @@ -45,10 +60,13 @@ async fn main() -> Result<()> { let data_path = &args.path.clone(); + let sdk_config = get_sdk_config(&args).await; + + if data_path.starts_with("s3://") { let s3_url = Url::parse(data_path) .map_err(|e| DataFusionError::Execution(format!("Failed to parse url, {e}")))?; - let op = init_s3_operator_via_builder(&s3_url)?; + let op = init_s3_operator_via_builder(&s3_url, &sdk_config).await?; ctx.runtime_env() .register_object_store(&s3_url, Arc::new(OpendalStore::new(op))); } @@ -75,3 +93,23 @@ async fn main() -> Result<()> { Ok(()) } + +async fn get_sdk_config(args: &Args) -> SdkConfig { + set_aws_profile_when_needed(args); + set_aws_region_when_needed(); + + aws_config::load_defaults(BehaviorVersion::latest()).await +} + +fn set_aws_profile_when_needed(args: &Args) { + if let Some(aws_profile) = &args.profile { + env::set_var("AWS_PROFILE", aws_profile); + } +} + +fn set_aws_region_when_needed() { + match env::var("AWS_DEFAULT_REGION") { + Ok(_) => {} + Err(_) => env::set_var("AWS_DEFAULT_REGION", "eu-central-1"), + } +}