diff --git a/Cargo.lock b/Cargo.lock index d605a7b9f8..983385050d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -719,7 +719,7 @@ dependencies = [ "http 0.2.12", "http-body 0.4.6", "hyper 0.14.30", - "hyper-tls", + "hyper-tls 0.5.0", "pin-project-lite", "tokio", "tower", @@ -898,7 +898,7 @@ dependencies = [ "pin-project", "quick-xml", "rand 0.8.5", - "reqwest", + "reqwest 0.11.27", "rustc_version", "serde", "serde_json", @@ -1494,6 +1494,7 @@ dependencies = [ "chrono", "common-error", "common-py-serde", + "derive_more", "pyo3", "secrecy", "serde", @@ -1755,7 +1756,7 @@ dependencies = [ "bitflags 2.6.0", "crossterm_winapi", "libc", - "parking_lot", + "parking_lot 0.12.3", "winapi", ] @@ -1945,7 +1946,7 @@ dependencies = [ "daft-table", "futures", "memchr", - "parking_lot", + "parking_lot 0.12.3", "pyo3", "rayon", "rstest", @@ -2089,7 +2090,7 @@ dependencies = [ "google-cloud-token", "home", "hyper 0.14.30", - "hyper-tls", + "hyper-tls 0.5.0", "itertools 0.11.0", "lazy_static", "log", @@ -2098,12 +2099,16 @@ dependencies = [ "pyo3", "rand 0.8.5", "regex", - "reqwest", + "reqwest 0.11.27", + "reqwest-middleware", + "reqwest-retry", + "retry-policies", "serde", "snafu", "tempfile", "tokio", "tokio-stream", + "tracing", "url", ] @@ -3046,9 +3051,9 @@ dependencies = [ [[package]] name = "google-cloud-auth" -version = "0.13.2" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf7cb7864f08a92e77c26bb230d021ea57691788fb5dd51793f96965d19e7f9" +checksum = "357160f51a60ec3e32169ad687f4abe0ee1e90c73b449aa5d11256c4f1cf2ff6" dependencies = [ "async-trait", "base64 0.21.7", @@ -3056,7 +3061,7 @@ dependencies = [ "google-cloud-token", "home", "jsonwebtoken", - "reqwest", + "reqwest 0.12.9", "serde", "serde_json", "thiserror", @@ -3068,21 +3073,22 @@ dependencies = [ [[package]] name = "google-cloud-metadata" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc279bfb50487d7bcd900e8688406475fc750fe474a835b2ab9ade9eb1fc90e2" +checksum = "04f945a208886a13d07636f38fb978da371d0abc3e34bad338124b9f8c135a8f" dependencies = [ - "reqwest", + "reqwest 0.12.9", "thiserror", "tokio", ] [[package]] name = "google-cloud-storage" -version = "0.15.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac04b29849ebdeb9fb008988cc1c4d1f0c9d121b4c7f1ddeb8061df124580e93" +checksum = "c7347a3d65cd64db51e5b4aebf0c68c484042948c6d53f856f58269bc9816360" dependencies = [ + "anyhow", "async-stream", "async-trait", "base64 0.21.7", @@ -3096,7 +3102,8 @@ dependencies = [ "percent-encoding", "pkcs8", "regex", - "reqwest", + "reqwest 0.12.9", + "reqwest-middleware", "ring 0.17.8", "serde", "serde_json", @@ -3448,6 +3455,22 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.5.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -3578,6 +3601,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" dependencies = [ "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", ] [[package]] @@ -4382,6 +4408,17 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -4389,7 +4426,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.10", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -4400,7 +4451,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.3", "smallvec", "windows-targets 0.52.6", ] @@ -4777,7 +4828,7 @@ dependencies = [ "inventory", "libc", "memoffset", - "parking_lot", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -4999,6 +5050,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.3" @@ -5099,17 +5159,16 @@ dependencies = [ "http 0.2.12", "http-body 0.4.6", "hyper 0.14.30", - "hyper-tls", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", "mime", - "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", @@ -5127,6 +5186,94 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-tls 0.6.0", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "mime_guess", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 2.2.0", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "windows-registry", +] + +[[package]] +name = "reqwest-middleware" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562ceb5a604d3f7c885a792d42c199fd8af239d0a51b2fa6a78aafa092452b04" +dependencies = [ + "anyhow", + "async-trait", + "http 1.1.0", + "reqwest 0.12.9", + "serde", + "thiserror", + "tower-service", +] + +[[package]] +name = "reqwest-retry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a83df1aaec00176d0fabb65dea13f832d2a446ca99107afc17c5d2d4981221d0" +dependencies = [ + "anyhow", + "async-trait", + "futures", + "getrandom 0.2.15", + "http 1.1.0", + "hyper 1.5.0", + "parking_lot 0.11.2", + "reqwest 0.12.9", + "reqwest-middleware", + "retry-policies", + "tokio", + "tracing", + "wasm-timer", +] + +[[package]] +name = "retry-policies" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5875471e6cab2871bc150ecb8c727db5113c9338cc3354dc5ee3425b6aa40a1c" +dependencies = [ + "rand 0.8.5", +] + [[package]] name = "ring" version = "0.16.20" @@ -5235,6 +5382,21 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" + [[package]] name = "rustversion" version = "1.0.17" @@ -5832,6 +5994,9 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "synstructure" @@ -6022,7 +6187,7 @@ dependencies = [ "bstr", "fancy-regex", "lazy_static", - "parking_lot", + "parking_lot 0.12.3", "rustc-hash", ] @@ -6124,7 +6289,7 @@ dependencies = [ "bytes", "libc", "mio", - "parking_lot", + "parking_lot 0.12.3", "pin-project-lite", "signal-hook-registry", "socket2", @@ -6674,6 +6839,21 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasm-timer" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be0ecb0db480561e9a7642b5d3e4187c128914e58aa84330b9493e3eb68c5e7f" +dependencies = [ + "futures", + "js-sys", + "parking_lot 0.11.2", + "pin-utils", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.69" @@ -6740,6 +6920,36 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d4580cb3ca..8ec06319d3 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -600,6 +600,11 @@ class GCSConfig: credentials: str | None token: str | None anonymous: bool + max_connections: int + retry_initial_backoff_ms: int + connect_timeout_ms: int + read_timeout_ms: int + num_tries: int def __init__( self, @@ -607,6 +612,11 @@ class GCSConfig: credentials: str | None = None, token: str | None = None, anonymous: bool | None = None, + max_connections: int | None = None, + retry_initial_backoff_ms: int | None = None, + connect_timeout_ms: int | None = None, + read_timeout_ms: int | None = None, + num_tries: int | None = None, ): ... def replace( self, @@ -614,6 +624,11 @@ class GCSConfig: credentials: str | None = None, token: str | None = None, anonymous: bool | None = None, + max_connections: int | None = None, + retry_initial_backoff_ms: int | None = None, + connect_timeout_ms: int | None = None, + read_timeout_ms: int | None = None, + num_tries: int | None = None, ) -> GCSConfig: """Replaces values if provided, returning a new GCSConfig""" ... diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index a66e9bfa23..5903132c46 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -3,6 +3,7 @@ aws-credential-types = {version = "0.55.3"} chrono = {workspace = true} common-error = {path = "../error", default-features = false} common-py-serde = {path = "../py-serde", default-features = false} +derive_more = {workspace = true} pyo3 = {workspace = true, optional = true} secrecy = {version = "0.8.0", features = ["alloc"], default-features = false} serde = {workspace = true} diff --git a/src/common/io-config/src/gcs.rs b/src/common/io-config/src/gcs.rs index cd5e8628a3..a471298fb8 100644 --- a/src/common/io-config/src/gcs.rs +++ b/src/common/io-config/src/gcs.rs @@ -1,15 +1,45 @@ -use std::fmt::{Display, Formatter}; - +use derive_more::Display; use serde::{Deserialize, Serialize}; use crate::ObfuscatedString; -#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Display)] +#[display( + "GCSConfig + project_id: {project_id:?} + anonymous: {anonymous} + max_connections_per_io_thread: {max_connections_per_io_thread} + retry_initial_backoff_ms: {retry_initial_backoff_ms} + connect_timeout_ms: {connect_timeout_ms} + read_timeout_ms: {read_timeout_ms} + num_tries: {num_tries}" +)] pub struct GCSConfig { pub project_id: Option, pub credentials: Option, pub token: Option, pub anonymous: bool, + pub max_connections_per_io_thread: u32, + pub retry_initial_backoff_ms: u64, + pub connect_timeout_ms: u64, + pub read_timeout_ms: u64, + pub num_tries: u32, +} + +impl Default for GCSConfig { + fn default() -> Self { + Self { + project_id: None, + credentials: None, + token: None, + anonymous: false, + max_connections_per_io_thread: 8, + retry_initial_backoff_ms: 1000, + connect_timeout_ms: 30_000, + read_timeout_ms: 30_000, + num_tries: 5, + } + } } impl GCSConfig { @@ -20,18 +50,17 @@ impl GCSConfig { res.push(format!("Project ID = {project_id}")); } res.push(format!("Anonymous = {}", self.anonymous)); + res.push(format!( + "Max connections = {}", + self.max_connections_per_io_thread + )); + res.push(format!( + "Retry initial backoff ms = {}", + self.retry_initial_backoff_ms + )); + res.push(format!("Connect timeout ms = {}", self.connect_timeout_ms)); + res.push(format!("Read timeout ms = {}", self.read_timeout_ms)); + res.push(format!("Max retries = {}", self.num_tries)); res } } - -impl Display for GCSConfig { - fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - write!( - f, - "GCSConfig - project_id: {:?} - anonymous: {:?}", - self.project_id, self.anonymous - ) - } -} diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 2632ebecdd..aeae87a714 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -25,12 +25,12 @@ use crate::{config, s3::S3CredentialsProvider}; /// access_key (str, optional): AWS Secret Access Key, defaults to auto-detection from the current environment /// credentials_provider (Callable[[], S3Credentials], optional): Custom credentials provider function, should return a `S3Credentials` object /// buffer_time (int, optional): Amount of time in seconds before the actual credential expiration time where credentials given by `credentials_provider` are considered expired, defaults to 10s -/// max_connections (int, optional): Maximum number of connections to S3 at any time, defaults to 64 +/// max_connections (int, optional): Maximum number of connections to S3 at any time per io thread, defaults to 8 /// session_token (str, optional): AWS Session Token, required only if `key_id` and `access_key` are temporary credentials /// retry_initial_backoff_ms (int, optional): Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms -/// connect_timeout_ms (int, optional): Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 10 seconds -/// read_timeout_ms (int, optional): Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 10 seconds -/// num_tries (int, optional): Number of attempts to make a connection, defaults to 5 +/// connect_timeout_ms (int, optional): Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 30 seconds +/// read_timeout_ms (int, optional): Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 30 seconds +/// num_tries (int, optional): Number of attempts to make a connection, defaults to 25 /// retry_mode (str, optional): Retry Mode when a request fails, current supported values are `standard` and `adaptive`, defaults to `adaptive` /// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access S3 without any credentials /// use_ssl (bool, optional): Whether or not to use SSL, which require accessing S3 over HTTPS rather than HTTP, defaults to True @@ -107,6 +107,11 @@ pub struct AzureConfig { /// credentials (str, optional): Path to credentials file or JSON string with credentials /// token (str, optional): OAuth2 token to use for authentication. You likely want to use `credentials` instead, since it can be used to refresh the token. This value is used when vended by a data catalog. /// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access Google Storage without any credentials. Defaults to false +/// max_connections (int, optional): Maximum number of connections to GCS at any time per io thread, defaults to 8 +/// retry_initial_backoff_ms (int, optional): Initial backoff duration in milliseconds for an GCS retry, defaults to 1000ms +/// connect_timeout_ms (int, optional): Timeout duration to wait to make a connection to GCS in milliseconds, defaults to 30 seconds +/// read_timeout_ms (int, optional): Timeout duration to wait to read the first byte from GCS in milliseconds, defaults to 30 seconds +/// num_tries (int, optional): Number of attempts to make a connection, defaults to 5 /// /// Example: /// >>> io_config = IOConfig(gcs=GCSConfig(anonymous=True)) @@ -848,6 +853,11 @@ impl GCSConfig { credentials: Option, token: Option, anonymous: Option, + max_connections: Option, + retry_initial_backoff_ms: Option, + connect_timeout_ms: Option, + read_timeout_ms: Option, + num_tries: Option, ) -> Self { let def = crate::GCSConfig::default(); Self { @@ -858,10 +868,17 @@ impl GCSConfig { .or(def.credentials), token: token.or(def.token), anonymous: anonymous.unwrap_or(def.anonymous), + max_connections_per_io_thread: max_connections + .unwrap_or(def.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(def.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(def.read_timeout_ms), + num_tries: num_tries.unwrap_or(def.num_tries), }, } } - + #[allow(clippy::too_many_arguments)] #[must_use] pub fn replace( &self, @@ -869,6 +886,11 @@ impl GCSConfig { credentials: Option, token: Option, anonymous: Option, + max_connections: Option, + retry_initial_backoff_ms: Option, + connect_timeout_ms: Option, + read_timeout_ms: Option, + num_tries: Option, ) -> Self { Self { config: crate::GCSConfig { @@ -878,6 +900,13 @@ impl GCSConfig { .or_else(|| self.config.credentials.clone()), token: token.or_else(|| self.config.token.clone()), anonymous: anonymous.unwrap_or(self.config.anonymous), + max_connections_per_io_thread: max_connections + .unwrap_or(self.config.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(self.config.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(self.config.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(self.config.read_timeout_ms), + num_tries: num_tries.unwrap_or(self.config.num_tries), }, } } @@ -913,6 +942,31 @@ impl GCSConfig { pub fn anonymous(&self) -> PyResult { Ok(self.config.anonymous) } + + #[getter] + pub fn max_connections(&self) -> PyResult { + Ok(self.config.max_connections_per_io_thread) + } + + #[getter] + pub fn retry_initial_backoff_ms(&self) -> PyResult { + Ok(self.config.retry_initial_backoff_ms) + } + + #[getter] + pub fn connect_timeout_ms(&self) -> PyResult { + Ok(self.config.connect_timeout_ms) + } + + #[getter] + pub fn read_timeout_ms(&self) -> PyResult { + Ok(self.config.read_timeout_ms) + } + + #[getter] + pub fn num_tries(&self) -> PyResult { + Ok(self.config.num_tries) + } } impl From for IOConfig { diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index d904b8b655..ebab7448fd 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -20,7 +20,7 @@ common-runtime = {path = "../common/runtime", default-features = false} derive_builder = {workspace = true} futures = {workspace = true} globset = "0.4" -google-cloud-storage = {version = "0.15.0", default-features = false, features = ["default-tls", "auth"]} +google-cloud-storage = {version = "0.22.1", default-features = false, features = ["default-tls", "auth"]} google-cloud-token = {version = "0.1.2"} home = "0.5.9" hyper = "0.14.27" @@ -32,10 +32,14 @@ openssl-sys = {version = "0.9.102", features = ["vendored"]} pyo3 = {workspace = true, optional = true} rand = "0.8.5" regex = {version = "1.10.4"} +reqwest-middleware = "0.3.3" +reqwest-retry = "0.6.1" +retry-policies = "0.4.0" serde = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} tokio-stream = {workspace = true} +tracing = {workspace = true} url = {workspace = true} [dependencies.reqwest] diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index d74484fa27..41ca0261d0 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -1,7 +1,8 @@ -use std::{ops::Range, sync::Arc}; +use std::{ops::Range, sync::Arc, time::Duration}; use async_trait::async_trait; use common_io_config::GCSConfig; +use common_runtime::get_io_pool_num_threads; use futures::{stream::BoxStream, TryStreamExt}; use google_cloud_storage::{ client::{google_cloud_auth::credentials::CredentialsFile, Client, ClientConfig}, @@ -12,6 +13,7 @@ use google_cloud_storage::{ }; use google_cloud_token::{TokenSource, TokenSourceProvider}; use snafu::{IntoError, ResultExt, Snafu}; +use tokio::sync::Semaphore; use crate::{ object_io::{FileMetadata, FileType, LSResult, ObjectSource}, @@ -48,13 +50,20 @@ enum Error { NotAFile { path: String }, #[snafu(display("Not a File: \"{}\"", path))] NotFound { path: String }, + #[snafu(display("Unable to grab semaphore. {}", source))] + UnableToGrabSemaphore { source: tokio::sync::AcquireError }, + + #[snafu(display("Unable to create Http Client {}", source))] + UnableToCreateClient { + source: reqwest_middleware::reqwest::Error, + }, } impl From for super::Error { fn from(error: Error) -> Self { use Error::{ - InvalidUrl, NotAFile, NotFound, UnableToListObjects, UnableToLoadCredentials, - UnableToOpenFile, UnableToReadBytes, + InvalidUrl, NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore, + UnableToListObjects, UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes, }; match error { UnableToReadBytes { path, source } @@ -70,10 +79,24 @@ impl From for super::Error { path, source: err.into(), }, - _ => Self::UnableToOpenFile { - path, - source: err.into(), - }, + _ => { + if err.is_connect() { + Self::ConnectTimeout { + path, + source: err.into(), + } + } else if err.is_timeout() { + Self::ReadTimeout { + path, + source: err.into(), + } + } else { + Self::UnableToOpenFile { + path, + source: err.into(), + } + } + } }, GError::Response(err) => match err.code { 404 | 410 => Self::NotFound { @@ -94,6 +117,12 @@ impl From for super::Error { store: super::SourceType::GCS, source: err, }, + err @ GError::HttpMiddleware(_) | err @ GError::InvalidRangeHeader(_) => { + Self::UnableToOpenFile { + path, + source: err.into(), + } + } }, NotFound { ref path } => Self::NotFound { path: path.into(), @@ -105,11 +134,25 @@ impl From for super::Error { source: source.into(), }, NotAFile { path } => Self::NotAFile { path }, + UnableToGrabSemaphore { .. } => Self::Generic { + store: crate::SourceType::GCS, + source: error.into(), + }, + UnableToCreateClient { .. } => Self::UnableToCreateClient { + store: crate::SourceType::GCS, + source: error.into(), + }, } } } -struct GCSClientWrapper(Client); +struct GCSClientWrapper { + client: Client, + /// Used to limit the concurrent connections to GCS at any given time. + /// Acquired when we initiate a connection to GCS + /// Released when the stream for that connection is exhausted + connection_pool_sema: Arc, +} fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> { let bucket = match uri.host_str() { @@ -136,8 +179,13 @@ impl GCSClientWrapper { if key.is_empty() { return Err(Error::NotAFile { path: uri.into() }.into()); } - - let client = &self.0; + let permit = self + .connection_pool_sema + .clone() + .acquire_owned() + .await + .context(UnableToGrabSemaphoreSnafu)?; + let client = &self.client; let req = GetObjectRequest { bucket: bucket.into(), object: key.into(), @@ -172,7 +220,7 @@ impl GCSClientWrapper { Ok(GetResult::Stream( io_stats_on_bytestream(response, io_stats), size, - None, + Some(permit), None, )) } @@ -183,7 +231,14 @@ impl GCSClientWrapper { if key.is_empty() { return Err(Error::NotAFile { path: uri.into() }.into()); } - let client = &self.0; + + let _permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + + let client = &self.client; let req = GetObjectRequest { bucket: bucket.into(), object: key.into(), @@ -262,7 +317,14 @@ impl GCSClientWrapper { ) -> super::Result { let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; let (bucket, key) = parse_uri(&uri)?; - let client = &self.0; + + let _permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + + let client = &self.client; if posix { // Attempt to forcefully ls the key as a directory (by ensuring a "/" suffix) @@ -393,10 +455,48 @@ impl GCSSource { if config.project_id.is_some() { client_config.project_id.clone_from(&config.project_id); } + client_config.http = Some({ + use reqwest_middleware::ClientBuilder; + use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; + use retry_policies::Jitter; + let retry_policy = ExponentialBackoff::builder() + .base(2) + .jitter(Jitter::Bounded) + .retry_bounds( + Duration::from_millis(config.retry_initial_backoff_ms), + Duration::from_secs(60), + ) + .build_with_max_retries(config.num_tries); + + let base_client = reqwest_middleware::reqwest::ClientBuilder::default() + .connect_timeout(Duration::from_millis(config.connect_timeout_ms)) + .read_timeout(Duration::from_millis(config.read_timeout_ms)) + .pool_idle_timeout(Duration::from_secs(60)) + .pool_max_idle_per_host(70) + .build() + .context(UnableToCreateClientSnafu)?; + + ClientBuilder::new(base_client) + // reqwest-retry already comes with a default retry strategy that matches http standards + // override it only if you need a custom one due to non standard behavior + .with( + RetryTransientMiddleware::new_with_policy(retry_policy) + .with_retry_log_level(tracing::Level::DEBUG), + ) + .build() + }); let client = Client::new(client_config); + + let connection_pool_sema = Arc::new(tokio::sync::Semaphore::new( + (config.max_connections_per_io_thread as usize) + * get_io_pool_num_threads().expect("Should be running in tokio pool"), + )); Ok(Self { - client: GCSClientWrapper(client), + client: GCSClientWrapper { + client, + connection_pool_sema, + }, } .into()) } diff --git a/src/daft-sql/src/modules/config.rs b/src/daft-sql/src/modules/config.rs index 9a540d3025..fdd550160b 100644 --- a/src/daft-sql/src/modules/config.rs +++ b/src/daft-sql/src/modules/config.rs @@ -372,14 +372,27 @@ pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult { let credentials = get_value!("credentials", Utf8)?; let token = get_value!("token", Utf8)?; let anonymous = get_value!("anonymous", Boolean)?; - let default = GCSConfig::default(); + let max_connections_per_io_thread = + get_value!("max_connections_per_io_thread", UInt32)?; + let retry_initial_backoff_ms = get_value!("retry_initial_backoff_ms", UInt64)?; + let connect_timeout_ms = get_value!("connect_timeout_ms", UInt64)?; + let read_timeout_ms = get_value!("read_timeout_ms", UInt64)?; + let num_tries = get_value!("num_tries", UInt32)?; + let default = GCSConfig::default(); Ok(IOConfig { gcs: GCSConfig { project_id, credentials: credentials.map(|s| s.into()), token, anonymous: anonymous.unwrap_or(default.anonymous), + max_connections_per_io_thread: max_connections_per_io_thread + .unwrap_or(default.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(default.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(default.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(default.read_timeout_ms), + num_tries: num_tries.unwrap_or(default.num_tries), }, ..Default::default() }) diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index ec9c9a397f..dff82b5fb0 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -408,7 +408,7 @@ def test_row_groups_selection_into_pyarrow_bulk(public_storage_io_config, multit "multithreaded_io", [False, True], ) -def test_connect_timeout(multithreaded_io): +def test_connect_timeout_s3(multithreaded_io): url = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" connect_timeout_config = daft.io.IOConfig( s3=daft.io.S3Config( @@ -429,7 +429,7 @@ def test_connect_timeout(multithreaded_io): "multithreaded_io", [False, True], ) -def test_read_timeout(multithreaded_io): +def test_read_timeout_s3(multithreaded_io): url = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" read_timeout_config = daft.io.IOConfig( s3=daft.io.S3Config( @@ -459,3 +459,43 @@ def test_read_file_level_timeout(): with pytest.raises((ReadTimeoutError), match=f"Parquet reader timed out while trying to read: {url}"): daft.table.read_parquet_into_pyarrow(url, io_config=read_timeout_config, file_timeout_ms=2) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_connect_timeout_gcs(multithreaded_io): + url = "gs://daft-public-data-gs/mvp.parquet" + connect_timeout_config = daft.io.IOConfig( + gcs=daft.io.GCSConfig( + anonymous=True, + connect_timeout_ms=1, + retry_initial_backoff_ms=10, + num_tries=3, + ) + ) + + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"timed out when trying to connect to {url}"): + MicroPartition.read_parquet(url, io_config=connect_timeout_config, multithreaded_io=multithreaded_io).to_arrow() + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_read_timeout_gcs(multithreaded_io): + url = "gs://daft-public-data-gs/mvp.parquet" + read_timeout_config = daft.io.IOConfig( + gcs=daft.io.GCSConfig( + anonymous=True, + read_timeout_ms=1, + retry_initial_backoff_ms=10, + num_tries=3, + ) + ) + + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"Read timed out when trying to read {url}"): + MicroPartition.read_parquet(url, io_config=read_timeout_config, multithreaded_io=multithreaded_io).to_arrow() diff --git a/tools/check_for_rustls.sh b/tools/check_for_rustls.sh index 264937c1ba..fdd6009de6 100755 --- a/tools/check_for_rustls.sh +++ b/tools/check_for_rustls.sh @@ -1,2 +1,2 @@ #!/bin/bash -cargo tree --workspace --all-features | grep -v 'rustls-pemfile' | grep -vzq 'rustls' +cargo tree --workspace --all-features | grep -v 'rustls-pemfile' | grep -v 'rustls-pki-types' | grep -vzq 'rustls'