From d43ec45b42c16decddff0cd301880025a979b827 Mon Sep 17 00:00:00 2001 From: Rohith Reddy Kota Date: Thu, 5 Dec 2024 04:45:28 -0500 Subject: [PATCH] athena-connector: support sts assumerole (#6212) * athena-connector: support sts assumerole * rename local variable --- runtime/drivers/athena/athena.go | 5 ++++- runtime/drivers/athena/warehouse.go | 26 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/runtime/drivers/athena/athena.go b/runtime/drivers/athena/athena.go index bb687306076..3b7f6b48f34 100644 --- a/runtime/drivers/athena/athena.go +++ b/runtime/drivers/athena/athena.go @@ -80,9 +80,12 @@ type driver struct{} type configProperties struct { AccessKeyID string `mapstructure:"aws_access_key_id"` + AllowHostAccess bool `mapstructure:"allow_host_access"` + ExternalID string `mapstructure:"external_id"` + RoleARN string `mapstructure:"role_arn"` + RoleSessionName string `mapstructure:"role_session_name"` SecretAccessKey string `mapstructure:"aws_secret_access_key"` SessionToken string `mapstructure:"aws_access_token"` - AllowHostAccess bool `mapstructure:"allow_host_access"` } func (d driver) Open(instanceID string, config map[string]any, client *activity.Client, logger *zap.Logger) (drivers.Handle, error) { diff --git a/runtime/drivers/athena/warehouse.go b/runtime/drivers/athena/warehouse.go index 59457c9b00f..8724c1612b9 100644 --- a/runtime/drivers/athena/warehouse.go +++ b/runtime/drivers/athena/warehouse.go @@ -12,10 +12,12 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/athena" types2 "github.com/aws/aws-sdk-go-v2/service/athena/types" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/google/uuid" "github.com/mitchellh/mapstructure" "github.com/rilldata/rill/runtime/drivers" @@ -123,7 +125,29 @@ func (c *Connection) awsConfig(ctx context.Context, awsRegion string) (aws.Confi return aws.Config{}, fmt.Errorf("static creds are not provided, and host access is not allowed") } - return config.LoadDefaultConfig(ctx, loadOptions...) + awsConfig, err := config.LoadDefaultConfig(ctx, loadOptions...) + if err != nil { + return aws.Config{}, err + } + + if c.config.RoleARN != "" { + stsClient := sts.NewFromConfig(awsConfig) + assumeRoleOptions := []func(*stscreds.AssumeRoleOptions){} + if c.config.RoleSessionName != "" { + assumeRoleOptions = append(assumeRoleOptions, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = c.config.RoleSessionName + }) + } + if c.config.ExternalID != "" { + assumeRoleOptions = append(assumeRoleOptions, func(o *stscreds.AssumeRoleOptions) { + o.ExternalID = &c.config.ExternalID + }) + } + provider := stscreds.NewAssumeRoleProvider(stsClient, c.config.RoleARN, assumeRoleOptions...) + awsConfig.Credentials = aws.NewCredentialsCache(provider) + } + + return awsConfig, nil } func (c *Connection) unload(ctx context.Context, client *athena.Client, conf *sourceProperties, unloadLocation string) error {