Skip to content

Commit

Permalink
embedded AWS region into Authorization metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Maksymilian Boguń <[email protected]>
  • Loading branch information
maxbog committed Nov 6, 2024
1 parent efeb28e commit 1fe7209
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 99 deletions.
2 changes: 1 addition & 1 deletion pkg/scalers/apache_kafka_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func getApacheKafkaClient(ctx context.Context, metadata apacheKafkaMetadata, log
case KafkaSASLTypeOAuthbearer:
return nil, errors.New("SASL/OAUTHBEARER is not implemented yet")
case KafkaSASLTypeMskIam:
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AWSRegion, metadata.AWSAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AWSAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/scalers/aws/aws_authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type AuthorizationMetadata struct {
AwsSecretAccessKey string
AwsSessionToken string

AwsRegion string

// Deprecated
PodIdentityOwner bool
// Pod identity owner is confusing and it'll be removed when we get
Expand Down
32 changes: 14 additions & 18 deletions pkg/scalers/aws/aws_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,43 +39,34 @@ import (
// ErrAwsNoAccessKey is returned when awsAccessKeyID is missing.
var ErrAwsNoAccessKey = errors.New("awsAccessKeyID not found")

type awsConfigMetadata struct {
awsRegion string
awsAuthorization AuthorizationMetadata
}

var awsSharedCredentialsCache = newSharedConfigsCache()

// GetAwsConfig returns an *aws.Config for a given AuthorizationMetadata
// If AuthorizationMetadata uses static credentials or `aws` auth,
// we recover the *aws.Config from the shared cache. If not, we generate
// a new entry on each request
func GetAwsConfig(ctx context.Context, awsRegion string, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
metadata := &awsConfigMetadata{
awsRegion: awsRegion,
awsAuthorization: awsAuthorization,
}
func GetAwsConfig(ctx context.Context, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {

if metadata.awsAuthorization.UsingPodIdentity ||
(metadata.awsAuthorization.AwsAccessKeyID != "" && metadata.awsAuthorization.AwsSecretAccessKey != "") {
return awsSharedCredentialsCache.GetCredentials(ctx, metadata.awsRegion, metadata.awsAuthorization)
if awsAuthorization.UsingPodIdentity ||
(awsAuthorization.AwsAccessKeyID != "" && awsAuthorization.AwsSecretAccessKey != "") {
return awsSharedCredentialsCache.GetCredentials(ctx, awsAuthorization)
}

// TODO, remove when aws-eks are removed
configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(metadata.awsRegion))
configOptions = append(configOptions, config.WithRegion(awsAuthorization.AwsRegion))
cfg, err := config.LoadDefaultConfig(ctx, configOptions...)
if err != nil {
return nil, err
}

if !metadata.awsAuthorization.PodIdentityOwner {
if !awsAuthorization.PodIdentityOwner {
return &cfg, nil
}

if metadata.awsAuthorization.AwsRoleArn != "" {
if awsAuthorization.AwsRoleArn != "" {
stsSvc := sts.NewFromConfig(cfg)
stsCredentialProvider := stscreds.NewAssumeRoleProvider(stsSvc, metadata.awsAuthorization.AwsRoleArn, func(_ *stscreds.AssumeRoleOptions) {})
stsCredentialProvider := stscreds.NewAssumeRoleProvider(stsSvc, awsAuthorization.AwsRoleArn, func(_ *stscreds.AssumeRoleOptions) {})
cfg.Credentials = aws.NewCredentialsCache(stsCredentialProvider)
}
return &cfg, err
Expand All @@ -88,13 +79,18 @@ func GetAwsAuthorization(uniqueKey string, podIdentity kedav1alpha1.AuthPodIdent
TriggerUniqueKey: uniqueKey,
}

if val, ok := authParams["awsRegion"]; ok && val != "" {
meta.AwsRegion = val
}

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
meta.UsingPodIdentity = true
if val, ok := authParams["awsRoleArn"]; ok && val != "" {
meta.AwsRoleArn = val
}
return meta, nil
}

// TODO, remove all the logic below and just keep the logic for
// parsing awsAccessKeyID, awsSecretAccessKey and awsSessionToken
// when aws-eks are removed
Expand Down Expand Up @@ -138,5 +134,5 @@ func GetAwsAuthorization(uniqueKey string, podIdentity kedav1alpha1.AuthPodIdent

// ClearAwsConfig wraps the removal of the config from the cache
func ClearAwsConfig(awsRegion string, awsAuthorization AuthorizationMetadata) {
awsSharedCredentialsCache.RemoveCachedEntry(awsRegion, awsAuthorization)
awsSharedCredentialsCache.RemoveCachedEntry(awsAuthorization)
}
18 changes: 9 additions & 9 deletions pkg/scalers/aws/aws_config_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ func newSharedConfigsCache() sharedConfigCache {

// getCacheKey returns a unique key based on given AuthorizationMetadata.
// As it can contain sensitive data, the key is hashed to not expose secrets
func (a *sharedConfigCache) getCacheKey(awsRegion string, awsAuthorization AuthorizationMetadata) string {
key := "keda-" + awsRegion
func (a *sharedConfigCache) getCacheKey(awsAuthorization AuthorizationMetadata) string {
key := "keda-" + awsAuthorization.AwsRegion
if awsAuthorization.AwsAccessKeyID != "" {
key = fmt.Sprintf("%s-%s-%s-%s", awsAuthorization.AwsAccessKeyID, awsAuthorization.AwsSecretAccessKey, awsAuthorization.AwsSessionToken, awsRegion)
key = fmt.Sprintf("%s-%s-%s-%s", awsAuthorization.AwsAccessKeyID, awsAuthorization.AwsSecretAccessKey, awsAuthorization.AwsSessionToken, awsAuthorization.AwsRegion)
} else if awsAuthorization.AwsRoleArn != "" {
key = fmt.Sprintf("%s-%s", awsAuthorization.AwsRoleArn, awsRegion)
key = fmt.Sprintf("%s-%s", awsAuthorization.AwsRoleArn, awsAuthorization.AwsRegion)
}
// to avoid sensitive data as key and to use a constant key size,
// we hash the key with sha3
Expand All @@ -86,18 +86,18 @@ func (a *sharedConfigCache) getCacheKey(awsRegion string, awsAuthorization Autho
// sharing it between all the requests. To track if the *aws.Config is used by whom,
// every time when an scaler requests *aws.Config we register it inside
// the cached item.
func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsRegion string, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
a.Lock()
defer a.Unlock()
key := a.getCacheKey(awsRegion, awsAuthorization)
key := a.getCacheKey(awsAuthorization)
if cachedEntry, exists := a.items[key]; exists {
cachedEntry.usages[awsAuthorization.TriggerUniqueKey] = true
a.items[key] = cachedEntry
return cachedEntry.config, nil
}

configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(awsRegion))
configOptions = append(configOptions, config.WithRegion(awsAuthorization.AwsRegion))
cfg, err := config.LoadDefaultConfig(ctx, configOptions...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -125,10 +125,10 @@ func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsRegion string
// RemoveCachedEntry removes the usage of an AuthorizationMetadata from the cached item.
// If there isn't any usage of a given cached item (because there isn't any trigger using the aws.Config),
// we also remove it from the cache
func (a *sharedConfigCache) RemoveCachedEntry(awsRegion string, awsAuthorization AuthorizationMetadata) {
func (a *sharedConfigCache) RemoveCachedEntry(awsAuthorization AuthorizationMetadata) {
a.Lock()
defer a.Unlock()
key := a.getCacheKey(awsRegion, awsAuthorization)
key := a.getCacheKey(awsAuthorization)
if cachedEntry, exists := a.items[key]; exists {
// Delete the TriggerUniqueKey from usages
delete(cachedEntry.usages, awsAuthorization.TriggerUniqueKey)
Expand Down
78 changes: 33 additions & 45 deletions pkg/scalers/aws/aws_config_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,105 +28,93 @@ import (
func TestGetCredentialsReturnNewItemAndStoreItIfNotExist(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test-key",
AwsRegion: "test-region",
}
cacheKey := cache.getCacheKey(config.awsRegion, config.awsAuthorization)
_, err := cache.GetCredentials(context.Background(), config.awsRegion, config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
_, err := cache.GetCredentials(context.Background(), awsAuthorization)
assert.NoError(t, err)
assert.Contains(t, cache.items, cacheKey)
assert.Contains(t, cache.items[cacheKey].usages, config.awsAuthorization.TriggerUniqueKey)
assert.Contains(t, cache.items[cacheKey].usages, awsAuthorization.TriggerUniqueKey)
}

func TestGetCredentialsReturnCachedItemIfExist(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test1-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
AwsRegion: "test1-region",
}
cfg := aws.Config{}
cfg.AppID = "test1-app"
cacheKey := cache.getCacheKey(config.awsRegion, config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
"other-usage": true,
},
}
configFromCache, err := cache.GetCredentials(context.Background(), config.awsRegion, config.awsAuthorization)
configFromCache, err := cache.GetCredentials(context.Background(), awsAuthorization)
assert.NoError(t, err)
assert.Equal(t, &cfg, configFromCache)
assert.Contains(t, cache.items[cacheKey].usages, config.awsAuthorization.TriggerUniqueKey)
assert.Contains(t, cache.items[cacheKey].usages, awsAuthorization.TriggerUniqueKey)
}

func TestRemoveCachedEntryRemovesCachedItemIfNotUsages(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test2-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test2-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test2-key",
AwsRegion: "test2-region",
}
cfg := aws.Config{}
cfg.AppID = "test2-app"
cacheKey := cache.getCacheKey(config.awsRegion, config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
config.awsAuthorization.TriggerUniqueKey: true,
awsAuthorization.TriggerUniqueKey: true,
},
}
cache.RemoveCachedEntry(config.awsRegion, config.awsAuthorization)
cache.RemoveCachedEntry(awsAuthorization)
assert.NotContains(t, cache.items, cacheKey)
}

func TestRemoveCachedEntryNotRemoveCachedItemIfUsages(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test3-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test3-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test3-key",
AwsRegion: "test3-region",
}
cfg := aws.Config{}
cfg.AppID = "test3-app"
cacheKey := cache.getCacheKey(config.awsRegion, config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
config.awsAuthorization.TriggerUniqueKey: true,
"other-usage": true,
awsAuthorization.TriggerUniqueKey: true,
"other-usage": true,
},
}
cache.RemoveCachedEntry(config.awsRegion, config.awsAuthorization)
cache.RemoveCachedEntry(awsAuthorization)
assert.Contains(t, cache.items, cacheKey)
}

func TestCredentialsShouldBeCachedPerRegion(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config1 := awsConfigMetadata{
awsRegion: "test4-region1",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test4-key1",
},
awsAuthorization1 := AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
AwsRegion: "test4-region1",
}
config2 := awsConfigMetadata{
awsRegion: "test4-region2",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test4-key2",
},
awsAuthorization2 := AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
AwsRegion: "test4-region2",
}
cred1, err1 := cache.GetCredentials(context.Background(), config1.awsRegion, config1.awsAuthorization)
cred2, err2 := cache.GetCredentials(context.Background(), config2.awsRegion, config2.awsAuthorization)
cred1, err1 := cache.GetCredentials(context.Background(), awsAuthorization1)
cred2, err2 := cache.GetCredentials(context.Background(), awsAuthorization2)

assert.NoError(t, err1)
assert.NoError(t, err2)
Expand Down
23 changes: 4 additions & 19 deletions pkg/scalers/aws/aws_sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,8 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}

// parseAwsAMPMetadata parses the data to get the AWS sepcific auth info and metadata
func parseAwsAMPMetadata(config *scalersconfig.ScalerConfig) (*awsConfigMetadata, error) {
meta := awsConfigMetadata{}

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
}

auth, err := GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
if err != nil {
return nil, err
}

meta.awsAuthorization = auth
return &meta, nil
func parseAwsAMPMetadata(config *scalersconfig.ScalerConfig) (AuthorizationMetadata, error) {
return GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
}

// NewSigV4RoundTripper returns a new http.RoundTripper that will sign requests
Expand All @@ -100,11 +88,8 @@ func NewSigV4RoundTripper(config *scalersconfig.ScalerConfig) (http.RoundTripper
// which is probably the reason to create a SigV4RoundTripper.
// To prevent failures we check if the metadata is nil
// (missing AWS info) and we hide the error
metadata, _ := parseAwsAMPMetadata(config)
if metadata == nil {
return nil, nil
}
awsCfg, err := GetAwsConfig(context.Background(), metadata.awsRegion, metadata.awsAuthorization)
awsAuthorization, _ := parseAwsAMPMetadata(config)
awsCfg, err := GetAwsConfig(context.Background(), awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_cloudwatch_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func NewAwsCloudwatchScaler(ctx context.Context, config *scalersconfig.ScalerCon
}

func createCloudwatchClient(ctx context.Context, metadata *awsCloudwatchMetadata) (*cloudwatch.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)

if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_dynamodb_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func parseAwsDynamoDBMetadata(config *scalersconfig.ScalerConfig) (*awsDynamoDBM
}

func createDynamoDBClient(ctx context.Context, metadata *awsDynamoDBMetadata) (*dynamodb.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_dynamodb_streams_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger
}

func createClientsForDynamoDBStreamsScaler(ctx context.Context, metadata *awsDynamoDBStreamsMetadata) (*dynamodb.Client, *dynamodbstreams.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_kinesis_stream_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func parseAwsKinesisStreamMetadata(config *scalersconfig.ScalerConfig, logger lo
}

func createKinesisClient(ctx context.Context, metadata *awsKinesisStreamMetadata) (*kinesis.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_sqs_queue_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func parseAwsSqsQueueMetadata(config *scalersconfig.ScalerConfig, logger logr.Lo
}

func createSqsClient(ctx context.Context, metadata *awsSqsQueueMetadata) (*sqs.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/kafka_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func getKafkaClientConfig(ctx context.Context, metadata kafkaMetadata) (*sarama.
case KafkaSASLOAuthTokenProviderBearer:
config.Net.SASL.TokenProvider = kafka.OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes, metadata.oauthExtensions)
case KafkaSASLOAuthTokenProviderAWSMSKIAM:
awsAuth, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
awsAuth, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, fmt.Errorf("error getting AWS config: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/scaling/resolver/aws_secretmanager_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (ash *AwsSecretManagerHandler) Initialize(ctx context.Context, client clien
if ash.secretManager.Region != "" {
awsRegion = ash.secretManager.Region
}
ash.awsMetadata.AwsRegion = awsRegion
podIdentity := ash.secretManager.PodIdentity
if podIdentity == nil {
podIdentity = &kedav1alpha1.AuthPodIdentity{}
Expand Down Expand Up @@ -100,7 +101,7 @@ func (ash *AwsSecretManagerHandler) Initialize(ctx context.Context, client clien
return fmt.Errorf("pod identity provider %s not supported", podIdentity.Provider)
}

config, err := awsutils.GetAwsConfig(ctx, awsRegion, ash.awsMetadata)
config, err := awsutils.GetAwsConfig(ctx, ash.awsMetadata)
if err != nil {
logger.Error(err, "Error getting credentials")
return err
Expand Down

0 comments on commit 1fe7209

Please sign in to comment.