Skip to content

Commit

Permalink
feat: leverage rust aws sdk to get credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
timvw committed Mar 27, 2024
1 parent 7cf8129 commit d713ba3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 31 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ exclude = [
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
aws-config = "0.55"
aws-sdk-glue = "0.28"
aws-types = "0.55"
aws-config = "1.1"
aws-sdk-glue = "1.22"
aws-types = "1.1"
aws-credential-types = "1.1"
chrono = "0.4"
clap = { version = "4.5", features = ["derive"] }
datafusion = { version = "36.0", features = ["avro"] }
Expand Down
25 changes: 0 additions & 25 deletions src/args.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
//use crate::GlobbingPath;
use aws_sdk_glue::Client;
use aws_types::SdkConfig;
use chrono::{DateTime, Utc};
use clap::Parser;
use datafusion::common::{DataFusionError, Result};
use regex::Regex;
use std::collections::HashMap;
use std::env;
use url::Url;

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -83,29 +81,6 @@ impl Args {
}*/
}

#[allow(dead_code)]
async fn get_sdk_config(args: &Args) -> SdkConfig {
set_aws_profile_when_needed(args);
set_aws_region_when_needed();

aws_config::load_from_env().await
}

#[allow(dead_code)]
fn set_aws_profile_when_needed(args: &Args) {
if let Some(aws_profile) = &args.profile {
env::set_var("AWS_PROFILE", aws_profile);
}
}

#[allow(dead_code)]
fn set_aws_region_when_needed() {
match env::var("AWS_DEFAULT_REGION") {
Ok(_) => {}
Err(_) => env::set_var("AWS_DEFAULT_REGION", "eu-central-1"),
}
}

#[allow(dead_code)]
async fn get_storage_location(
sdk_config: &SdkConfig,
Expand Down
44 changes: 41 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use clap::Parser;
use datafusion::catalog::TableReference;
use std::env;
use std::sync::Arc;

Check warning on line 4 in src/main.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/qv/qv/src/main.rs
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;

use aws_types::SdkConfig;

use datafusion::common::{DataFusionError, Result};
use datafusion::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl};
Expand All @@ -19,13 +23,24 @@ use opendal::services::S3;
use opendal::Operator;
use url::Url;

Check warning on line 24 in src/main.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/qv/qv/src/main.rs

fn init_s3_operator_via_builder(url: &Url) -> Result<Operator> {
async fn init_s3_operator_via_builder(url: &Url, sdk_config: &SdkConfig) -> Result<Operator> {

let cp = sdk_config.credentials_provider().unwrap();
let creds = cp.provide_credentials().await
.map_err(|e| DataFusionError::Execution(format!("Failed to get credentials: {e}")))?;

let mut builder = S3::default();
let bucket_name = url.host_str().unwrap();
builder.bucket(bucket_name);

//https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
builder.access_key_id(creds.access_key_id());
builder.secret_access_key(creds.secret_access_key());

if let Some(session_token) = creds.session_token() {
builder.security_token(session_token);
}

//https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
if let Ok(aws_endpoint_url) = env::var("AWS_ENDPOINT_URL") {
builder.endpoint(&aws_endpoint_url);
}
Expand All @@ -45,10 +60,13 @@ async fn main() -> Result<()> {

let data_path = &args.path.clone();

Check warning on line 61 in src/main.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/qv/qv/src/main.rs

let sdk_config = get_sdk_config(&args).await;


if data_path.starts_with("s3://") {
let s3_url = Url::parse(data_path)
.map_err(|e| DataFusionError::Execution(format!("Failed to parse url, {e}")))?;
let op = init_s3_operator_via_builder(&s3_url)?;
let op = init_s3_operator_via_builder(&s3_url, &sdk_config).await?;
ctx.runtime_env()
.register_object_store(&s3_url, Arc::new(OpendalStore::new(op)));
}
Expand All @@ -75,3 +93,23 @@ async fn main() -> Result<()> {

Ok(())
}

async fn get_sdk_config(args: &Args) -> SdkConfig {
set_aws_profile_when_needed(args);
set_aws_region_when_needed();

aws_config::load_defaults(BehaviorVersion::latest()).await
}

fn set_aws_profile_when_needed(args: &Args) {
if let Some(aws_profile) = &args.profile {
env::set_var("AWS_PROFILE", aws_profile);
}
}

fn set_aws_region_when_needed() {
match env::var("AWS_DEFAULT_REGION") {
Ok(_) => {}
Err(_) => env::set_var("AWS_DEFAULT_REGION", "eu-central-1"),
}
}

0 comments on commit d713ba3

Please sign in to comment.