diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..388b1943547 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -3,3 +3,5 @@ ### SDK Enhancements ### SDK Bugs +* `aws/defaults`: Feature updates to endpoint credentials provider. + * Add support for dynamic auth token from file and EKS container host in configured URI. \ No newline at end of file diff --git a/aws/credentials/endpointcreds/provider.go b/aws/credentials/endpointcreds/provider.go index 785f30d8e6c..329f788a38a 100644 --- a/aws/credentials/endpointcreds/provider.go +++ b/aws/credentials/endpointcreds/provider.go @@ -31,6 +31,8 @@ package endpointcreds import ( "encoding/json" + "fmt" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -69,7 +71,37 @@ type Provider struct { // Optional authorization token value if set will be used as the value of // the Authorization header of the endpoint credential request. + // + // When constructed from environment, the provider will use the value of + // AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token + // + // Will be overridden if AuthorizationTokenProvider is configured AuthorizationToken string + + // Optional auth provider func to dynamically load the auth token from a file + // everytime a credential is retrieved + // + // When constructed from environment, the provider will read and use the content + // of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable + // as the auth token everytime credentials are retrieved + // + // Will override AuthorizationToken if configured + AuthorizationTokenProvider AuthTokenProvider +} + +// AuthTokenProvider defines an interface to dynamically load a value to be passed +// for the Authorization header of a credentials request. +type AuthTokenProvider interface { + GetToken() (string, error) +} + +// TokenProviderFunc is a func type implementing AuthTokenProvider interface +// and enables customizing token provider behavior +type TokenProviderFunc func() (string, error) + +// GetToken func retrieves auth token according to TokenProviderFunc implementation +func (p TokenProviderFunc) GetToken() (string, error) { + return p() } // NewProviderClient returns a credentials Provider for retrieving AWS credentials @@ -164,7 +196,20 @@ func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error req := p.Client.NewRequest(op, nil, out) req.SetContext(ctx) req.HTTPRequest.Header.Set("Accept", "application/json") - if authToken := p.AuthorizationToken; len(authToken) != 0 { + + authToken := p.AuthorizationToken + var err error + if p.AuthorizationTokenProvider != nil { + authToken, err = p.AuthorizationTokenProvider.GetToken() + if err != nil { + return nil, fmt.Errorf("get authorization token: %v", err) + } + } + + if strings.ContainsAny(authToken, "\r\n") { + return nil, fmt.Errorf("authorization token contains invalid newline sequence") + } + if len(authToken) != 0 { req.HTTPRequest.Header.Set("Authorization", authToken) } diff --git a/aws/credentials/endpointcreds/provider_test.go b/aws/credentials/endpointcreds/provider_test.go index fceb077792b..00c0fac23b8 100644 --- a/aws/credentials/endpointcreds/provider_test.go +++ b/aws/credentials/endpointcreds/provider_test.go @@ -159,64 +159,108 @@ func TestFailedRetrieveCredentials(t *testing.T) { } func TestAuthorizationToken(t *testing.T) { - const expectAuthToken = "Basic abc123" + cases := map[string]struct { + ExpectPath string + ServerPath string + AuthToken string + AuthTokenProvider endpointcreds.AuthTokenProvider + ExpectAuthToken string + ExpectError bool + }{ + "AuthToken": { + ExpectPath: "/path/to/endpoint", + ServerPath: "/path/to/endpoint?something=else", + AuthToken: "Basic abc123", + ExpectAuthToken: "Basic abc123", + }, + "AuthFileToken": { + ExpectPath: "/path/to/endpoint", + ServerPath: "/path/to/endpoint?something=else", + AuthToken: "Basic abc123", + AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) { + return "Hello %20world", nil + }), + ExpectAuthToken: "Hello %20world", + }, + "RetrieveFileTokenError": { + ExpectPath: "/path/to/endpoint", + ServerPath: "/path/to/endpoint?something=else", + AuthToken: "Basic abc123", + AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) { + return "", fmt.Errorf("test error") + }), + ExpectAuthToken: "Hello %20world", + ExpectError: true, + }, + } - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if e, a := "/path/to/endpoint", r.URL.Path; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "application/json", r.Header.Get("Accept"); e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a { - t.Fatalf("expect %v, got %v", e, a) - } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if e, a := c.ExpectPath, r.URL.Path; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "application/json", r.Header.Get("Accept"); e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := c.ExpectAuthToken, r.Header.Get("Authorization"); e != a { + t.Fatalf("expect %v, got %v", e, a) + } - encoder := json.NewEncoder(w) - err := encoder.Encode(map[string]interface{}{ - "AccessKeyID": "AKID", - "SecretAccessKey": "SECRET", - "Token": "TOKEN", - "Expiration": time.Now().Add(1 * time.Hour), - }) + encoder := json.NewEncoder(w) + err := encoder.Encode(map[string]interface{}{ + "AccessKeyID": "AKID", + "SecretAccessKey": "SECRET", + "Token": "TOKEN", + "Expiration": time.Now().Add(1 * time.Hour), + }) - if err != nil { - fmt.Println("failed to write out creds", err) - } - })) - defer server.Close() + if err != nil { + fmt.Println("failed to write out creds", err) + } + })) + defer server.Close() - client := endpointcreds.NewProviderClient(*unit.Session.Config, - unit.Session.Handlers, - server.URL+"/path/to/endpoint?something=else", - func(p *endpointcreds.Provider) { - p.AuthorizationToken = expectAuthToken - }, - ) - creds, err := client.Retrieve() + client := endpointcreds.NewProviderClient(*unit.Session.Config, + unit.Session.Handlers, + server.URL+c.ServerPath, + func(p *endpointcreds.Provider) { + p.AuthorizationToken = c.AuthToken + p.AuthorizationTokenProvider = c.AuthTokenProvider + }, + ) + creds, err := client.Retrieve() - if err != nil { - t.Errorf("expect no error, got %v", err) - } + if err != nil && !c.ExpectError { + t.Errorf("expect no error, got %v", err) + } else if err == nil && c.ExpectError { + t.Errorf("expect error, got nil") + } - if e, a := "AKID", creds.AccessKeyID; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SECRET", creds.SecretAccessKey; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "TOKEN", creds.SessionToken; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if client.IsExpired() { - t.Errorf("expect not expired, was") - } + if c.ExpectError { + return + } - client.(*endpointcreds.Provider).CurrentTime = func() time.Time { - return time.Now().Add(2 * time.Hour) - } + if e, a := "AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "TOKEN", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if client.IsExpired() { + t.Errorf("expect not expired, was") + } - if !client.IsExpired() { - t.Errorf("expect expired, wasn't") + client.(*endpointcreds.Provider).CurrentTime = func() time.Time { + return time.Now().Add(2 * time.Hour) + } + + if !client.IsExpired() { + t.Errorf("expect expired, wasn't") + } + }) } } diff --git a/aws/defaults/defaults.go b/aws/defaults/defaults.go index 23bb639e018..98278141935 100644 --- a/aws/defaults/defaults.go +++ b/aws/defaults/defaults.go @@ -9,6 +9,7 @@ package defaults import ( "fmt" + "io/ioutil" "net" "net/http" "net/url" @@ -114,9 +115,31 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro const ( httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN" + httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI" ) +// direct representation of the IPv4 address for the ECS container +// "169.254.170.2" +var ecsContainerIPv4 net.IP = []byte{ + 169, 254, 170, 2, +} + +// direct representation of the IPv4 address for the EKS container +// "169.254.170.23" +var eksContainerIPv4 net.IP = []byte{ + 169, 254, 170, 23, +} + +// direct representation of the IPv6 address for the EKS container +// "fd00:ec2::23" +var eksContainerIPv6 net.IP = []byte{ + 0xFD, 0, 0xE, 0xC2, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0x23, +} + // RemoteCredProvider returns a credentials provider for the default remote // endpoints such as EC2 or ECS Roles. func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider { @@ -134,19 +157,22 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P var lookupHostFn = net.LookupHost -func isLoopbackHost(host string) (bool, error) { - ip := net.ParseIP(host) - if ip != nil { - return ip.IsLoopback(), nil +// isAllowedHost allows host to be loopback or known ECS/EKS container IPs +// +// host can either be an IP address OR an unresolved hostname - resolution will +// be automatically performed in the latter case +func isAllowedHost(host string) (bool, error) { + if ip := net.ParseIP(host); ip != nil { + return isIPAllowed(ip), nil } - // Host is not an ip, perform lookup addrs, err := lookupHostFn(host) if err != nil { return false, err } + for _, addr := range addrs { - if !net.ParseIP(addr).IsLoopback() { + if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) { return false, nil } } @@ -154,6 +180,13 @@ func isLoopbackHost(host string) (bool, error) { return true, nil } +func isIPAllowed(ip net.IP) bool { + return ip.IsLoopback() || + ip.Equal(ecsContainerIPv4) || + ip.Equal(eksContainerIPv4) || + ip.Equal(eksContainerIPv6) +} + func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider { var errMsg string @@ -164,10 +197,12 @@ func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) host := aws.URLHostname(parsed) if len(host) == 0 { errMsg = "unable to parse host from local HTTP cred provider URL" - } else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil { - errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr) - } else if !isLoopback { - errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host) + } else if parsed.Scheme == "http" { + if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil { + errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, allowHostErr) + } else if !isAllowedHost { + errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed.", host) + } } } @@ -189,6 +224,15 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede func(p *endpointcreds.Provider) { p.ExpiryWindow = 5 * time.Minute p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar) + if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" { + p.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) { + if contents, err := ioutil.ReadFile(authFilePath); err != nil { + return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err) + } else { + return string(contents), nil + } + }) + } }, ) } diff --git a/aws/defaults/defaults_test.go b/aws/defaults/defaults_test.go index 2ab6ec24607..ee29a221104 100644 --- a/aws/defaults/defaults_test.go +++ b/aws/defaults/defaults_test.go @@ -23,10 +23,12 @@ func TestHTTPCredProvider(t *testing.T) { Addrs []string Err error }{ - "localhost": {Addrs: []string{"::1", "127.0.0.1"}}, - "actuallylocal": {Addrs: []string{"127.0.0.2"}}, - "notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}}, - "www.example.com": {Addrs: []string{"10.10.10.10"}}, + "localhost": {Addrs: []string{"::1", "127.0.0.1"}}, + "actuallylocal": {Addrs: []string{"127.0.0.2"}}, + "notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}}, + "www.example.com": {Addrs: []string{"10.10.10.10"}}, + "www.eks.legit.com": {Addrs: []string{"fd00:ec2::23"}}, + "www.eks.scary.com": {Addrs: []string{"fd00:ec3::23"}}, } h, ok := m[host] @@ -49,7 +51,13 @@ func TestHTTPCredProvider(t *testing.T) { {Host: "127.1.1.1", Fail: false}, {Host: "[::1]", Fail: false}, {Host: "www.example.com", Fail: true}, - {Host: "169.254.170.2", Fail: true}, + {Host: "169.254.170.2", Fail: false}, + {Host: "169.254.170.23", Fail: false}, + {Host: "[fd00:ec2::23]", Fail: false}, + {Host: "[fd00:ec2:0::23]", Fail: false}, + {Host: "[fd00:ec2:0:1::23]", Fail: true}, + {Host: "www.eks.legit.com", Fail: false}, + {Host: "www.eks.scary.com", Fail: true}, {Host: "localhost", Fail: false, AuthToken: "Basic abc123"}, } @@ -91,6 +99,27 @@ func TestHTTPCredProvider(t *testing.T) { } } +func TestHTTPAuthTokenFile(t *testing.T) { + restoreEnvFn := sdktesting.StashEnv() + defer restoreEnvFn() + os.Setenv(httpProviderAuthFileEnvVar, "path/to/file") + os.Setenv(httpProviderEnvVar, "http://169.254.170.23/abc/123") + + provider := RemoteCredProvider(aws.Config{}, request.Handlers{}) + if provider == nil { + t.Fatalf("expect provider not to be nil, but was") + } + + httpProvider := provider.(*endpointcreds.Provider) + if httpProvider == nil { + t.Fatalf("expect provider not to be nil, but was") + } + + if httpProvider.AuthorizationTokenProvider == nil { + t.Fatalf("expect auth token provider no to be nil, but was") + } +} + func TestECSCredProvider(t *testing.T) { restoreEnvFn := sdktesting.StashEnv() defer restoreEnvFn()