From 540e65c284d8b5270ae2b4fd393c325a1e293ba2 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Fri, 31 May 2024 14:54:50 -0700 Subject: [PATCH] [FEAT] Custom S3 Credentials Provider (#2233) This PR allows the user to provide their own S3 credentials provider as a function that returns S3 credentials. If provided, static credentials are ignored and the function is called at the start and every time S3 credentials are used and the provided credential expiry has passed. ```python import daft def get_creds() -> daft.io.S3Credentials: ... return daft.io.S3Credentials( key_id="[KEY_ID]" access_key="[ACCES KEY]" session_token="[SESSIONT_TOKEN]" expiry=expire_datetime ) io_config = daft.io.IOConfig(s3=daft.io.S3Config(credentials_provider=get_creds, ...) df = daft.read_parquet("s3://path/to/file.parquet", io_config=io_config) ``` --- Cargo.lock | 54 ++++ daft/daft.pyi | 15 + daft/io/__init__.py | 2 + docs/source/api_docs/configs.rst | 1 + src/common/io-config/Cargo.toml | 6 +- src/common/io-config/src/lib.rs | 4 +- src/common/io-config/src/python.rs | 273 ++++++++++++++++-- src/common/io-config/src/s3.rs | 115 +++++++- src/common/py-serde/Cargo.toml | 12 + src/common/py-serde/src/lib.rs | 5 + .../py-serde/src/python.rs} | 0 src/daft-io/src/s3_like.rs | 41 +-- src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/sink_info.rs | 2 +- src/daft-plan/src/source_info/mod.rs | 2 +- src/daft-scan/Cargo.toml | 1 + src/daft-scan/src/file_format.rs | 2 +- src/daft-scan/src/lib.rs | 2 - src/daft-scan/src/python.rs | 2 +- tests/io/delta_lake/conftest.py | 2 +- tests/io/{delta_lake => }/mock_aws_server.py | 0 tests/io/test_s3_credentials_refresh.py | 119 ++++++++ 22 files changed, 600 insertions(+), 61 deletions(-) create mode 100644 src/common/py-serde/Cargo.toml create mode 100644 src/common/py-serde/src/lib.rs rename src/{daft-scan/src/py_object_serde.rs => common/py-serde/src/python.rs} (100%) rename tests/io/{delta_lake => }/mock_aws_server.py (100%) create mode 100644 tests/io/test_s3_credentials_refresh.py diff --git a/Cargo.lock b/Cargo.lock index 512b46c24f..64bb4347b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -975,10 +975,22 @@ dependencies = [ name = "common-io-config" version = "0.2.0-dev0" dependencies = [ + "aws-credential-types", + "chrono", "common-error", + "common-py-serde", "pyo3", "serde", "serde_json", + "typetag", +] + +[[package]] +name = "common-py-serde" +version = "0.2.0-dev0" +dependencies = [ + "pyo3", + "serde", ] [[package]] @@ -1455,6 +1467,7 @@ dependencies = [ "common-daft-config", "common-error", "common-io-config", + "common-py-serde", "common-treenode", "daft-core", "daft-dsl", @@ -1479,6 +1492,7 @@ dependencies = [ "common-daft-config", "common-error", "common-io-config", + "common-py-serde", "daft-core", "daft-csv", "daft-dsl", @@ -1611,6 +1625,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d" +dependencies = [ + "serde", + "typeid", +] + [[package]] name = "errno" version = "0.3.5" @@ -4133,12 +4157,42 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "typeid" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059d83cc991e7a42fc37bd50941885db0888e34209f8cfd9aab07ddec03bc9cf" + [[package]] name = "typenum" version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "typetag" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661d18414ec032a49ece2d56eee03636e43c4e8d577047ab334c0ba892e29aaf" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.57", +] + [[package]] name = "unicase" version = "2.7.0" diff --git a/daft/daft.pyi b/daft/daft.pyi index 89ae597598..2c759eba0a 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1,4 +1,5 @@ import builtins +import datetime from enum import Enum from typing import TYPE_CHECKING, Any, Callable @@ -417,6 +418,7 @@ class S3Config: key_id: str | None session_token: str | None access_key: str | None + credentials_provider: Callable[[], S3Credentials] | None max_connections: int retry_initial_backoff_ms: int connect_timeout_ms: int @@ -438,6 +440,8 @@ class S3Config: key_id: str | None = None, session_token: str | None = None, access_key: str | None = None, + credentials_provider: Callable[[], S3Credentials] | None = None, + buffer_time: int | None = None, max_connections: int | None = None, retry_initial_backoff_ms: int | None = None, connect_timeout_ms: int | None = None, @@ -459,6 +463,7 @@ class S3Config: key_id: str | None = None, session_token: str | None = None, access_key: str | None = None, + credentials_provider: Callable[[], S3Credentials] | None = None, max_connections: int | None = None, retry_initial_backoff_ms: int | None = None, connect_timeout_ms: int | None = None, @@ -481,6 +486,16 @@ class S3Config: """Creates an S3Config, retrieving credentials and configurations from the current environment""" ... +class S3Credentials: + key_id: str + access_key: str + session_token: str | None + expiry: datetime.datetime | None + + def __init__( + self, key_id: str, access_key: str, session_token: str | None = None, expiry: datetime.datetime | None = None + ): ... + class AzureConfig: """ I/O configuration for accessing Azure Blob Storage. diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 166fe12bdc..b56e6919b6 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -7,6 +7,7 @@ GCSConfig, IOConfig, S3Config, + S3Credentials, set_io_pool_num_threads, ) from daft.io._csv import read_csv @@ -47,6 +48,7 @@ def _set_linux_cert_paths(): "read_sql", "IOConfig", "S3Config", + "S3Credentials", "AzureConfig", "GCSConfig", "set_io_pool_num_threads", diff --git a/docs/source/api_docs/configs.rst b/docs/source/api_docs/configs.rst index 9282bbe61c..cc1322b08d 100644 --- a/docs/source/api_docs/configs.rst +++ b/docs/source/api_docs/configs.rst @@ -38,5 +38,6 @@ These configurations are most often used as inputs to Daft DataFrame reading I/O daft.io.IOConfig daft.io.S3Config + daft.io.S3Credentials daft.io.GCSConfig daft.io.AzureConfig diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index 307de25498..49d94f718e 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -1,12 +1,16 @@ [dependencies] +aws-credential-types = {version = "0.55.3"} +chrono = {workspace = true} common-error = {path = "../error", default-features = false} +common-py-serde = {path = "../py-serde", default-features = false} pyo3 = {workspace = true, optional = true} serde = {workspace = true} serde_json = {workspace = true} +typetag = "0.2.16" [features] default = ["python"] -python = ["dep:pyo3", "common-error/python"] +python = ["dep:pyo3", "common-error/python", "common-py-serde/python"] [package] edition = {workspace = true} diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index 535ac72d40..d0952541b4 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -6,4 +6,6 @@ mod config; mod gcs; mod s3; -pub use crate::{azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config}; +pub use crate::{ + azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config, s3::S3Credentials, +}; diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index ebc13a96b6..f3b7aa16da 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -1,32 +1,44 @@ -use std::hash::Hasher; - +use std::{ + any::Any, + hash::{Hash, Hasher}, + time::{Duration, SystemTime}, +}; + +use aws_credential_types::{ + provider::{error::CredentialsError, ProvideCredentials}, + Credentials, +}; use common_error::DaftError; +use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; -use crate::config; +use crate::{config, s3::S3CredentialsProvider}; /// Create configurations to be used when accessing an S3-compatible system /// /// Args: -/// region_name: Name of the region to be used (used when accessing AWS S3), defaults to "us-east-1". +/// region_name (str, optional): Name of the region to be used (used when accessing AWS S3), defaults to "us-east-1". /// If wrongly provided, Daft will attempt to auto-detect the buckets' region at the cost of extra S3 requests. -/// endpoint_url: URL to the S3 endpoint, defaults to endpoints to AWS -/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment -/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment -/// max_connections: Maximum number of connections to S3 at any time, defaults to 64 -/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials -/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms -/// connect_timeout_ms: Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 10 seconds -/// read_timeout_ms: Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 10 seconds -/// num_tries: Number of attempts to make a connection, defaults to 5 -/// retry_mode: Retry Mode when a request fails, current supported values are `standard` and `adaptive`, defaults to `adaptive` -/// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials -/// use_ssl: Whether or not to use SSL, which require accessing S3 over HTTPS rather than HTTP, defaults to True -/// verify_ssl: Whether or not to verify ssl certificates, which will access S3 without checking if the certs are valid, defaults to True -/// check_hostname_ssl: Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True -/// requester_pays: Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False -/// force_virtual_addressing: Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False -/// profile_name: Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default` +/// endpoint_url (str, optional): URL to the S3 endpoint, defaults to endpoints to AWS +/// key_id (str, optional): AWS Access Key ID, defaults to auto-detection from the current environment +/// access_key (str, optional): AWS Secret Access Key, defaults to auto-detection from the current environment +/// credentials_provider (Callable[[], S3Credentials], optional): Custom credentials provider function, should return a `S3Credentials` object +/// buffer_time (int, optional): Amount of time in seconds before the actual credential expiration time where credentials given by `credentials_provider` are considered expired, defaults to 10s +/// max_connections (int, optional): Maximum number of connections to S3 at any time, defaults to 64 +/// session_token (str, optional): AWS Session Token, required only if `key_id` and `access_key` are temporary credentials +/// retry_initial_backoff_ms (int, optional): Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms +/// connect_timeout_ms (int, optional): Timeout duration to wait to make a connection to S3 in milliseconds, defaults to 10 seconds +/// read_timeout_ms (int, optional): Timeout duration to wait to read the first byte from S3 in milliseconds, defaults to 10 seconds +/// num_tries (int, optional): Number of attempts to make a connection, defaults to 5 +/// retry_mode (str, optional): Retry Mode when a request fails, current supported values are `standard` and `adaptive`, defaults to `adaptive` +/// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access S3 without any credentials +/// use_ssl (bool, optional): Whether or not to use SSL, which require accessing S3 over HTTPS rather than HTTP, defaults to True +/// verify_ssl (bool, optional): Whether or not to verify ssl certificates, which will access S3 without checking if the certs are valid, defaults to True +/// check_hostname_ssl (bool, optional): Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True +/// requester_pays (bool, optional): Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False +/// force_virtual_addressing (bool, optional): Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False +/// profile_name (str, optional): Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default` /// /// Example: /// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx")) @@ -36,6 +48,29 @@ use crate::config; pub struct S3Config { pub config: crate::S3Config, } + +/// Create credentials to be used when accessing an S3-compatible system +/// +/// Args: +/// key_id (str): AWS Access Key ID, defaults to auto-detection from the current environment +/// access_key (str): AWS Secret Access Key, defaults to auto-detection from the current environment +/// session_token (str, optional): AWS Session Token, required only if `key_id` and `access_key` are temporary credentials +/// expiry (datetime.datetime, optional): Expiry time of the credentials, credentials are assumed to be permanent if not provided +/// +/// Example: +/// >>> get_credentials = lambda: S3Credentials( +/// ... key_id="xxx", +/// ... access_key="xxx", +/// ... expiry=(datetime.datetime.now() + datetime.timedelta(hours=1)) +/// ... ) +/// >>> io_config = IOConfig(s3=S3Config(credentials_provider=get_credentials)) +/// >>> daft.read_parquet("s3://some-path", io_config=io_config) +#[derive(Clone)] +#[pyclass] +pub struct S3Credentials { + pub credentials: crate::S3Credentials, +} + /// Create configurations to be used when accessing Azure Blob Storage /// /// Args: @@ -172,11 +207,14 @@ impl S3Config { #[allow(clippy::too_many_arguments)] #[new] pub fn new( + py: Python, region_name: Option, endpoint_url: Option, key_id: Option, session_token: Option, access_key: Option, + credentials_provider: Option<&PyAny>, + buffer_time: Option, max_connections: Option, retry_initial_backoff_ms: Option, connect_timeout_ms: Option, @@ -190,15 +228,23 @@ impl S3Config { requester_pays: Option, force_virtual_addressing: Option, profile_name: Option, - ) -> Self { + ) -> PyResult { let def = crate::S3Config::default(); - S3Config { + Ok(S3Config { config: crate::S3Config { region_name: region_name.or(def.region_name), endpoint_url: endpoint_url.or(def.endpoint_url), key_id: key_id.or(def.key_id), session_token: session_token.or(def.session_token), access_key: access_key.or(def.access_key), + credentials_provider: credentials_provider + .map(|p| { + Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(py, p)?) + as Box) + }) + .transpose()? + .or(def.credentials_provider), + buffer_time: buffer_time.or(def.buffer_time), max_connections_per_io_thread: max_connections .unwrap_or(def.max_connections_per_io_thread), retry_initial_backoff_ms: retry_initial_backoff_ms @@ -216,17 +262,20 @@ impl S3Config { .unwrap_or(def.force_virtual_addressing), profile_name: profile_name.or(def.profile_name), }, - } + }) } #[allow(clippy::too_many_arguments)] pub fn replace( &self, + py: Python, region_name: Option, endpoint_url: Option, key_id: Option, session_token: Option, access_key: Option, + credentials_provider: Option<&PyAny>, + buffer_time: Option, max_connections: Option, retry_initial_backoff_ms: Option, connect_timeout_ms: Option, @@ -240,14 +289,22 @@ impl S3Config { requester_pays: Option, force_virtual_addressing: Option, profile_name: Option, - ) -> Self { - S3Config { + ) -> PyResult { + Ok(S3Config { config: crate::S3Config { region_name: region_name.or_else(|| self.config.region_name.clone()), endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()), key_id: key_id.or_else(|| self.config.key_id.clone()), session_token: session_token.or_else(|| self.config.session_token.clone()), access_key: access_key.or_else(|| self.config.access_key.clone()), + credentials_provider: credentials_provider + .map(|p| { + Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(py, p)?) + as Box) + }) + .transpose()? + .or_else(|| self.config.credentials_provider.clone()), + buffer_time: buffer_time.or(self.config.buffer_time), max_connections_per_io_thread: max_connections .unwrap_or(self.config.max_connections_per_io_thread), retry_initial_backoff_ms: retry_initial_backoff_ms @@ -265,7 +322,7 @@ impl S3Config { .unwrap_or(self.config.force_virtual_addressing), profile_name: profile_name.or_else(|| self.config.profile_name.clone()), }, - } + }) } /// Creates an S3Config from the current environment, auto-discovering variables such as @@ -323,6 +380,22 @@ impl S3Config { Ok(self.config.max_connections_per_io_thread) } + /// Custom credentials provider function + #[getter] + pub fn credentials_provider(&self, py: Python) -> PyResult>> { + Ok(self.config.credentials_provider.as_ref().and_then(|p| { + p.as_any() + .downcast_ref::() + .map(|p| p.provider.as_ref(py).into()) + })) + } + + /// AWS Buffer Time in Seconds + #[getter] + pub fn buffer_time(&self) -> PyResult> { + Ok(self.config.buffer_time) + } + /// AWS Retry Initial Backoff Time in Milliseconds #[getter] pub fn retry_initial_backoff_ms(&self) -> PyResult { @@ -396,6 +469,151 @@ impl S3Config { } } +#[pymethods] +impl S3Credentials { + #[new] + pub fn new( + key_id: String, + access_key: String, + session_token: Option, + expiry: Option<&PyAny>, + ) -> PyResult { + // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) + let expiry = expiry + .map(|e| { + let ts = e.call_method0("timestamp")?.extract()?; + + Ok::<_, PyErr>(SystemTime::UNIX_EPOCH + Duration::from_secs_f64(ts)) + }) + .transpose()?; + + Ok(S3Credentials { + credentials: crate::S3Credentials { + key_id, + access_key, + session_token, + expiry, + }, + }) + } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.credentials)) + } + + /// AWS Access Key ID + #[getter] + pub fn key_id(&self) -> PyResult { + Ok(self.credentials.key_id.clone()) + } + + /// AWS Secret Access Key + #[getter] + pub fn access_key(&self) -> PyResult { + Ok(self.credentials.access_key.clone()) + } + + /// AWS Session Token + #[getter] + pub fn expiry<'a>(&self, py: Python<'a>) -> PyResult> { + // TODO(Kevin): Refactor when upgrading to PyO3 0.21 (https://github.com/Eventual-Inc/Daft/issues/2288) + self.credentials + .expiry + .map(|e| { + let datetime = py.import("datetime")?; + + datetime.getattr("datetime")?.call_method1( + "fromtimestamp", + (e.duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs_f64(),), + ) + }) + .transpose() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PyS3CredentialsProvider { + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + pub provider: PyObject, + pub hash: isize, +} + +impl PyS3CredentialsProvider { + pub fn new(py: Python, provider: &PyAny) -> PyResult { + Ok(PyS3CredentialsProvider { + provider: provider.to_object(py), + hash: provider.hash()?, + }) + } +} + +impl ProvideCredentials for PyS3CredentialsProvider { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + aws_credential_types::provider::future::ProvideCredentials::ready( + Python::with_gil(|py| { + let py_creds = self.provider.call0(py)?; + py_creds.extract::(py) + }) + .map_err(|e| CredentialsError::provider_error(Box::new(e))) + .map(|creds| { + Credentials::new( + creds.credentials.key_id, + creds.credentials.access_key, + creds.credentials.session_token, + creds.credentials.expiry, + "daft_custom_provider", + ) + }), + ) + } +} + +impl PartialEq for PyS3CredentialsProvider { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash + } +} + +impl Eq for PyS3CredentialsProvider {} + +impl Hash for PyS3CredentialsProvider { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +#[typetag::serde] +impl S3CredentialsProvider for PyS3CredentialsProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn dyn_eq(&self, other: &dyn S3CredentialsProvider) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |other| self == other) + } + + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state) + } +} + #[pymethods] impl AzureConfig { #[allow(clippy::too_many_arguments)] @@ -523,6 +741,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; Ok(()) } diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index f7efdf3607..0f54ac92a9 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -1,16 +1,25 @@ -use std::fmt::Display; -use std::fmt::Formatter; - +use aws_credential_types::provider::ProvideCredentials; +use chrono::offset::Utc; +use chrono::DateTime; use serde::Deserialize; use serde::Serialize; +use std::any::Any; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::hash::Hash; +use std::hash::Hasher; +use std::time::SystemTime; -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)] pub struct S3Config { pub region_name: Option, pub endpoint_url: Option, pub key_id: Option, pub session_token: Option, pub access_key: Option, + pub credentials_provider: Option>, + pub buffer_time: Option, pub max_connections_per_io_thread: u32, pub retry_initial_backoff_ms: u64, pub connect_timeout_ms: u64, @@ -26,6 +35,53 @@ pub struct S3Config { pub profile_name: Option, } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct S3Credentials { + pub key_id: String, + pub access_key: String, + pub session_token: Option, + pub expiry: Option, +} + +#[typetag::serde(tag = "type")] +pub trait S3CredentialsProvider: ProvideCredentials + Debug { + fn as_any(&self) -> &dyn Any; + fn clone_box(&self) -> Box; + fn dyn_eq(&self, other: &dyn S3CredentialsProvider) -> bool; + fn dyn_hash(&self, state: &mut dyn Hasher); +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.clone_box() + } +} + +impl PartialEq for Box { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_ref()) + } +} + +impl Eq for Box {} + +impl Hash for Box { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +impl ProvideCredentials for Box { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + self.as_ref().provide_credentials() + } +} + impl S3Config { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; @@ -44,6 +100,12 @@ impl S3Config { if let Some(access_key) = &self.access_key { res.push(format!("Access key = {}", access_key)); } + if let Some(credentials_provider) = &self.credentials_provider { + res.push(format!("Credentials provider = {:?}", credentials_provider)); + } + if let Some(buffer_time) = &self.buffer_time { + res.push(format!("Buffer time = {}", buffer_time)); + } res.push(format!( "Max connections = {}", self.max_connections_per_io_thread @@ -82,6 +144,8 @@ impl Default for S3Config { key_id: None, session_token: None, access_key: None, + credentials_provider: None, + buffer_time: None, max_connections_per_io_thread: 8, retry_initial_backoff_ms: 1000, connect_timeout_ms: 30_000, @@ -111,6 +175,8 @@ impl Display for S3Config { key_id: {:?} session_token: {:?}, access_key: {:?} + credentials_provider: {:?} + buffer_time: {:?} max_connections: {}, retry_initial_backoff_ms: {}, connect_timeout_ms: {}, @@ -122,13 +188,14 @@ impl Display for S3Config { verify_ssl: {}, check_hostname_ssl: {} requester_pays: {} - force_virtual_addressing: {} - profile_name: {:?}", + force_virtual_addressing: {}", self.region_name, self.endpoint_url, self.key_id, self.session_token, self.access_key, + self.credentials_provider, + self.buffer_time, self.max_connections_per_io_thread, self.retry_initial_backoff_ms, self.connect_timeout_ms, @@ -140,8 +207,40 @@ impl Display for S3Config { self.verify_ssl, self.check_hostname_ssl, self.requester_pays, - self.force_virtual_addressing, - self.profile_name + self.force_virtual_addressing + )?; + Ok(()) + } +} + +impl S3Credentials { + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("Key ID = {}", self.key_id)); + res.push(format!("Access key = {}", self.access_key)); + + if let Some(session_token) = &self.session_token { + res.push(format!("Session token = {}", session_token)); + } + if let Some(expiry) = &self.expiry { + let expiry: DateTime = (*expiry).into(); + + res.push(format!("Expiry = {}", expiry.format("%Y-%m-%dT%H:%M:%S"))); + } + res + } +} + +impl Display for S3Credentials { + fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!( + f, + "S3Credentials + key_id: {:?} + session_token: {:?}, + access_key: {:?} + expiry: {:?}", + self.key_id, self.session_token, self.access_key, self.expiry, ) } } diff --git a/src/common/py-serde/Cargo.toml b/src/common/py-serde/Cargo.toml new file mode 100644 index 0000000000..a341bd1ab8 --- /dev/null +++ b/src/common/py-serde/Cargo.toml @@ -0,0 +1,12 @@ +[dependencies] +pyo3 = {workspace = true, optional = true} +serde = {workspace = true} + +[features] +default = ["python"] +python = ["dep:pyo3"] + +[package] +edition = {workspace = true} +name = "common-py-serde" +version = {workspace = true} diff --git a/src/common/py-serde/src/lib.rs b/src/common/py-serde/src/lib.rs new file mode 100644 index 0000000000..bb14399064 --- /dev/null +++ b/src/common/py-serde/src/lib.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "python")] +mod python; + +#[cfg(feature = "python")] +pub use crate::{python::deserialize_py_object, python::serialize_py_object}; diff --git a/src/daft-scan/src/py_object_serde.rs b/src/common/py-serde/src/python.rs similarity index 100% rename from src/daft-scan/src/py_object_serde.rs rename to src/common/py-serde/src/python.rs diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index e582ee6493..dbde82bded 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -14,7 +14,9 @@ use crate::stats::IOStatsRef; use crate::stream_utils::io_stats_on_bytestream; use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; -use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache}; +use aws_credential_types::cache::{ + CredentialsCache, ProvideCachedCredentials, SharedCredentialsCache, +}; use aws_credential_types::provider::error::CredentialsError; use aws_sig_auth::signer::SigningRequirements; use common_io_config::S3Config; @@ -354,6 +356,8 @@ async fn build_s3_conf( .or_default_provider() .await; Some(aws_credential_types::provider::SharedCredentialsProvider::new(provider)) + } else if let Some(provider) = &config.credentials_provider { + Some(aws_credential_types::provider::SharedCredentialsProvider::new(provider.clone())) } else if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( config.key_id.clone().unwrap(), @@ -374,25 +378,28 @@ async fn build_s3_conf( builder.set_credentials_provider(provider); builder.build() } else { - let loader = aws_config::from_env(); - let loader = if let Some(profile_name) = &config.profile_name { - loader.profile_name(profile_name) - } else { - loader - }; + let mut loader = aws_config::from_env(); + if let Some(profile_name) = &config.profile_name { + loader = loader.profile_name(profile_name); + } // Set region now to avoid imds - let loader = if let Some(region) = &config.region_name { - loader.region(Region::new(region.to_owned())) - } else { - loader - }; + if let Some(region) = &config.region_name { + loader = loader.region(Region::new(region.to_owned())); + } + // Set creds now to avoid imds - let loader = if let Some(provider) = provider { - loader.credentials_provider(provider) - } else { - loader - }; + if let Some(provider) = provider { + loader = loader.credentials_provider(provider); + } + + if let Some(buffer_time) = &config.buffer_time { + loader = loader.credentials_cache( + CredentialsCache::lazy_builder() + .buffer_time(Duration::from_secs(*buffer_time)) + .into_credentials_cache(), + ) + } loader.load().await }; diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 00f4d08297..ccfbbf1bd7 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -4,6 +4,7 @@ bincode = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index e3db25ed7f..7e9f130407 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -11,7 +11,7 @@ use crate::FileFormat; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] -use daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}; +use common_py_serde::{deserialize_py_object, serialize_py_object}; #[allow(clippy::large_enum_variant)] #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/src/daft-plan/src/source_info/mod.rs b/src/daft-plan/src/source_info/mod.rs index c625f5a475..49c701b6ae 100644 --- a/src/daft-plan/src/source_info/mod.rs +++ b/src/daft-plan/src/source_info/mod.rs @@ -11,7 +11,7 @@ use std::hash::Hasher; #[cfg(feature = "python")] use { - daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}, + common_py_serde::{deserialize_py_object, serialize_py_object}, pyo3::PyObject, }; diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 17dd9a59a2..bb7b0adea5 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -3,6 +3,7 @@ bincode = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index a9b5ca8ea0..d557322d5c 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -8,7 +8,7 @@ use std::hash::Hash; use std::{collections::BTreeMap, str::FromStr, sync::Arc}; #[cfg(feature = "python")] use { - crate::py_object_serde::{deserialize_py_object, serialize_py_object}, + common_py_serde::{deserialize_py_object, serialize_py_object}, daft_core::python::{datatype::PyTimeUnit, field::PyField}, pyo3::{ pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 496357a839..7f8f33a4bc 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -23,8 +23,6 @@ pub use anonymous::AnonymousScanOperator; pub mod file_format; mod glob; use common_daft_config::DaftExecutionConfig; -#[cfg(feature = "python")] -pub mod py_object_serde; pub mod scan_task_iters; #[cfg(feature = "python")] diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 89ea0bafe5..741fd84dc1 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -1,7 +1,7 @@ use pyo3::{prelude::*, types::PyTuple, AsPyPointer}; use serde::{Deserialize, Serialize}; -use crate::py_object_serde::{deserialize_py_object, serialize_py_object}; +use common_py_serde::{deserialize_py_object, serialize_py_object}; #[derive(Debug, Clone, Serialize, Deserialize)] struct PyObjectSerializableWrapper( diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index bd2dd735a1..175cace7ea 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -20,7 +20,7 @@ import daft from daft import DataCatalogTable, DataCatalogType from daft.io.object_store_options import io_config_to_storage_options -from tests.io.delta_lake.mock_aws_server import start_service, stop_process +from tests.io.mock_aws_server import start_service, stop_process @pytest.fixture(params=[1, 2, 8]) diff --git a/tests/io/delta_lake/mock_aws_server.py b/tests/io/mock_aws_server.py similarity index 100% rename from tests/io/delta_lake/mock_aws_server.py rename to tests/io/mock_aws_server.py diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py new file mode 100644 index 0000000000..16a98fadf0 --- /dev/null +++ b/tests/io/test_s3_credentials_refresh.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +import io +import os +import time +from collections.abc import Iterator + +import boto3 +import pytest + +import daft +from tests.io.mock_aws_server import start_service, stop_process + + +@pytest.fixture(scope="session") +def aws_log_file(tmp_path_factory: pytest.TempPathFactory) -> Iterator[io.IOBase]: + # NOTE(Clark): We have to use a log file for the mock AWS server's stdout/sterr. + # - If we use None, then the server output will spam stdout. + # - If we use PIPE, then the server will deadlock if the (relatively small) buffer fills, and the server is pretty + # noisy. + # - If we use DEVNULL, all log output is lost. + # With a tmp_path log file, we can prevent spam and deadlocks while also providing an avenue for debuggability, via + # changing this fixture to something persistent, or dumping the file to stdout before closing the file, etc. + tmp_path = tmp_path_factory.mktemp("aws_logging") + with open(tmp_path / "aws_log.txt", "w") as f: + yield f + + +def test_s3_credentials_refresh(aws_log_file: io.IOBase): + host = "127.0.0.1" + port = 5000 + + server_url = f"http://{host}:{port}" + + bucket_name = "mybucket" + file_name = "test.parquet" + + s3_file_path = f"s3://{bucket_name}/{file_name}" + + old_env = os.environ.copy() + # Set required AWS environment variables before starting server. + # Required to opt out of concurrent writing, since we don't provide a LockClient. + os.environ["AWS_S3_ALLOW_UNSAFE_RENAME"] = "true" + + # Start moto server. + process = start_service(host, port, aws_log_file) + + aws_credentials = { + "AWS_ACCESS_KEY_ID": "testing", + "AWS_SECRET_ACCESS_KEY": "testing", + "AWS_SESSION_TOKEN": "testing", + } + + s3 = boto3.resource( + "s3", + region_name="us-west-2", + use_ssl=False, + endpoint_url=server_url, + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + aws_session_token=aws_credentials["AWS_SESSION_TOKEN"], + ) + bucket = s3.Bucket(bucket_name) + bucket.create(CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + + count_get_credentials = 0 + + def get_credentials(): + nonlocal count_get_credentials + count_get_credentials += 1 + return daft.io.S3Credentials( + key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + session_token=aws_credentials["AWS_SESSION_TOKEN"], + expiry=(datetime.datetime.now() + datetime.timedelta(seconds=1)), + ) + + static_config = daft.io.IOConfig( + s3=daft.io.S3Config( + endpoint_url=server_url, + region_name="us-west-2", + key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + session_token=aws_credentials["AWS_SESSION_TOKEN"], + use_ssl=False, + ) + ) + + dynamic_config = daft.io.IOConfig( + s3=daft.io.S3Config( + endpoint_url=server_url, + region_name="us-west-2", + credentials_provider=get_credentials, + buffer_time=0, + use_ssl=False, + ) + ) + + df = daft.from_pydict({"a": [1, 2, 3]}) + df.write_parquet(s3_file_path, io_config=static_config) + + df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + assert count_get_credentials == 1 + + df.collect() + assert count_get_credentials == 1 + + df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + assert count_get_credentials == 1 + + time.sleep(1) + df.collect() + assert count_get_credentials == 2 + + # Shutdown moto server. + stop_process(process) + # Restore old set of environment variables. + os.environ = old_env