Skip to content

Commit

Permalink
feat: add env support for disabling v1 fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws committed Nov 1, 2023
1 parent ee5e3f0 commit 728ec2f
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .changelog/99f6ee13815544af9c5dcf4c6fcd6453.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "99f6ee13-8155-44af-9c5d-cf4c6fcd6453",
"type": "feature",
"description": "Add env and shared config settings for disabling IMDSv1 fallback.",
"modules": [
"config"
]
}
21 changes: 20 additions & 1 deletion config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ const (

awsEc2MetadataServiceEndpointEnvVar = "AWS_EC2_METADATA_SERVICE_ENDPOINT"

awsEc2MetadataDisabled = "AWS_EC2_METADATA_DISABLED"
awsEc2MetadataDisabled = "AWS_EC2_METADATA_DISABLED"
awsEc2MetadataV1DisabledEnvVar = "AWS_EC2_METADATA_V1_DISABLED"

awsS3DisableMultiRegionAccessPointEnvVar = "AWS_S3_DISABLE_MULTIREGION_ACCESS_POINTS"

Expand Down Expand Up @@ -209,6 +210,11 @@ type EnvConfig struct {
// AWS_EC2_METADATA_DISABLED=true
EC2IMDSClientEnableState imds.ClientEnableState

// Specifies if EC2 IMDSv1 fallback is disabled.
//
// AWS_EC2_METADATA_V1_DISABLED=true
EC2IMDSv1Disabled *bool

// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
Expand Down Expand Up @@ -317,6 +323,9 @@ func NewEnvConfig() (EnvConfig, error) {
return cfg, err
}
cfg.EC2IMDSEndpoint = os.Getenv(awsEc2MetadataServiceEndpointEnvVar)
if err := setBoolPtrFromEnvVal(&cfg.EC2IMDSv1Disabled, []string{awsEc2MetadataV1DisabledEnvVar}); err != nil {
return cfg, err
}

if err := setBoolPtrFromEnvVal(&cfg.S3DisableMultiRegionAccessPoints, []string{awsS3DisableMultiRegionAccessPointEnvVar}); err != nil {
return cfg, err
Expand Down Expand Up @@ -717,3 +726,13 @@ func (c EnvConfig) GetEC2IMDSEndpoint() (string, bool, error) {

return c.EC2IMDSEndpoint, true, nil
}

// GetEC2IMDSV1FallbackDisabled implements an EC2IMDSV1FallbackDisabled option
// resolver interface.
func (c EnvConfig) GetEC2IMDSV1FallbackDisabled() (bool, bool) {
if c.EC2IMDSv1Disabled == nil {
return false, false
}

return *c.EC2IMDSv1Disabled, true
}
17 changes: 17 additions & 0 deletions config/env_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,23 @@ func TestNewEnvConfig(t *testing.T) {
IgnoreConfiguredEndpoints: ptr.Bool(true),
},
},
40: {
Env: map[string]string{
"AWS_EC2_METADATA_V1_DISABLED": "tRuE",
},
Config: EnvConfig{
EC2IMDSv1Disabled: aws.Bool(true),
},
},
41: {
Env: map[string]string{
"AWS_EC2_METADATA_V1_DISABLED": "invalid",
},
Config: EnvConfig{
EC2IMDSv1Disabled: aws.Bool(false), // setBoolPtrFromEnvVal new()s the bool even if it errors
},
WantErr: true,
},
}

for i, c := range cases {
Expand Down
20 changes: 20 additions & 0 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ const (

ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint"

ec2MetadataV1DisabledKey = "ec2_metadata_v1_disabled"

// Use DualStack Endpoint Resolution
useDualStackEndpoint = "use_dualstack_endpoint"

Expand Down Expand Up @@ -246,6 +248,12 @@ type SharedConfig struct {
// ec2_metadata_service_endpoint=http://fd00:ec2::254
EC2IMDSEndpoint string

// Specifies that IMDS clients should not fallback to IMDSv1 if token
// requests fail.
//
// ec2_metadata_v1_disabled=true
EC2IMDSv1Disabled *bool

// Specifies if the S3 service should disable support for Multi-Region
// access-points
//
Expand Down Expand Up @@ -397,6 +405,16 @@ func (c SharedConfig) GetEC2IMDSEndpoint() (string, bool, error) {
return c.EC2IMDSEndpoint, true, nil
}

// GetEC2IMDSV1FallbackDisabled implements an EC2IMDSV1FallbackDisabled option
// resolver interface.
func (c SharedConfig) GetEC2IMDSV1FallbackDisabled() (bool, bool) {
if c.EC2IMDSv1Disabled == nil {
return false, false
}

return *c.EC2IMDSv1Disabled, true
}

// GetUseDualStackEndpoint returns whether the service's dual-stack endpoint should be
// used for requests.
func (c SharedConfig) GetUseDualStackEndpoint(ctx context.Context) (value aws.DualStackEndpointState, found bool, err error) {
Expand Down Expand Up @@ -807,6 +825,7 @@ func mergeSections(dst *ini.Sections, src ini.Sections) error {
s3DisableMultiRegionAccessPointsKey,
ec2MetadataServiceEndpointModeKey,
ec2MetadataServiceEndpointKey,
ec2MetadataV1DisabledKey,
useDualStackEndpoint,
useFIPSEndpointKey,
defaultsModeKey,
Expand Down Expand Up @@ -1040,6 +1059,7 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er
return fmt.Errorf("failed to load %s from shared config, %v", ec2MetadataServiceEndpointModeKey, err)
}
updateString(&c.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey)
updateBoolPtr(&c.EC2IMDSv1Disabled, section, ec2MetadataV1DisabledKey)

updateUseDualStackEndpoint(&c.UseDualStackEndpoint, section, useDualStackEndpoint)
updateUseFIPSEndpoint(&c.UseFIPSEndpoint, section, useFIPSEndpointKey)
Expand Down
27 changes: 26 additions & 1 deletion config/shared_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,32 @@ func TestNewSharedConfig(t *testing.T) {
BaseEndpoint: "https://example.com",
IgnoreConfiguredEndpoints: ptr.Bool(true),
},
}}
},
"imdsv1 disabled = false": {
ConfigFilenames: []string{testConfigFilename},
Profile: "ec2-metadata-v1-disabled-false",
Expected: SharedConfig{
Profile: "ec2-metadata-v1-disabled-false",
EC2IMDSv1Disabled: aws.Bool(false),
},
},
"imdsv1 disabled = true": {
ConfigFilenames: []string{testConfigFilename},
Profile: "ec2-metadata-v1-disabled-true",
Expected: SharedConfig{
Profile: "ec2-metadata-v1-disabled-true",
EC2IMDSv1Disabled: aws.Bool(true),
},
},
"imdsv1 disabled = invalid": {
ConfigFilenames: []string{testConfigFilename},
Profile: "ec2-metadata-v1-disabled-invalid",
Expected: SharedConfig{
Profile: "ec2-metadata-v1-disabled-invalid",
EC2IMDSv1Disabled: aws.Bool(false),
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
9 changes: 9 additions & 0 deletions config/testdata/shared_config
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,12 @@ sdk_ua_app_id = 12345
[profile endpoint_config]
ignore_configured_endpoint_urls = true
endpoint_url = https://example.com

[profile ec2-metadata-v1-disabled-false]
ec2_metadata_v1_disabled=False

[profile ec2-metadata-v1-disabled-true]
ec2_metadata_v1_disabled=True

[profile ec2-metadata-v1-disabled-invalid]
ec2_metadata_v1_disabled=invalid
18 changes: 18 additions & 0 deletions feature/ec2/imds/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func NewFromConfig(cfg aws.Config, optFns ...func(*Options)) *Client {
resolveClientEnableState(cfg, &opts)
resolveEndpointConfig(cfg, &opts)
resolveEndpointModeConfig(cfg, &opts)
resolveEnableFallback(cfg, &opts)

return New(opts, optFns...)
}
Expand Down Expand Up @@ -328,3 +329,20 @@ func resolveEndpointConfig(cfg aws.Config, options *Options) error {
options.Endpoint = value
return nil
}

func resolveEnableFallback(cfg aws.Config, options *Options) {
if options.EnableFallback != aws.UnknownTernary {
return
}

disabled, ok := internalconfig.ResolveV1FallbackDisabled(cfg.ConfigSources)
if !ok {
return
}

if disabled {
options.EnableFallback = aws.FalseTernary
} else {
options.EnableFallback = aws.TrueTernary
}
}
16 changes: 16 additions & 0 deletions feature/ec2/imds/internal/config/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type EndpointResolver interface {
GetEC2IMDSEndpoint() (string, bool, error)
}

type v1FallbackDisabledResolver interface {
GetEC2IMDSV1FallbackDisabled() (bool, bool)
}

// ResolveClientEnableState resolves the ClientEnableState from a list of configuration sources.
func ResolveClientEnableState(sources []interface{}) (value ClientEnableState, found bool, err error) {
for _, source := range sources {
Expand Down Expand Up @@ -96,3 +100,15 @@ func ResolveEndpointConfig(sources []interface{}) (value string, found bool, err
}
return value, found, err
}

// ResolveV1FallbackDisabled ...
func ResolveV1FallbackDisabled(sources []interface{}) (bool, bool) {
for _, source := range sources {
if resolver, ok := source.(v1FallbackDisabledResolver); ok {
if v, found := resolver.GetEC2IMDSV1FallbackDisabled(); found {
return v, true
}
}
}
return false, false
}

0 comments on commit 728ec2f

Please sign in to comment.