From 5fcee8ffce98b88b2b0a070c5959050902be9ae0 Mon Sep 17 00:00:00 2001 From: Benjamin Dornel <62495124+benjamin-awd@users.noreply.github.com> Date: Fri, 10 Nov 2023 23:37:18 +0800 Subject: [PATCH] Add aws_session_token for Athena mapping (#663) Adds the `aws_session_token` argument to Athena, which was added to dbt-athena 1.6.4 in https://github.com/dbt-athena/dbt-athena/pull/459 Closes: #609 Also addresses this comment: https://github.com/astronomer/astronomer-cosmos/pull/578#discussion_r1378301372 --- cosmos/profiles/__init__.py | 1 + cosmos/profiles/athena/access_key.py | 8 ++++---- tests/profiles/athena/test_athena_access_key.py | 8 +++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 8280cd950..1f39a91a0 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -64,6 +64,7 @@ def get_automatic_profile_mapping( __all__ = [ + "AthenaAccessKeyProfileMapping", "BaseProfileMapping", "GoogleCloudServiceAccountFileProfileMapping", "GoogleCloudServiceAccountDictProfileMapping", diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py index b79bb793a..a8f71c2b7 100644 --- a/cosmos/profiles/athena/access_key.py +++ b/cosmos/profiles/athena/access_key.py @@ -26,12 +26,11 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): "s3_staging_dir", "schema", ] - secret_fields = [ - "aws_secret_access_key", - ] + secret_fields = ["aws_secret_access_key", "aws_session_token"] airflow_param_mapping = { "aws_access_key_id": "login", "aws_secret_access_key": "password", + "aws_session_token": "extra.aws_session_token", "aws_profile_name": "extra.aws_profile_name", "database": "extra.database", "debug_query_state": "extra.debug_query_state", @@ -53,7 +52,8 @@ def profile(self) -> dict[str, Any | None]: profile = { **self.mapped_params, **self.profile_args, - # aws_secret_access_key should always get set as env var + # aws_secret_access_key and aws_session_token should always get set as env var "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), + "aws_session_token": self.get_env_var_format("aws_session_token"), } return self.filter_null(profile) diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py index 2063ef6ed..22c8efa2c 100644 --- a/tests/profiles/athena/test_athena_access_key.py +++ b/tests/profiles/athena/test_athena_access_key.py @@ -22,6 +22,7 @@ def mock_athena_conn(): # type: ignore password="my_aws_secret_key", extra=json.dumps( { + "aws_session_token": "token123", "database": "my_database", "region_name": "my_region", "s3_staging_dir": "s3://my_bucket/dbt/", @@ -107,6 +108,7 @@ def test_athena_profile_args( "type": "athena", "aws_access_key_id": mock_athena_conn.login, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": mock_athena_conn.extra_dejson.get("database"), "region_name": mock_athena_conn.extra_dejson.get("region_name"), "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), @@ -122,17 +124,20 @@ def test_athena_profile_args_overrides( """ profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, - profile_args={"schema": "my_custom_schema", "database": "my_custom_db"}, + profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"}, ) + assert profile_mapping.profile_args == { "schema": "my_custom_schema", "database": "my_custom_db", + "aws_session_token": "override_token", } assert profile_mapping.profile == { "type": "athena", "aws_access_key_id": mock_athena_conn.login, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": "my_custom_db", "region_name": mock_athena_conn.extra_dejson.get("region_name"), "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), @@ -151,4 +156,5 @@ def test_athena_profile_env_vars( ) assert profile_mapping.env_vars == { "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, + "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"), }