diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py index 02de2be24..8dc14f839 100644 --- a/cosmos/profiles/athena/access_key.py +++ b/cosmos/profiles/athena/access_key.py @@ -66,9 +66,11 @@ def profile(self) -> dict[str, Any | None]: **self.profile_args, "aws_access_key_id": self.temporary_credentials.access_key, "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), - "aws_session_token": self.get_env_var_format("aws_session_token"), } + if self.temporary_credentials.token: + profile["aws_session_token"] = self.get_env_var_format("aws_session_token") + return self.filter_null(profile) @property diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py index 71ba1eb05..c0a25b7e9 100644 --- a/tests/profiles/athena/test_athena_access_key.py +++ b/tests/profiles/athena/test_athena_access_key.py @@ -1,8 +1,10 @@ "Tests for the Athena profile." +from __future__ import annotations import json import sys from collections import namedtuple +from unittest import mock from unittest.mock import MagicMock, patch import pytest @@ -39,12 +41,7 @@ def get_credentials(self) -> Credentials: yield mock_aws_hook -@pytest.fixture() -def mock_athena_conn(): # type: ignore - """ - Sets the connection as an environment variable. - """ - +def mock_conn_value(token: str | None = None) -> Connection: conn = Connection( conn_id="my_athena_connection", conn_type="aws", @@ -52,7 +49,7 @@ def mock_athena_conn(): # type: ignore password="my_aws_secret_key", extra=json.dumps( { - "aws_session_token": "token123", + "aws_session_token": token, "database": "my_database", "region_name": "us-east-1", "s3_staging_dir": "s3://my_bucket/dbt/", @@ -60,7 +57,25 @@ def mock_athena_conn(): # type: ignore } ), ) + return conn + +@pytest.fixture() +def mock_athena_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = mock_conn_value(token="token123") + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_athena_conn_without_token(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = mock_conn_value(token=None) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): yield conn @@ -151,6 +166,28 @@ def test_athena_profile_args( } +@mock.patch("cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping._get_temporary_credentials") +def test_athena_profile_args_without_token(mock_temp_cred, mock_athena_conn_without_token: Connection) -> None: + """ + Tests that the profile values get set correctly for Athena. + """ + ReadOnlyCredentials = namedtuple("ReadOnlyCredentials", ["access_key", "secret_key", "token"]) + credentials = ReadOnlyCredentials(access_key="my_aws_access_key", secret_key="my_aws_secret_key", token=None) + mock_temp_cred.return_value = credentials + + profile_mapping = get_automatic_profile_mapping(mock_athena_conn_without_token.conn_id) + + assert profile_mapping.profile == { + "type": "athena", + "aws_access_key_id": "my_aws_access_key", + "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "database": mock_athena_conn_without_token.extra_dejson.get("database"), + "region_name": mock_athena_conn_without_token.extra_dejson.get("region_name"), + "s3_staging_dir": mock_athena_conn_without_token.extra_dejson.get("s3_staging_dir"), + "schema": mock_athena_conn_without_token.extra_dejson.get("schema"), + } + + def test_athena_profile_args_overrides( mock_athena_conn: Connection, ) -> None: