diff --git a/.changelog/99f6ee13815544af9c5dcf4c6fcd6453.json b/.changelog/99f6ee13815544af9c5dcf4c6fcd6453.json new file mode 100644 index 00000000000..f86c5af90f8 --- /dev/null +++ b/.changelog/99f6ee13815544af9c5dcf4c6fcd6453.json @@ -0,0 +1,8 @@ +{ + "id": "99f6ee13-8155-44af-9c5d-cf4c6fcd6453", + "type": "feature", + "description": "Add env and shared config settings for disabling IMDSv1 fallback.", + "modules": [ + "config" + ] +} \ No newline at end of file diff --git a/config/env_config.go b/config/env_config.go index 38d7ef72bac..78bc1493372 100644 --- a/config/env_config.go +++ b/config/env_config.go @@ -57,7 +57,8 @@ const ( awsEc2MetadataServiceEndpointEnvVar = "AWS_EC2_METADATA_SERVICE_ENDPOINT" - awsEc2MetadataDisabled = "AWS_EC2_METADATA_DISABLED" + awsEc2MetadataDisabled = "AWS_EC2_METADATA_DISABLED" + awsEc2MetadataV1DisabledEnvVar = "AWS_EC2_METADATA_V1_DISABLED" awsS3DisableMultiRegionAccessPointEnvVar = "AWS_S3_DISABLE_MULTIREGION_ACCESS_POINTS" @@ -209,6 +210,11 @@ type EnvConfig struct { // AWS_EC2_METADATA_DISABLED=true EC2IMDSClientEnableState imds.ClientEnableState + // Specifies if EC2 IMDSv1 fallback is disabled. + // + // AWS_EC2_METADATA_V1_DISABLED=true + EC2IMDSv1Disabled *bool + // Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6) // // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6 @@ -317,6 +323,9 @@ func NewEnvConfig() (EnvConfig, error) { return cfg, err } cfg.EC2IMDSEndpoint = os.Getenv(awsEc2MetadataServiceEndpointEnvVar) + if err := setBoolPtrFromEnvVal(&cfg.EC2IMDSv1Disabled, []string{awsEc2MetadataV1DisabledEnvVar}); err != nil { + return cfg, err + } if err := setBoolPtrFromEnvVal(&cfg.S3DisableMultiRegionAccessPoints, []string{awsS3DisableMultiRegionAccessPointEnvVar}); err != nil { return cfg, err @@ -717,3 +726,13 @@ func (c EnvConfig) GetEC2IMDSEndpoint() (string, bool, error) { return c.EC2IMDSEndpoint, true, nil } + +// GetEC2IMDSV1FallbackDisabled implements an EC2IMDSV1FallbackDisabled option +// resolver interface. +func (c EnvConfig) GetEC2IMDSV1FallbackDisabled() (bool, bool) { + if c.EC2IMDSv1Disabled == nil { + return false, false + } + + return *c.EC2IMDSv1Disabled, true +} diff --git a/config/env_config_test.go b/config/env_config_test.go index 988f31a017f..14bead837a1 100644 --- a/config/env_config_test.go +++ b/config/env_config_test.go @@ -439,6 +439,23 @@ func TestNewEnvConfig(t *testing.T) { IgnoreConfiguredEndpoints: ptr.Bool(true), }, }, + 40: { + Env: map[string]string{ + "AWS_EC2_METADATA_V1_DISABLED": "tRuE", + }, + Config: EnvConfig{ + EC2IMDSv1Disabled: aws.Bool(true), + }, + }, + 41: { + Env: map[string]string{ + "AWS_EC2_METADATA_V1_DISABLED": "invalid", + }, + Config: EnvConfig{ + EC2IMDSv1Disabled: aws.Bool(false), // setBoolPtrFromEnvVal new()s the bool even if it errors + }, + WantErr: true, + }, } for i, c := range cases { diff --git a/config/shared_config.go b/config/shared_config.go index 435051735b4..20683bf5f07 100644 --- a/config/shared_config.go +++ b/config/shared_config.go @@ -79,6 +79,8 @@ const ( ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint" + ec2MetadataV1DisabledKey = "ec2_metadata_v1_disabled" + // Use DualStack Endpoint Resolution useDualStackEndpoint = "use_dualstack_endpoint" @@ -246,6 +248,12 @@ type SharedConfig struct { // ec2_metadata_service_endpoint=http://fd00:ec2::254 EC2IMDSEndpoint string + // Specifies that IMDS clients should not fallback to IMDSv1 if token + // requests fail. + // + // ec2_metadata_v1_disabled=true + EC2IMDSv1Disabled *bool + // Specifies if the S3 service should disable support for Multi-Region // access-points // @@ -397,6 +405,16 @@ func (c SharedConfig) GetEC2IMDSEndpoint() (string, bool, error) { return c.EC2IMDSEndpoint, true, nil } +// GetEC2IMDSV1FallbackDisabled implements an EC2IMDSV1FallbackDisabled option +// resolver interface. +func (c SharedConfig) GetEC2IMDSV1FallbackDisabled() (bool, bool) { + if c.EC2IMDSv1Disabled == nil { + return false, false + } + + return *c.EC2IMDSv1Disabled, true +} + // GetUseDualStackEndpoint returns whether the service's dual-stack endpoint should be // used for requests. func (c SharedConfig) GetUseDualStackEndpoint(ctx context.Context) (value aws.DualStackEndpointState, found bool, err error) { @@ -807,6 +825,7 @@ func mergeSections(dst *ini.Sections, src ini.Sections) error { s3DisableMultiRegionAccessPointsKey, ec2MetadataServiceEndpointModeKey, ec2MetadataServiceEndpointKey, + ec2MetadataV1DisabledKey, useDualStackEndpoint, useFIPSEndpointKey, defaultsModeKey, @@ -1040,6 +1059,7 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er return fmt.Errorf("failed to load %s from shared config, %v", ec2MetadataServiceEndpointModeKey, err) } updateString(&c.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey) + updateBoolPtr(&c.EC2IMDSv1Disabled, section, ec2MetadataV1DisabledKey) updateUseDualStackEndpoint(&c.UseDualStackEndpoint, section, useDualStackEndpoint) updateUseFIPSEndpoint(&c.UseFIPSEndpoint, section, useFIPSEndpointKey) diff --git a/config/shared_config_test.go b/config/shared_config_test.go index 0c935cd366d..417ee2aff9b 100644 --- a/config/shared_config_test.go +++ b/config/shared_config_test.go @@ -660,7 +660,32 @@ func TestNewSharedConfig(t *testing.T) { BaseEndpoint: "https://example.com", IgnoreConfiguredEndpoints: ptr.Bool(true), }, - }} + }, + "imdsv1 disabled = false": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "ec2-metadata-v1-disabled-false", + Expected: SharedConfig{ + Profile: "ec2-metadata-v1-disabled-false", + EC2IMDSv1Disabled: aws.Bool(false), + }, + }, + "imdsv1 disabled = true": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "ec2-metadata-v1-disabled-true", + Expected: SharedConfig{ + Profile: "ec2-metadata-v1-disabled-true", + EC2IMDSv1Disabled: aws.Bool(true), + }, + }, + "imdsv1 disabled = invalid": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "ec2-metadata-v1-disabled-invalid", + Expected: SharedConfig{ + Profile: "ec2-metadata-v1-disabled-invalid", + EC2IMDSv1Disabled: aws.Bool(false), + }, + }, + } for name, c := range cases { t.Run(name, func(t *testing.T) { diff --git a/config/testdata/shared_config b/config/testdata/shared_config index 0260b7139e9..eefb5158262 100644 --- a/config/testdata/shared_config +++ b/config/testdata/shared_config @@ -282,3 +282,12 @@ sdk_ua_app_id = 12345 [profile endpoint_config] ignore_configured_endpoint_urls = true endpoint_url = https://example.com + +[profile ec2-metadata-v1-disabled-false] +ec2_metadata_v1_disabled=False + +[profile ec2-metadata-v1-disabled-true] +ec2_metadata_v1_disabled=True + +[profile ec2-metadata-v1-disabled-invalid] +ec2_metadata_v1_disabled=invalid diff --git a/feature/ec2/imds/api_client.go b/feature/ec2/imds/api_client.go index e55edd992e2..46e144d9363 100644 --- a/feature/ec2/imds/api_client.go +++ b/feature/ec2/imds/api_client.go @@ -119,6 +119,7 @@ func NewFromConfig(cfg aws.Config, optFns ...func(*Options)) *Client { resolveClientEnableState(cfg, &opts) resolveEndpointConfig(cfg, &opts) resolveEndpointModeConfig(cfg, &opts) + resolveEnableFallback(cfg, &opts) return New(opts, optFns...) } @@ -328,3 +329,20 @@ func resolveEndpointConfig(cfg aws.Config, options *Options) error { options.Endpoint = value return nil } + +func resolveEnableFallback(cfg aws.Config, options *Options) { + if options.EnableFallback != aws.UnknownTernary { + return + } + + disabled, ok := internalconfig.ResolveV1FallbackDisabled(cfg.ConfigSources) + if !ok { + return + } + + if disabled { + options.EnableFallback = aws.FalseTernary + } else { + options.EnableFallback = aws.TrueTernary + } +} diff --git a/feature/ec2/imds/internal/config/resolvers.go b/feature/ec2/imds/internal/config/resolvers.go index d72fcb5626f..ce774558932 100644 --- a/feature/ec2/imds/internal/config/resolvers.go +++ b/feature/ec2/imds/internal/config/resolvers.go @@ -58,6 +58,10 @@ type EndpointResolver interface { GetEC2IMDSEndpoint() (string, bool, error) } +type v1FallbackDisabledResolver interface { + GetEC2IMDSV1FallbackDisabled() (bool, bool) +} + // ResolveClientEnableState resolves the ClientEnableState from a list of configuration sources. func ResolveClientEnableState(sources []interface{}) (value ClientEnableState, found bool, err error) { for _, source := range sources { @@ -96,3 +100,15 @@ func ResolveEndpointConfig(sources []interface{}) (value string, found bool, err } return value, found, err } + +// ResolveV1FallbackDisabled ... +func ResolveV1FallbackDisabled(sources []interface{}) (bool, bool) { + for _, source := range sources { + if resolver, ok := source.(v1FallbackDisabledResolver); ok { + if v, found := resolver.GetEC2IMDSV1FallbackDisabled(); found { + return v, true + } + } + } + return false, false +}