Skip to content

Commit

Permalink
Add aws_session_token for Athena mapping (astronomer#663)
Browse files Browse the repository at this point in the history
Adds the `aws_session_token` argument to Athena, which was added to
dbt-athena 1.6.4 in dbt-labs/dbt-athena#459

Closes: astronomer#609

Also addresses this comment:
astronomer#578 (comment)
  • Loading branch information
benjamin-awd authored and arojasb3 committed Jul 14, 2024
1 parent 5e95cfd commit 5fcee8f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_automatic_profile_mapping(


__all__ = [
"AthenaAccessKeyProfileMapping",
"BaseProfileMapping",
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
Expand Down
8 changes: 4 additions & 4 deletions cosmos/profiles/athena/access_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping):
"s3_staging_dir",
"schema",
]
secret_fields = [
"aws_secret_access_key",
]
secret_fields = ["aws_secret_access_key", "aws_session_token"]
airflow_param_mapping = {
"aws_access_key_id": "login",
"aws_secret_access_key": "password",
"aws_session_token": "extra.aws_session_token",
"aws_profile_name": "extra.aws_profile_name",
"database": "extra.database",
"debug_query_state": "extra.debug_query_state",
Expand All @@ -53,7 +52,8 @@ def profile(self) -> dict[str, Any | None]:
profile = {
**self.mapped_params,
**self.profile_args,
# aws_secret_access_key should always get set as env var
# aws_secret_access_key and aws_session_token should always get set as env var
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"),
"aws_session_token": self.get_env_var_format("aws_session_token"),
}
return self.filter_null(profile)
8 changes: 7 additions & 1 deletion tests/profiles/athena/test_athena_access_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def mock_athena_conn(): # type: ignore
password="my_aws_secret_key",
extra=json.dumps(
{
"aws_session_token": "token123",
"database": "my_database",
"region_name": "my_region",
"s3_staging_dir": "s3://my_bucket/dbt/",
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_athena_profile_args(
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}",
"database": mock_athena_conn.extra_dejson.get("database"),
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
Expand All @@ -122,17 +124,20 @@ def test_athena_profile_args_overrides(
"""
profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
profile_args={"schema": "my_custom_schema", "database": "my_custom_db"},
profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"},
)

assert profile_mapping.profile_args == {
"schema": "my_custom_schema",
"database": "my_custom_db",
"aws_session_token": "override_token",
}

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}",
"database": "my_custom_db",
"region_name": mock_athena_conn.extra_dejson.get("region_name"),
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"),
Expand All @@ -151,4 +156,5 @@ def test_athena_profile_env_vars(
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password,
"COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"),
}

0 comments on commit 5fcee8f

Please sign in to comment.