diff --git a/docs/changes.rst b/docs/changes.rst index c4436d0..18694a2 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,6 +1,11 @@ Changelog ********* +1.9.0 +===== +* Add `session_token` url param that can be set when creating a `[h]s3://` store + via `get_store_from_url`. + 1.8.6 ===== * We undeprecated ``url2dict`` and ``extract_params`` as these functions turned diff --git a/minimalkv/net/s3fsstore.py b/minimalkv/net/s3fsstore.py index 28ff2c3..0519843 100644 --- a/minimalkv/net/s3fsstore.py +++ b/minimalkv/net/s3fsstore.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Dict +from typing import Dict, NamedTuple, Optional from uritools import SplitResult @@ -22,10 +22,27 @@ ) +class Credentials(NamedTuple): + """Dataclass to hold AWS credentials.""" + + access_key_id: Optional[str] + secret_access_key: Optional[str] + session_token: Optional[str] + + def as_boto3_params(self): + """Return the credentials as a dictionary suitable for boto3 authentication.""" + return { + "aws_access_key_id": self.access_key_id, + "aws_secret_access_key": self.secret_access_key, + "aws_session_token": self.session_token, + } + + class S3FSStore(FSSpecStore, UrlMixin): # noqa D def __init__( self, bucket, + credentials: Optional[Credentials] = None, object_prefix="", url_valid_time=0, reduced_redundancy=False, @@ -37,12 +54,14 @@ def __init__( if isinstance(bucket, str): import boto3 - s3_resource = boto3.resource("s3") + boto3_params = credentials.as_boto3_params() if credentials else {} + s3_resource = boto3.resource("s3", **boto3_params) bucket = s3_resource.Bucket(bucket) if bucket not in s3_resource.buckets.all(): raise ValueError("invalid s3 bucket name") self.bucket = bucket + self.credentials = credentials self.object_prefix = object_prefix.strip().lstrip("/") self.url_valid_time = url_valid_time self.reduced_redundancy = reduced_redundancy @@ -74,6 +93,16 @@ def _create_filesystem(self) -> "S3FileSystem": client_kwargs["endpoint_url"] = self.endpoint_url if self.region_name: client_kwargs["region_name"] = self.region_name + + if self.credentials: + return S3FileSystem( + key=self.credentials.access_key_id, + secret=self.credentials.secret_access_key, + token=self.credentials.session_token, + anon=False, + client_kwargs=client_kwargs, + ) + return S3FileSystem( anon=False, client_kwargs=client_kwargs, @@ -117,10 +146,22 @@ def _from_parsed_url( ``region_name`` (default: ``None``): If set the AWS region name is applied as location constraint during bucket creation. + + ``session_token``(default: ``None``): If set this token will be used in conjunction + with access_key_id and secret_access_key for authentication. + **Notes**: If the scheme is ``hs3``, an ``HS3FSStore`` is returned which allows ``/`` in key names. + If the credentials are not provided through the url, they are attempted to be + loaded from the environment variables `AWS_ACCESS_KEY_ID`, + `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN`. If these variables are not set, + the search for credentials will be delegated to boto(core). + + Positional arguments should be encoded by `urllib.parse.quote_plus` + if they contain special characters e.g. "/". + Parameters ---------- parsed_url: SplitResult @@ -137,6 +178,7 @@ def _from_parsed_url( url_access_key_id = _get_username(parsed_url) url_secret_access_key = _get_password(parsed_url) + url_session_token = query.get("session_token", None) if url_access_key_id is None: url_secret_access_key = os.environ.get("AWS_ACCESS_KEY_ID") @@ -148,10 +190,13 @@ def _from_parsed_url( else: os.environ["AWS_SECRET_ACCESS_KEY"] = url_secret_access_key - boto3_params = { - "aws_access_key_id": url_access_key_id, - "aws_secret_access_key": url_secret_access_key, - } + credentials = Credentials( + access_key_id=url_access_key_id, + secret_access_key=url_secret_access_key, + session_token=url_session_token, + ) + + boto3_params = credentials.as_boto3_params() host = parsed_url.gethost() port = parsed_url.getport() @@ -198,4 +243,6 @@ def _from_parsed_url( verify = query.get("verify", "true").lower() == "true" - return cls(bucket, verify=verify, region_name=region_name) + return cls( + bucket, credentials=credentials, verify=verify, region_name=region_name + )