diff --git a/Cargo.lock b/Cargo.lock index 512b46c24f..64bb4347b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -975,10 +975,22 @@ dependencies = [ name = "common-io-config" version = "0.2.0-dev0" dependencies = [ + "aws-credential-types", + "chrono", "common-error", + "common-py-serde", "pyo3", "serde", "serde_json", + "typetag", +] + +[[package]] +name = "common-py-serde" +version = "0.2.0-dev0" +dependencies = [ + "pyo3", + "serde", ] [[package]] @@ -1455,6 +1467,7 @@ dependencies = [ "common-daft-config", "common-error", "common-io-config", + "common-py-serde", "common-treenode", "daft-core", "daft-dsl", @@ -1479,6 +1492,7 @@ dependencies = [ "common-daft-config", "common-error", "common-io-config", + "common-py-serde", "daft-core", "daft-csv", "daft-dsl", @@ -1611,6 +1625,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d" +dependencies = [ + "serde", + "typeid", +] + [[package]] name = "errno" version = "0.3.5" @@ -4133,12 +4157,42 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "typeid" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059d83cc991e7a42fc37bd50941885db0888e34209f8cfd9aab07ddec03bc9cf" + [[package]] name = "typenum" version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "typetag" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661d18414ec032a49ece2d56eee03636e43c4e8d577047ab334c0ba892e29aaf" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.57", +] + [[package]] name = "unicase" version = "2.7.0" diff --git a/daft/daft.pyi b/daft/daft.pyi index 89ae597598..2c759eba0a 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1,4 +1,5 @@ import builtins +import datetime from enum import Enum from typing import TYPE_CHECKING, Any, Callable @@ -417,6 +418,7 @@ class S3Config: key_id: str | None session_token: str | None access_key: str | None + credentials_provider: Callable[[], S3Credentials] | None max_connections: int retry_initial_backoff_ms: int connect_timeout_ms: int @@ -438,6 +440,8 @@ class S3Config: key_id: str | None = None, session_token: str | None = None, access_key: str | None = None, + credentials_provider: Callable[[], S3Credentials] | None = None, + buffer_time: int | None = None, max_connections: int | None = None, retry_initial_backoff_ms: int | None = None, connect_timeout_ms: int | None = None, @@ -459,6 +463,7 @@ class S3Config: key_id: str | None = None, session_token: str | None = None, access_key: str | None = None, + credentials_provider: Callable[[], S3Credentials] | None = None, max_connections: int | None = None, retry_initial_backoff_ms: int | None = None, connect_timeout_ms: int | None = None, @@ -481,6 +486,16 @@ class S3Config: """Creates an S3Config, retrieving credentials and configurations from the current environment""" ... +class S3Credentials: + key_id: str + access_key: str + session_token: str | None + expiry: datetime.datetime | None + + def __init__( + self, key_id: str, access_key: str, session_token: str | None = None, expiry: datetime.datetime | None = None + ): ... + class AzureConfig: """ I/O configuration for accessing Azure Blob Storage. diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 166fe12bdc..b56e6919b6 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -7,6 +7,7 @@ GCSConfig, IOConfig, S3Config, + S3Credentials, set_io_pool_num_threads, ) from daft.io._csv import read_csv @@ -47,6 +48,7 @@ def _set_linux_cert_paths(): "read_sql", "IOConfig", "S3Config", + "S3Credentials", "AzureConfig", "GCSConfig", "set_io_pool_num_threads", diff --git a/docs/source/api_docs/configs.rst b/docs/source/api_docs/configs.rst index 9282bbe61c..cc1322b08d 100644 --- a/docs/source/api_docs/configs.rst +++ b/docs/source/api_docs/configs.rst @@ -38,5 +38,6 @@ These configurations are most often used as inputs to Daft DataFrame reading I/O daft.io.IOConfig daft.io.S3Config + daft.io.S3Credentials daft.io.GCSConfig daft.io.AzureConfig diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index 307de25498..49d94f718e 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -1,12 +1,16 @@ [dependencies] +aws-credential-types = {version = "0.55.3"} +chrono = {workspace = true} common-error = {path = "../error", default-features = false} +common-py-serde = {path = "../py-serde", default-features = false} pyo3 = {workspace = true, optional = true} serde = {workspace = true} serde_json = {workspace = true} +typetag = "0.2.16" [features] default = ["python"] -python = ["dep:pyo3", "common-error/python"] +python = ["dep:pyo3", "common-error/python", "common-py-serde/python"] [package] edition = {workspace = true} diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index 535ac72d40..d0952541b4 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -6,4 +6,6 @@ mod config; mod gcs; mod s3; -pub use crate::{azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config}; +pub use crate::{ + azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config, s3::S3Credentials, +}; diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index ebc13a96b6..f3b7aa16da 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -1,32 +1,44 @@ -use std::hash::Hasher; - +use std::{ + any::Any, + hash::{Hash, Hasher}, + time::{Duration, SystemTime}, +}; + +use aws_credential_types::{ + provider::{error::CredentialsError, ProvideCredentials}, + Credentials, +}; use common_error::DaftError; +use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; -use crate::config; +use crate::{config, s3::S3CredentialsProvider}; /// Create configurations to be used when accessing an S3-compatible system /// /// Args: -/// region_name: Name of the region to be used (used when accessing AWS S3), defaults to "us-east-1". +/// region_name (str, optional): Name of the region to be used (used when accessing AWS S3), defaults to "us-east-1". /// If wrongly provided, Daft will attempt to auto-detect the buckets' region at the cost of extra S3 requests. -/// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS -/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment -/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment -/// max_connections: Maximum number of connections to S3 at any time, defaults to 64 -/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials -/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms -/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 10 seconds -/// read_timeout_ms: Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 10 seconds -/// num_tries: Number of attempts to make a connection, defaults to 5 -/// retry_mode: Retry Mode when a request fails, current supported values are `standard` and `adaptive`, defaults to `adaptive` -/// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials -/// use_ssl: Whether or not to use SSL, which require accessing S3 over HTTPS rather than HTTP, defaults to True -/// verify_ssl: Whether or not to verify ssl certificates, which will access S3 without checking if the certs are valid, defaults to True -/// check_hostname_ssl: Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True -/// requester_pays: Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False -/// force_virtual_addressing: Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False -/// profile_name: Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default` +/// endpoint_url (str, optional): URL to the S3 endpoint, defaults to endpoints to AWS +/// key_id (str, optional): AWS Access Key ID, defaults to auto-detection from the current environment +/// access_key (str, optional): AWS Secret Access Key, defaults to auto-detection from the current environment +/// credentials_provider (Callable[[], S3Credentials], optional): Custom credentials provider function, should return a `S3Credentials` object +/// buffer_time (int, optional): Amount of time in seconds before the actual credential expiration time where credentials given by `credentials_provider` are considered expired, defaults to 10s +/// max_connections (int, optional): Maximum number of connections to S3 at any time, defaults to 64 +/// session_token (str, optional): AWS Session Token, required only if `key_id` and `access_key` are temporary credentials +/// retry_initial_backoff_ms (int, optional): Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms +/// connect_timeout_ms (int, optional): Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 10 seconds +/// read_timeout_ms (int, optional): Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 10 seconds +/// num_tries (int, optional): Number of attempts to make a connection, defaults to 5 +/// retry_mode (str, optional): Retry Mode when a request fails, current supported values are `standard` and `adaptive`, defaults to `adaptive` +/// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access S3 without any credentials +/// use_ssl (bool, optional): Whether or not to use SSL, which require accessing S3 over HTTPS rather than HTTP, defaults to True +/// verify_ssl (bool, optional): Whether or not to verify ssl certificates, which will access S3 without checking if the certs are valid, defaults to True +/// check_hostname_ssl (bool, optional): Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True +/// requester_pays (bool, optional): Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False +/// force_virtual_addressing (bool, optional): Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False +/// profile_name (str, optional): Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default` /// /// Example: /// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx")) @@ -36,6 +48,29 @@ use crate::config; pub struct S3Config { pub config: crate::S3Config, } + +/// Create credentials to be used when accessing an S3-compatible system +/// +/// Args: +/// key_id (str): AWS Access Key ID, defaults to auto-detection from the current environment +/// access_key (str): AWS Secret Access Key, defaults to auto-detection from the current environment +/// session_token (str, optional): AWS Session Token, required only if `key_id` and `access_key` are temporary credentials +/// expiry (datetime.datetime, optional): Expiry time of the credentials, credentials are assumed to be permanent if not provided +/// +/// Example: +/// >>> get_credentials = lambda: S3Credentials( +/// ... key_id="xxx", +/// ... access_key="xxx", +/// ... expiry=(datetime.datetime.now() + datetime.timedelta(hours=1)) +/// ... ) +/// >>> io_config = IOConfig(s3=S3Config(credentials_provider=get_credentials)) +/// >>> daft.read_parquet("s3://some-path", io_config=io_config) +#[derive(Clone)] +#[pyclass] +pub struct S3Credentials { + pub credentials: crate::S3Credentials, +} + /// Create configurations to be used when accessing Azure Blob Storage /// /// Args: @@ -172,11 +207,14 @@ impl S3Config { #[allow(clippy::too_many_arguments)] #[new] pub fn new( + py: Python, region_name: Option, endpoint_url: Option, key_id: Option, session_token: Option, access_key: Option, + credentials_provider: Option<&PyAny>, + buffer_time: Option, max_connections: Option, retry_initial_backoff_ms: Option, connect_timeout_ms: Option, @@ -190,15 +228,23 @@ impl S3Config { requester_pays: Option, force_virtual_addressing: Option, profile_name: Option, - ) -> Self { + ) -> PyResult { let def = crate::S3Config::default(); - S3Config { + Ok(S3Config { config: crate::S3Config { region_name: region_name.or(def.region_name), endpoint_url: endpoint_url.or(def.endpoint_url), key_id: key_id.or(def.key_id), session_token: session_token.or(def.session_token), access_key: access_key.or(def.access_key), + credentials_provider: credentials_provider + .map(|p| { + Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(py, p)?) + as Box) + }) + .transpose()? + .or(def.credentials_provider), + buffer_time: buffer_time.or(def.buffer_time), max_connections_per_io_thread: max_connections .unwrap_or(def.max_connections_per_io_thread), retry_initial_backoff_ms: retry_initial_backoff_ms @@ -216,17 +262,20 @@ impl S3Config { .unwrap_or(def.force_virtual_addressing), profile_name: profile_name.or(def.profile_name), }, - } + }) } #[allow(clippy::too_many_arguments)] pub fn replace( &self, + py: Python, region_name: Option, endpoint_url: Option, key_id: Option, session_token: Option, access_key: Option, + credentials_provider: Option<&PyAny>, + buffer_time: Option, max_connections: Option, retry_initial_backoff_ms: Option, connect_timeout_ms: Option, @@ -240,14 +289,22 @@ impl S3Config { requester_pays: Option, force_virtual_addressing: Option, profile_name: Option, - ) -> Self { - S3Config { + ) -> PyResult { + Ok(S3Config { config: crate::S3Config { region_name: region_name.or_else(|| self.config.region_name.clone()), endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()), key_id: key_id.or_else(|| self.config.key_id.clone()), session_token: session_token.or_else(|| self.config.session_token.clone()), access_key: access_key.or_else(|| self.config.access_key.clone()), + credentials_provider: credentials_provider + .map(|p| { + Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(py, p)?) + as Box) + }) + .transpose()? + .or_else(|| self.config.credentials_provider.clone()), + buffer_time: buffer_time.or(self.config.buffer_time), max_connections_per_io_thread: max_connections .unwrap_or(self.config.max_connections_per_io_thread), retry_initial_backoff_ms: retry_initial_backoff_ms @@ -265,7 +322,7 @@ impl S3Config { .unwrap_or(self.config.force_virtual_addressing), profile_name: profile_name.or_else(|| self.config.profile_name.clone()), }, - } + }) } /// Creates an S3Config from the current environment, auto-discovering variables such as @@ -323,6 +380,22 @@ impl S3Config { Ok(self.config.max_connections_per_io_thread) } + /// Custom credentials provider function + #[getter] + pub fn credentials_provider(&self, py: Python) -> PyResult>> { + Ok(self.config.credentials_provider.as_ref().and_then(|p| { + p.as_any() + .downcast_ref::() + .map(|p| p.provider.as_ref(py).into()) + })) + } + + /// AWS Buffer Time in Seconds + #[getter] + pub fn buffer_time(&self) -> PyResult> { + Ok(self.config.buffer_time) + } + /// AWS Retry Initial Backoff Time in Milliseconds #[getter] pub fn retry_initial_backoff_ms(&self) -> PyResult { @@ -396,6 +469,151 @@ impl S3Config { } } +#[pymethods] +impl S3Credentials { + #[new] + pub fn new( + key_id: String, + access_key: String, + session_token: Option, + expiry: Option<&PyAny>, + ) -> PyResult { + // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) + let expiry = expiry + .map(|e| { + let ts = e.call_method0("timestamp")?.extract()?; + + Ok::<_, PyErr>(SystemTime::UNIX_EPOCH + Duration::from_secs_f64(ts)) + }) + .transpose()?; + + Ok(S3Credentials { + credentials: crate::S3Credentials { + key_id, + access_key, + session_token, + expiry, + }, + }) + } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.credentials)) + } + + /// AWS Access Key ID + #[getter] + pub fn key_id(&self) -> PyResult { + Ok(self.credentials.key_id.clone()) + } + + /// AWS Secret Access Key + #[getter] + pub fn access_key(&self) -> PyResult { + Ok(self.credentials.access_key.clone()) + } + + /// AWS Session Token + #[getter] + pub fn expiry<'a>(&self, py: Python<'a>) -> PyResult> { + // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) + self.credentials + .expiry + .map(|e| { + let datetime = py.import("datetime")?; + + datetime.getattr("datetime")?.call_method1( + "fromtimestamp", + (e.duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs_f64(),), + ) + }) + .transpose() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PyS3CredentialsProvider { + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + pub provider: PyObject, + pub hash: isize, +} + +impl PyS3CredentialsProvider { + pub fn new(py: Python, provider: &PyAny) -> PyResult { + Ok(PyS3CredentialsProvider { + provider: provider.to_object(py), + hash: provider.hash()?, + }) + } +} + +impl ProvideCredentials for PyS3CredentialsProvider { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + aws_credential_types::provider::future::ProvideCredentials::ready( + Python::with_gil(|py| { + let py_creds = self.provider.call0(py)?; + py_creds.extract::(py) + }) + .map_err(|e| CredentialsError::provider_error(Box::new(e))) + .map(|creds| { + Credentials::new( + creds.credentials.key_id, + creds.credentials.access_key, + creds.credentials.session_token, + creds.credentials.expiry, + "daft_custom_provider", + ) + }), + ) + } +} + +impl PartialEq for PyS3CredentialsProvider { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash + } +} + +impl Eq for PyS3CredentialsProvider {} + +impl Hash for PyS3CredentialsProvider { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +#[typetag::serde] +impl S3CredentialsProvider for PyS3CredentialsProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn dyn_eq(&self, other: &dyn S3CredentialsProvider) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |other| self == other) + } + + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state) + } +} + #[pymethods] impl AzureConfig { #[allow(clippy::too_many_arguments)] @@ -523,6 +741,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; Ok(()) } diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index f7efdf3607..0f54ac92a9 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -1,16 +1,25 @@ -use std::fmt::Display; -use std::fmt::Formatter; - +use aws_credential_types::provider::ProvideCredentials; +use chrono::offset::Utc; +use chrono::DateTime; use serde::Deserialize; use serde::Serialize; +use std::any::Any; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::hash::Hash; +use std::hash::Hasher; +use std::time::SystemTime; -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)] pub struct S3Config { pub region_name: Option, pub endpoint_url: Option, pub key_id: Option, pub session_token: Option, pub access_key: Option, + pub credentials_provider: Option>, + pub buffer_time: Option, pub max_connections_per_io_thread: u32, pub retry_initial_backoff_ms: u64, pub connect_timeout_ms: u64, @@ -26,6 +35,53 @@ pub struct S3Config { pub profile_name: Option, } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct S3Credentials { + pub key_id: String, + pub access_key: String, + pub session_token: Option, + pub expiry: Option, +} + +#[typetag::serde(tag = "type")] +pub trait S3CredentialsProvider: ProvideCredentials + Debug { + fn as_any(&self) -> &dyn Any; + fn clone_box(&self) -> Box; + fn dyn_eq(&self, other: &dyn S3CredentialsProvider) -> bool; + fn dyn_hash(&self, state: &mut dyn Hasher); +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.clone_box() + } +} + +impl PartialEq for Box { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_ref()) + } +} + +impl Eq for Box {} + +impl Hash for Box { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +impl ProvideCredentials for Box { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + self.as_ref().provide_credentials() + } +} + impl S3Config { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; @@ -44,6 +100,12 @@ impl S3Config { if let Some(access_key) = &self.access_key { res.push(format!("Access key = {}", access_key)); } + if let Some(credentials_provider) = &self.credentials_provider { + res.push(format!("Credentials provider = {:?}", credentials_provider)); + } + if let Some(buffer_time) = &self.buffer_time { + res.push(format!("Buffer time = {}", buffer_time)); + } res.push(format!( "Max connections = {}", self.max_connections_per_io_thread @@ -82,6 +144,8 @@ impl Default for S3Config { key_id: None, session_token: None, access_key: None, + credentials_provider: None, + buffer_time: None, max_connections_per_io_thread: 8, retry_initial_backoff_ms: 1000, connect_timeout_ms: 30_000, @@ -111,6 +175,8 @@ impl Display for S3Config { key_id: {:?} session_token: {:?}, access_key: {:?} + credentials_provider: {:?} + buffer_time: {:?} max_connections: {}, retry_initial_backoff_ms: {}, connect_timeout_ms: {}, @@ -122,13 +188,14 @@ impl Display for S3Config { verify_ssl: {}, check_hostname_ssl: {} requester_pays: {} - force_virtual_addressing: {} - profile_name: {:?}", + force_virtual_addressing: {}", self.region_name, self.endpoint_url, self.key_id, self.session_token, self.access_key, + self.credentials_provider, + self.buffer_time, self.max_connections_per_io_thread, self.retry_initial_backoff_ms, self.connect_timeout_ms, @@ -140,8 +207,40 @@ impl Display for S3Config { self.verify_ssl, self.check_hostname_ssl, self.requester_pays, - self.force_virtual_addressing, - self.profile_name + self.force_virtual_addressing + )?; + Ok(()) + } +} + +impl S3Credentials { + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("Key ID = {}", self.key_id)); + res.push(format!("Access key = {}", self.access_key)); + + if let Some(session_token) = &self.session_token { + res.push(format!("Session token = {}", session_token)); + } + if let Some(expiry) = &self.expiry { + let expiry: DateTime = (*expiry).into(); + + res.push(format!("Expiry = {}", expiry.format("%Y-%m-%dT%H:%M:%S"))); + } + res + } +} + +impl Display for S3Credentials { + fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!( + f, + "S3Credentials + key_id: {:?} + session_token: {:?}, + access_key: {:?} + expiry: {:?}", + self.key_id, self.session_token, self.access_key, self.expiry, ) } } diff --git a/src/common/py-serde/Cargo.toml b/src/common/py-serde/Cargo.toml new file mode 100644 index 0000000000..a341bd1ab8 --- /dev/null +++ b/src/common/py-serde/Cargo.toml @@ -0,0 +1,12 @@ +[dependencies] +pyo3 = {workspace = true, optional = true} +serde = {workspace = true} + +[features] +default = ["python"] +python = ["dep:pyo3"] + +[package] +edition = {workspace = true} +name = "common-py-serde" +version = {workspace = true} diff --git a/src/common/py-serde/src/lib.rs b/src/common/py-serde/src/lib.rs new file mode 100644 index 0000000000..bb14399064 --- /dev/null +++ b/src/common/py-serde/src/lib.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "python")] +mod python; + +#[cfg(feature = "python")] +pub use crate::{python::deserialize_py_object, python::serialize_py_object}; diff --git a/src/daft-scan/src/py_object_serde.rs b/src/common/py-serde/src/python.rs similarity index 100% rename from src/daft-scan/src/py_object_serde.rs rename to src/common/py-serde/src/python.rs diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index e582ee6493..dbde82bded 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -14,7 +14,9 @@ use crate::stats::IOStatsRef; use crate::stream_utils::io_stats_on_bytestream; use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; -use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache}; +use aws_credential_types::cache::{ + CredentialsCache, ProvideCachedCredentials, SharedCredentialsCache, +}; use aws_credential_types::provider::error::CredentialsError; use aws_sig_auth::signer::SigningRequirements; use common_io_config::S3Config; @@ -354,6 +356,8 @@ async fn build_s3_conf( .or_default_provider() .await; Some(aws_credential_types::provider::SharedCredentialsProvider::new(provider)) + } else if let Some(provider) = &config.credentials_provider { + Some(aws_credential_types::provider::SharedCredentialsProvider::new(provider.clone())) } else if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( config.key_id.clone().unwrap(), @@ -374,25 +378,28 @@ async fn build_s3_conf( builder.set_credentials_provider(provider); builder.build() } else { - let loader = aws_config::from_env(); - let loader = if let Some(profile_name) = &config.profile_name { - loader.profile_name(profile_name) - } else { - loader - }; + let mut loader = aws_config::from_env(); + if let Some(profile_name) = &config.profile_name { + loader = loader.profile_name(profile_name); + } // Set region now to avoid imds - let loader = if let Some(region) = &config.region_name { - loader.region(Region::new(region.to_owned())) - } else { - loader - }; + if let Some(region) = &config.region_name { + loader = loader.region(Region::new(region.to_owned())); + } + // Set creds now to avoid imds - let loader = if let Some(provider) = provider { - loader.credentials_provider(provider) - } else { - loader - }; + if let Some(provider) = provider { + loader = loader.credentials_provider(provider); + } + + if let Some(buffer_time) = &config.buffer_time { + loader = loader.credentials_cache( + CredentialsCache::lazy_builder() + .buffer_time(Duration::from_secs(*buffer_time)) + .into_credentials_cache(), + ) + } loader.load().await }; diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 00f4d08297..ccfbbf1bd7 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -4,6 +4,7 @@ bincode = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index e3db25ed7f..7e9f130407 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -11,7 +11,7 @@ use crate::FileFormat; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] -use daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}; +use common_py_serde::{deserialize_py_object, serialize_py_object}; #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/src/daft-plan/src/source_info/mod.rs b/src/daft-plan/src/source_info/mod.rs index c625f5a475..49c701b6ae 100644 --- a/src/daft-plan/src/source_info/mod.rs +++ b/src/daft-plan/src/source_info/mod.rs @@ -11,7 +11,7 @@ use std::hash::Hasher; #[cfg(feature = "python")] use { - daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}, + common_py_serde::{deserialize_py_object, serialize_py_object}, pyo3::PyObject, }; diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 17dd9a59a2..bb7b0adea5 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -3,6 +3,7 @@ bincode = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index a9b5ca8ea0..d557322d5c 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -8,7 +8,7 @@ use std::hash::Hash; use std::{collections::BTreeMap, str::FromStr, sync::Arc}; #[cfg(feature = "python")] use { - crate::py_object_serde::{deserialize_py_object, serialize_py_object}, + common_py_serde::{deserialize_py_object, serialize_py_object}, daft_core::python::{datatype::PyTimeUnit, field::PyField}, pyo3::{ pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 496357a839..7f8f33a4bc 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -23,8 +23,6 @@ pub use anonymous::AnonymousScanOperator; pub mod file_format; mod glob; use common_daft_config::DaftExecutionConfig; -#[cfg(feature = "python")] -pub mod py_object_serde; pub mod scan_task_iters; #[cfg(feature = "python")] diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 89ea0bafe5..741fd84dc1 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -1,7 +1,7 @@ use pyo3::{prelude::*, types::PyTuple, AsPyPointer}; use serde::{Deserialize, Serialize}; -use crate::py_object_serde::{deserialize_py_object, serialize_py_object}; +use common_py_serde::{deserialize_py_object, serialize_py_object}; #[derive(Debug, Clone, Serialize, Deserialize)] struct PyObjectSerializableWrapper( diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index bd2dd735a1..175cace7ea 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -20,7 +20,7 @@ import daft from daft import DataCatalogTable, DataCatalogType from daft.io.object_store_options import io_config_to_storage_options -from tests.io.delta_lake.mock_aws_server import start_service, stop_process +from tests.io.mock_aws_server import start_service, stop_process @pytest.fixture(params=[1, 2, 8]) diff --git a/tests/io/delta_lake/mock_aws_server.py b/tests/io/mock_aws_server.py similarity index 100% rename from tests/io/delta_lake/mock_aws_server.py rename to tests/io/mock_aws_server.py diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py new file mode 100644 index 0000000000..16a98fadf0 --- /dev/null +++ b/tests/io/test_s3_credentials_refresh.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +import io +import os +import time +from collections.abc import Iterator + +import boto3 +import pytest + +import daft +from tests.io.mock_aws_server import start_service, stop_process + + +@pytest.fixture(scope="session") +def aws_log_file(tmp_path_factory: pytest.TempPathFactory) -> Iterator[io.IOBase]: + # NOTE(Clark): We have to use a log file for the mock AWS server's stdout/sterr. + # - If we use None, then the server output will spam stdout. + # - If we use PIPE, then the server will deadlock if the (relatively small) buffer fills, and the server is pretty + # noisy. + # - If we use DEVNULL, all log output is lost. + # With a tmp_path log file, we can prevent spam and deadlocks while also providing an avenue for debuggability, via + # changing this fixture to something persistent, or dumping the file to stdout before closing the file, etc. + tmp_path = tmp_path_factory.mktemp("aws_logging") + with open(tmp_path / "aws_log.txt", "w") as f: + yield f + + +def test_s3_credentials_refresh(aws_log_file: io.IOBase): + host = "127.0.0.1" + port = 5000 + + server_url = f"http://{host}:{port}" + + bucket_name = "mybucket" + file_name = "test.parquet" + + s3_file_path = f"s3://{bucket_name}/{file_name}" + + old_env = os.environ.copy() + # Set required AWS environment variables before starting server. + # Required to opt out of concurrent writing, since we don't provide a LockClient. + os.environ["AWS_S3_ALLOW_UNSAFE_RENAME"] = "true" + + # Start moto server. + process = start_service(host, port, aws_log_file) + + aws_credentials = { + "AWS_ACCESS_KEY_ID": "testing", + "AWS_SECRET_ACCESS_KEY": "testing", + "AWS_SESSION_TOKEN": "testing", + } + + s3 = boto3.resource( + "s3", + region_name="us-west-2", + use_ssl=False, + endpoint_url=server_url, + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + aws_session_token=aws_credentials["AWS_SESSION_TOKEN"], + ) + bucket = s3.Bucket(bucket_name) + bucket.create(CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + + count_get_credentials = 0 + + def get_credentials(): + nonlocal count_get_credentials + count_get_credentials += 1 + return daft.io.S3Credentials( + key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + session_token=aws_credentials["AWS_SESSION_TOKEN"], + expiry=(datetime.datetime.now() + datetime.timedelta(seconds=1)), + ) + + static_config = daft.io.IOConfig( + s3=daft.io.S3Config( + endpoint_url=server_url, + region_name="us-west-2", + key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + session_token=aws_credentials["AWS_SESSION_TOKEN"], + use_ssl=False, + ) + ) + + dynamic_config = daft.io.IOConfig( + s3=daft.io.S3Config( + endpoint_url=server_url, + region_name="us-west-2", + credentials_provider=get_credentials, + buffer_time=0, + use_ssl=False, + ) + ) + + df = daft.from_pydict({"a": [1, 2, 3]}) + df.write_parquet(s3_file_path, io_config=static_config) + + df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + assert count_get_credentials == 1 + + df.collect() + assert count_get_credentials == 1 + + df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + assert count_get_credentials == 1 + + time.sleep(1) + df.collect() + assert count_get_credentials == 2 + + # Shutdown moto server. + stop_process(process) + # Restore old set of environment variables. + os.environ = old_env