Skip to content

Commit

Permalink
Merge pull request #114 from ThomasMarwitzQC/add-session-support-s3
Browse files Browse the repository at this point in the history
Allow `session_token` to be set for S3 buckets. Truly isolates stores by explicitly setting credentials for `S3FS`.
  • Loading branch information
xhochy authored Feb 27, 2024
2 parents 789a5f8 + 25806b2 commit 077305c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
*********

1.9.0
=====
* Add `session_token` url param that can be set when creating a `[h]s3://` store
via `get_store_from_url`.

1.8.6
=====
* We undeprecated ``url2dict`` and ``extract_params`` as these functions turned
Expand Down
61 changes: 54 additions & 7 deletions minimalkv/net/s3fsstore.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Dict
from typing import Dict, NamedTuple, Optional

from uritools import SplitResult

Expand All @@ -22,10 +22,27 @@
)


class Credentials(NamedTuple):
"""Dataclass to hold AWS credentials."""

access_key_id: Optional[str]
secret_access_key: Optional[str]
session_token: Optional[str]

def as_boto3_params(self):
"""Return the credentials as a dictionary suitable for boto3 authentication."""
return {
"aws_access_key_id": self.access_key_id,
"aws_secret_access_key": self.secret_access_key,
"aws_session_token": self.session_token,
}


class S3FSStore(FSSpecStore, UrlMixin): # noqa D
def __init__(
self,
bucket,
credentials: Optional[Credentials] = None,
object_prefix="",
url_valid_time=0,
reduced_redundancy=False,
Expand All @@ -37,12 +54,14 @@ def __init__(
if isinstance(bucket, str):
import boto3

s3_resource = boto3.resource("s3")
boto3_params = credentials.as_boto3_params() if credentials else {}
s3_resource = boto3.resource("s3", **boto3_params)
bucket = s3_resource.Bucket(bucket)
if bucket not in s3_resource.buckets.all():
raise ValueError("invalid s3 bucket name")

self.bucket = bucket
self.credentials = credentials
self.object_prefix = object_prefix.strip().lstrip("/")
self.url_valid_time = url_valid_time
self.reduced_redundancy = reduced_redundancy
Expand Down Expand Up @@ -74,6 +93,16 @@ def _create_filesystem(self) -> "S3FileSystem":
client_kwargs["endpoint_url"] = self.endpoint_url
if self.region_name:
client_kwargs["region_name"] = self.region_name

if self.credentials:
return S3FileSystem(
key=self.credentials.access_key_id,
secret=self.credentials.secret_access_key,
token=self.credentials.session_token,
anon=False,
client_kwargs=client_kwargs,
)

return S3FileSystem(
anon=False,
client_kwargs=client_kwargs,
Expand Down Expand Up @@ -117,10 +146,22 @@ def _from_parsed_url(
``region_name`` (default: ``None``): If set the AWS region name is applied as location
constraint during bucket creation.
``session_token``(default: ``None``): If set this token will be used in conjunction
with access_key_id and secret_access_key for authentication.
**Notes**:
If the scheme is ``hs3``, an ``HS3FSStore`` is returned which allows ``/`` in key names.
If the credentials are not provided through the url, they are attempted to be
loaded from the environment variables `AWS_ACCESS_KEY_ID`,
`AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN`. If these variables are not set,
the search for credentials will be delegated to boto(core).
Positional arguments should be encoded by `urllib.parse.quote_plus`
if they contain special characters e.g. "/".
Parameters
----------
parsed_url: SplitResult
Expand All @@ -137,6 +178,7 @@ def _from_parsed_url(

url_access_key_id = _get_username(parsed_url)
url_secret_access_key = _get_password(parsed_url)
url_session_token = query.get("session_token", None)

if url_access_key_id is None:
url_secret_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
Expand All @@ -148,10 +190,13 @@ def _from_parsed_url(
else:
os.environ["AWS_SECRET_ACCESS_KEY"] = url_secret_access_key

boto3_params = {
"aws_access_key_id": url_access_key_id,
"aws_secret_access_key": url_secret_access_key,
}
credentials = Credentials(
access_key_id=url_access_key_id,
secret_access_key=url_secret_access_key,
session_token=url_session_token,
)

boto3_params = credentials.as_boto3_params()
host = parsed_url.gethost()
port = parsed_url.getport()

Expand Down Expand Up @@ -198,4 +243,6 @@ def _from_parsed_url(

verify = query.get("verify", "true").lower() == "true"

return cls(bucket, verify=verify, region_name=region_name)
return cls(
bucket, credentials=credentials, verify=verify, region_name=region_name
)

0 comments on commit 077305c

Please sign in to comment.