diff --git a/daft/daft.pyi b/daft/daft.pyi index 91171522c2..5bb906f67d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -426,6 +426,7 @@ class S3Config: verify_ssl: bool check_hostname_ssl: bool requester_pays: bool | None + force_virtual_addressing: bool | None def __init__( self, @@ -445,6 +446,7 @@ class S3Config: verify_ssl: bool | None = None, check_hostname_ssl: bool | None = None, requester_pays: bool | None = None, + force_virtual_addressing: bool | None = None, ): ... def replace( self, @@ -464,6 +466,7 @@ class S3Config: verify_ssl: bool | None = None, check_hostname_ssl: bool | None = None, requester_pays: bool | None = None, + force_virtual_addressing: bool | None = None, ) -> S3Config: """Replaces values if provided, returning a new S3Config""" ... diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index d80452856f..ecfeb19c25 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -25,6 +25,7 @@ use crate::config; /// verify_ssl: Whether or not to verify ssl certificates, which will access S3 without checking if the certs are valid, defaults to True /// check_hostname_ssl: Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True /// requester_pays: Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False +/// force_virtual_addressing: Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False /// /// Example: /// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx")) @@ -186,6 +187,7 @@ impl S3Config { verify_ssl: Option, check_hostname_ssl: Option, requester_pays: Option, + force_virtual_addressing: Option, ) -> Self { let def = crate::S3Config::default(); S3Config { @@ -208,6 +210,8 @@ impl S3Config { verify_ssl: verify_ssl.unwrap_or(def.verify_ssl), check_hostname_ssl: check_hostname_ssl.unwrap_or(def.check_hostname_ssl), requester_pays: requester_pays.unwrap_or(def.requester_pays), + force_virtual_addressing: force_virtual_addressing + .unwrap_or(def.force_virtual_addressing), }, } } @@ -231,6 +235,7 @@ impl S3Config { verify_ssl: Option, check_hostname_ssl: Option, requester_pays: Option, + force_virtual_addressing: Option, ) -> Self { S3Config { config: crate::S3Config { @@ -252,6 +257,8 @@ impl S3Config { verify_ssl: verify_ssl.unwrap_or(self.config.verify_ssl), check_hostname_ssl: check_hostname_ssl.unwrap_or(self.config.check_hostname_ssl), requester_pays: requester_pays.unwrap_or(self.config.requester_pays), + force_virtual_addressing: force_virtual_addressing + .unwrap_or(self.config.force_virtual_addressing), }, } } @@ -355,6 +362,12 @@ impl S3Config { pub fn requester_pays(&self) -> PyResult> { Ok(Some(self.config.requester_pays)) } + + /// AWS force virtual addressing + #[getter] + pub fn force_virtual_addressing(&self) -> PyResult> { + Ok(Some(self.config.force_virtual_addressing)) + } } #[pymethods] diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 7641a19dbd..0f66d0285b 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -22,6 +22,7 @@ pub struct S3Config { pub verify_ssl: bool, pub check_hostname_ssl: bool, pub requester_pays: bool, + pub force_virtual_addressing: bool, } impl S3Config { @@ -61,6 +62,10 @@ impl S3Config { res.push(format!("Verify SSL = {}", self.verify_ssl)); res.push(format!("Check hostname SSL = {}", self.check_hostname_ssl)); res.push(format!("Requester pays = {}", self.requester_pays)); + res.push(format!( + "Force Virtual Addressing = {}", + self.force_virtual_addressing + )); res } } @@ -86,6 +91,7 @@ impl Default for S3Config { verify_ssl: true, check_hostname_ssl: true, requester_pays: false, + force_virtual_addressing: false, } } } @@ -110,7 +116,8 @@ impl Display for S3Config { use_ssl: {}, verify_ssl: {}, check_hostname_ssl: {} - requester_pays: {}", + requester_pays: {} + force_virtual_addressing: {}", self.region_name, self.endpoint_url, self.key_id, @@ -126,7 +133,8 @@ impl Display for S3Config { self.use_ssl, self.verify_ssl, self.check_hostname_ssl, - self.requester_pays + self.requester_pays, + self.force_virtual_addressing ) } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index a31cf79a85..1a143f9040 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -307,8 +307,15 @@ async fn build_s3_client( let builder = aws_sdk_s3::config::Builder::from(&conf); let builder = match &config.endpoint_url { None => builder, - Some(endpoint) => builder.endpoint_url(endpoint).force_path_style(true), + Some(endpoint) => builder.endpoint_url(endpoint), }; + + let builder = if config.endpoint_url.is_some() && !config.force_virtual_addressing { + builder.force_path_style(true) + } else { + builder.force_path_style(false) + }; + let builder = if let Some(region) = &config.region_name { builder.region(Region::new(region.to_owned())) } else {