diff --git a/pkg/repository/config/aws.go b/pkg/repository/config/aws.go index 567fec54c3..96dde03d79 100644 --- a/pkg/repository/config/aws.go +++ b/pkg/repository/config/aws.go @@ -34,6 +34,9 @@ import ( "github.com/pkg/errors" ) +// getS3CredentialsFunc is used to make testing more convenient +var getS3CredentialsFunc = GetS3Credentials + const ( // AWS specific environment variable awsProfileEnvVar = "AWS_PROFILE" @@ -63,7 +66,7 @@ func GetS3ResticEnvVars(config map[string]string) (map[string]string, error) { // GetS3ResticEnvVars reads the AWS config, from files and envs // if needed assumes the role and returns the session credentials // setting these variables emulates what would happen for example when using kube2iam - if creds, err := GetS3Credentials(config); err == nil && creds != nil { + if creds, err := getS3CredentialsFunc(config); err == nil && creds != nil { result[awsKeyIDEnvVar] = creds.AccessKeyID result[awsSecretKeyEnvVar] = creds.SecretAccessKey result[awsSessTokenEnvVar] = creds.SessionToken diff --git a/pkg/repository/config/aws_test.go b/pkg/repository/config/aws_test.go index ba7d00f6b9..8c75273709 100644 --- a/pkg/repository/config/aws_test.go +++ b/pkg/repository/config/aws_test.go @@ -27,14 +27,18 @@ import ( func TestGetS3ResticEnvVars(t *testing.T) { testCases := []struct { - name string - config map[string]string - expected map[string]string + name string + config map[string]string + expected map[string]string + getS3Credentials func(config map[string]string) (*aws.Credentials, error) }{ { name: "when config is empty, no env vars are returned", config: map[string]string{}, expected: map[string]string{}, + getS3Credentials: func(config map[string]string) (*aws.Credentials, error) { + return nil, nil + }, }, { name: "when config contains profile key, profile env var is set with profile value", @@ -53,16 +57,39 @@ func TestGetS3ResticEnvVars(t *testing.T) { expected: map[string]string{ "AWS_SHARED_CREDENTIALS_FILE": "/tmp/credentials/path/to/secret", }, + getS3Credentials: func(config map[string]string) (*aws.Credentials, error) { + return nil, nil + }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + // Mock GetS3Credentials + if tc.getS3Credentials != nil { + getS3CredentialsFunc = tc.getS3Credentials + } else { + getS3CredentialsFunc = GetS3Credentials + } + actual, err := GetS3ResticEnvVars(tc.config) require.NoError(t, err) - require.Equal(t, tc.expected, actual) + // Avoid direct comparison of expected and actual to prevent exposing secrets. + // This may occur if the test doesn't set getS3Credentials func correctly. + if !reflect.DeepEqual(tc.expected, actual) { + t.Errorf("Expected and actual results do not match for test case %q", tc.name) + for key, value := range actual { + if expVal, err := tc.expected[key]; !err || expVal != value { + if actualVal, ok := actual[key]; !ok { + t.Errorf("Key %q is missing in actual result", key) + } else if expVal != actualVal { + t.Errorf("Key %q: expected value %q", key, expVal) + } + } + } + } }) } } @@ -117,6 +144,11 @@ func TestGetS3CredentialsCorrectlyUseProfile(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Ensure env variables do not set AWS config entries + t.Setenv("AWS_ACCESS_KEY_ID", "") + t.Setenv("AWS_SECRET_ACCESS_KEY", "") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "") + tmpFile, err := os.CreateTemp("", "velero-test-aws-credentials") defer os.Remove(tmpFile.Name()) if err != nil { @@ -129,6 +161,7 @@ func TestGetS3CredentialsCorrectlyUseProfile(t *testing.T) { t.Errorf("GetS3Credentials() error = %v", err) return } + tt.args.config["credentialsFile"] = tmpFile.Name() got, err := GetS3Credentials(tt.args.config) if (err != nil) != tt.wantErr { @@ -136,10 +169,10 @@ func TestGetS3CredentialsCorrectlyUseProfile(t *testing.T) { return } if !reflect.DeepEqual(got.AccessKeyID, tt.want.AccessKeyID) { - t.Errorf("GetS3Credentials() got = %v, want %v", got.AccessKeyID, tt.want.AccessKeyID) + t.Errorf("GetS3Credentials() want %v", tt.want.AccessKeyID) } if !reflect.DeepEqual(got.SecretAccessKey, tt.want.SecretAccessKey) { - t.Errorf("GetS3Credentials() got = %v, want %v", got.SecretAccessKey, tt.want.SecretAccessKey) + t.Errorf("GetS3Credentials() want %v", tt.want.SecretAccessKey) } }) }