From e4755d7f7972df1322193b061784f0fabb2e53f5 Mon Sep 17 00:00:00 2001
From: Sammy Sidhu <samster25@users.noreply.github.com>
Date: Fri, 24 Nov 2023 21:23:40 -0800
Subject: [PATCH] [FEAT] add retries to s3 credential provider timeouts (#1663)

* Adds exponential backoff when grabbing credentials that timeout
---
 Cargo.lock                 |  1 +
 src/daft-io/Cargo.toml     |  1 +
 src/daft-io/src/s3_like.rs | 52 ++++++++++++++++++++++++++++----------
 3 files changed, 41 insertions(+), 13 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 311234cbe7..8d50e330dc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1232,6 +1232,7 @@ dependencies = [
  "openssl-sys",
  "pyo3",
  "pyo3-log",
+ "rand 0.8.5",
  "regex",
  "reqwest",
  "serde",
diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml
index 06dd071925..fe6bfdfc07 100644
--- a/src/daft-io/Cargo.toml
+++ b/src/daft-io/Cargo.toml
@@ -26,6 +26,7 @@ log = {workspace = true}
 openssl-sys = {version = "0.9.93", features = ["vendored"]}
 pyo3 = {workspace = true, optional = true}
 pyo3-log = {workspace = true, optional = true}
+rand = "0.8.5"
 regex = {version = "1.9.5"}
 serde = {workspace = true}
 serde_json = {workspace = true}
diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs
index 20f2351a60..51a3a992d8 100644
--- a/src/daft-io/src/s3_like.rs
+++ b/src/daft-io/src/s3_like.rs
@@ -323,20 +323,46 @@ async fn build_s3_client(
 
     let builder_copy = builder.clone();
     let s3_conf = builder.build();
-    if !config.anonymous {
+    const CRED_TRIES: u64 = 4;
+    const JITTER_MS: u64 = 2_500;
+    const MAX_BACKOFF_MS: u64 = 20_000;
+    const MAX_WAITTIME_MS: u64 = 45_000;
+    let check_creds = async || -> super::Result<bool> {
+        use rand::Rng;
         use CredentialsError::*;
-        match s3_conf
-            .credentials_cache()
-            .provide_cached_credentials()
-            .await {
-            Ok(_) => Ok(()),
-            Err(err @ CredentialsNotLoaded(..)) => {
-                log::warn!("S3 Credentials not provided or found when making client for {}! Reverting to Anonymous mode. {err}", s3_conf.region().unwrap_or(&DEFAULT_REGION));
-                anonymous = true;
-                Ok(())
-            },
-            Err(err) => Err(err),
-        }.with_context(|_| UnableToLoadCredentialsSnafu {})?;
+        let mut attempt = 0;
+        let first_attempt_time = std::time::Instant::now();
+        loop {
+            let creds = s3_conf
+                .credentials_cache()
+                .provide_cached_credentials()
+                .await;
+            attempt += 1;
+            match creds {
+                Ok(_) => return Ok(false),
+                Err(err @  ProviderTimedOut(..)) => {
+                    let total_time_waited_ms: u64 = first_attempt_time.elapsed().as_millis().try_into().unwrap();
+                    if attempt < CRED_TRIES && (total_time_waited_ms < MAX_WAITTIME_MS) {
+                        let jitter = rand::thread_rng().gen_range(0..((2<<attempt) * JITTER_MS)) as u64;
+                        let jitter = jitter.min(MAX_BACKOFF_MS);
+                        log::warn!("S3 Credentials Provider timed out when making client for {}! Attempt {attempt} out of {CRED_TRIES} tries. Trying again in {jitter}ms. {err}", s3_conf.region().unwrap_or(&DEFAULT_REGION));
+                        tokio::time::sleep(Duration::from_millis(jitter)).await;
+                        continue;
+                    } else {
+                        Err(err)
+                    }
+                }
+                Err(err @ CredentialsNotLoaded(..)) => {
+                    log::warn!("S3 Credentials not provided or found when making client for {}! Reverting to Anonymous mode. {err}", s3_conf.region().unwrap_or(&DEFAULT_REGION));
+                    return Ok(true)
+                },
+                Err(err) => Err(err),
+            }.with_context(|_| UnableToLoadCredentialsSnafu {})?;
+        }
+    };
+
+    if !config.anonymous {
+        anonymous = check_creds().await?;
     };
 
     let s3_conf = if s3_conf.region().is_none() {