Skip to content

Commit

Permalink
[FEAT] Custom S3 Credentials Provider (#2233)
Browse files Browse the repository at this point in the history
This PR allows the user to provide their own S3 credentials provider as
a function that returns S3 credentials. If provided, static credentials
are ignored and the function is called at the start and every time S3
credentials are used and the provided credential expiry has passed.

```python
import daft

def get_creds() -> daft.io.S3Credentials:
    ...
    return daft.io.S3Credentials(
        key_id="[KEY_ID]"
        access_key="[ACCES KEY]"
        session_token="[SESSIONT_TOKEN]"
        expiry=expire_datetime
    )

io_config = daft.io.IOConfig(s3=daft.io.S3Config(credentials_provider=get_creds, ...)

df = daft.read_parquet("s3://path/to/file.parquet", io_config=io_config)
```
  • Loading branch information
kevinzwang authored May 31, 2024
1 parent 0ba9a19 commit 540e65c
Show file tree
Hide file tree
Showing 22 changed files with 600 additions and 61 deletions.
54 changes: 54 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import builtins
import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -417,6 +418,7 @@ class S3Config:
key_id: str | None
session_token: str | None
access_key: str | None
credentials_provider: Callable[[], S3Credentials] | None
max_connections: int
retry_initial_backoff_ms: int
connect_timeout_ms: int
Expand All @@ -438,6 +440,8 @@ class S3Config:
key_id: str | None = None,
session_token: str | None = None,
access_key: str | None = None,
credentials_provider: Callable[[], S3Credentials] | None = None,
buffer_time: int | None = None,
max_connections: int | None = None,
retry_initial_backoff_ms: int | None = None,
connect_timeout_ms: int | None = None,
Expand All @@ -459,6 +463,7 @@ class S3Config:
key_id: str | None = None,
session_token: str | None = None,
access_key: str | None = None,
credentials_provider: Callable[[], S3Credentials] | None = None,
max_connections: int | None = None,
retry_initial_backoff_ms: int | None = None,
connect_timeout_ms: int | None = None,
Expand All @@ -481,6 +486,16 @@ class S3Config:
"""Creates an S3Config, retrieving credentials and configurations from the current environment"""
...

class S3Credentials:
key_id: str
access_key: str
session_token: str | None
expiry: datetime.datetime | None

def __init__(
self, key_id: str, access_key: str, session_token: str | None = None, expiry: datetime.datetime | None = None
): ...

class AzureConfig:
"""
I/O configuration for accessing Azure Blob Storage.
Expand Down
2 changes: 2 additions & 0 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
GCSConfig,
IOConfig,
S3Config,
S3Credentials,
set_io_pool_num_threads,
)
from daft.io._csv import read_csv
Expand Down Expand Up @@ -47,6 +48,7 @@ def _set_linux_cert_paths():
"read_sql",
"IOConfig",
"S3Config",
"S3Credentials",
"AzureConfig",
"GCSConfig",
"set_io_pool_num_threads",
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ These configurations are most often used as inputs to Daft DataFrame reading I/O

daft.io.IOConfig
daft.io.S3Config
daft.io.S3Credentials
daft.io.GCSConfig
daft.io.AzureConfig
6 changes: 5 additions & 1 deletion src/common/io-config/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
[dependencies]
aws-credential-types = {version = "0.55.3"}
chrono = {workspace = true}
common-error = {path = "../error", default-features = false}
common-py-serde = {path = "../py-serde", default-features = false}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}
serde_json = {workspace = true}
typetag = "0.2.16"

[features]
default = ["python"]
python = ["dep:pyo3", "common-error/python"]
python = ["dep:pyo3", "common-error/python", "common-py-serde/python"]

[package]
edition = {workspace = true}
Expand Down
4 changes: 3 additions & 1 deletion src/common/io-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ mod config;
mod gcs;
mod s3;

pub use crate::{azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config};
pub use crate::{
azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config, s3::S3Credentials,
};
Loading

0 comments on commit 540e65c

Please sign in to comment.