Skip to content

Commit

Permalink
athena-connector: support sts assumerole (#6212)
Browse files Browse the repository at this point in the history
* athena-connector: support sts assumerole

* rename local variable
  • Loading branch information
rohithreddykota authored and begelundmuller committed Dec 9, 2024
1 parent 6813b08 commit d43ec45
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
5 changes: 4 additions & 1 deletion runtime/drivers/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ type driver struct{}

type configProperties struct {
AccessKeyID string `mapstructure:"aws_access_key_id"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
ExternalID string `mapstructure:"external_id"`
RoleARN string `mapstructure:"role_arn"`
RoleSessionName string `mapstructure:"role_session_name"`
SecretAccessKey string `mapstructure:"aws_secret_access_key"`
SessionToken string `mapstructure:"aws_access_token"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
}

func (d driver) Open(instanceID string, config map[string]any, client *activity.Client, logger *zap.Logger) (drivers.Handle, error) {
Expand Down
26 changes: 25 additions & 1 deletion runtime/drivers/athena/warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/athena"
types2 "github.com/aws/aws-sdk-go-v2/service/athena/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/google/uuid"
"github.com/mitchellh/mapstructure"
"github.com/rilldata/rill/runtime/drivers"
Expand Down Expand Up @@ -123,7 +125,29 @@ func (c *Connection) awsConfig(ctx context.Context, awsRegion string) (aws.Confi
return aws.Config{}, fmt.Errorf("static creds are not provided, and host access is not allowed")
}

return config.LoadDefaultConfig(ctx, loadOptions...)
awsConfig, err := config.LoadDefaultConfig(ctx, loadOptions...)
if err != nil {
return aws.Config{}, err
}

if c.config.RoleARN != "" {
stsClient := sts.NewFromConfig(awsConfig)
assumeRoleOptions := []func(*stscreds.AssumeRoleOptions){}
if c.config.RoleSessionName != "" {
assumeRoleOptions = append(assumeRoleOptions, func(o *stscreds.AssumeRoleOptions) {
o.RoleSessionName = c.config.RoleSessionName
})
}
if c.config.ExternalID != "" {
assumeRoleOptions = append(assumeRoleOptions, func(o *stscreds.AssumeRoleOptions) {
o.ExternalID = &c.config.ExternalID
})
}
provider := stscreds.NewAssumeRoleProvider(stsClient, c.config.RoleARN, assumeRoleOptions...)
awsConfig.Credentials = aws.NewCredentialsCache(provider)
}

return awsConfig, nil
}

func (c *Connection) unload(ctx context.Context, client *athena.Client, conf *sourceProperties, unloadLocation string) error {
Expand Down

1 comment on commit d43ec45

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 Published on https://ui.rilldata.com as production
🚀 Deployed on https://67570cd841b4d409ebba40c5--rill-ui.netlify.app

Please sign in to comment.