diff --git a/aws/acm.go b/aws/acm.go index 03977609..309b41a2 100644 --- a/aws/acm.go +++ b/aws/acm.go @@ -2,6 +2,7 @@ package aws import ( "crypto/x509" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/acm" @@ -10,16 +11,17 @@ import ( ) type acmCertificateProvider struct { - api acmiface.ACMAPI + api acmiface.ACMAPI + filterTag string } -func newACMCertProvider(api acmiface.ACMAPI) certs.CertificatesProvider { - return &acmCertificateProvider{api: api} +func newACMCertProvider(api acmiface.ACMAPI, certFilterTag string) certs.CertificatesProvider { + return &acmCertificateProvider{api: api, filterTag: certFilterTag} } // GetCertificates returns a list of AWS ACM certificates func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, error) { - acmSummaries, err := getACMCertificateSummaries(p.api) + acmSummaries, err := getACMCertificateSummaries(p.api, p.filterTag) if err != nil { return nil, err } @@ -34,20 +36,47 @@ func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, return result, nil } -func getACMCertificateSummaries(api acmiface.ACMAPI) ([]*acm.CertificateSummary, error) { +func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.CertificateSummary, error) { params := &acm.ListCertificatesInput{ CertificateStatuses: []*string{ aws.String(acm.CertificateStatusIssued), }, } acmSummaries := make([]*acm.CertificateSummary, 0) + err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool { acmSummaries = append(acmSummaries, page.CertificateSummaryList...) return true }) + + if tag := strings.Split(filterTag, "="); filterTag != "=" && len(tag) == 2 { + return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1]) + } + return acmSummaries, err } +func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) { + prodSummaries := make([]*acm.CertificateSummary, 0) + for _, summary := range allSummaries { + in := &acm.ListTagsForCertificateInput{ + CertificateArn: summary.CertificateArn, + } + out, err := api.ListTagsForCertificate(in) + if err != nil { + return nil, err + } + + for _, tag := range out.Tags { + if *tag.Key == key && *tag.Value == value { + prodSummaries = append(prodSummaries, summary) + } + } + } + + return prodSummaries, nil +} + func getCertificateSummaryFromACM(api acmiface.ACMAPI, arn *string) (*certs.CertificateSummary, error) { params := &acm.GetCertificateInput{CertificateArn: arn} resp, err := api.GetCertificate(params) diff --git a/aws/acm_test.go b/aws/acm_test.go index 3dc206e5..9a48b1b6 100644 --- a/aws/acm_test.go +++ b/aws/acm_test.go @@ -15,6 +15,7 @@ type acmExpect struct { DomainNames []string Chain int Error error + EmptyList bool } func TestACM(t *testing.T) { @@ -22,9 +23,10 @@ func TestACM(t *testing.T) { chain := mustRead("chain.txt") for _, ti := range []struct { - msg string - api acmiface.ACMAPI - expect acmExpect + msg string + api acmiface.ACMAPI + filterTag string + expect acmExpect }{ { msg: "Found ACM Cert foobar and a chain", @@ -37,9 +39,11 @@ func TestACM(t *testing.T) { }, }, }, - acm.GetCertificateOutput{ - Certificate: aws.String(cert), - CertificateChain: aws.String(chain), + map[string]*acm.GetCertificateOutput{ + "foobar": { + Certificate: aws.String(cert), + CertificateChain: aws.String(chain), + }, }, ), expect: acmExpect{ @@ -59,11 +63,82 @@ func TestACM(t *testing.T) { }, }, }, - acm.GetCertificateOutput{ - Certificate: aws.String(cert), + map[string]*acm.GetCertificateOutput{ + "foobar": { + Certificate: aws.String(cert), + }, + }, + ), + expect: acmExpect{ + ARN: "foobar", + DomainNames: []string{"foobar.de"}, + Error: nil, + }, + }, + { + msg: "Found one ACM Cert with correct filter tag", + api: fake.NewACMClientWithTags( + acm.ListCertificatesOutput{ + CertificateSummaryList: []*acm.CertificateSummary{ + { + CertificateArn: aws.String("foobar"), + DomainName: aws.String("foobar.de"), + }, + { + CertificateArn: aws.String("foobaz"), + DomainName: aws.String("foobar.de"), + }, + }, + }, + map[string]*acm.GetCertificateOutput{ + "foobar": { + Certificate: aws.String(cert), + }, + "foobaz": { + Certificate: aws.String(cert), + }, + }, + map[string]*acm.ListTagsForCertificateOutput{ + "foobar": { + Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("true")}}, + }, + "foobaz": { + Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}}, + }, + }, + ), + filterTag: "production=true", + expect: acmExpect{ + ARN: "foobar", + DomainNames: []string{"foobar.de"}, + Error: nil, + }, + }, + { + msg: "ACM Cert with incorrect filter tag should not be found", + api: fake.NewACMClientWithTags( + acm.ListCertificatesOutput{ + CertificateSummaryList: []*acm.CertificateSummary{ + { + CertificateArn: aws.String("foobar"), + DomainName: aws.String("foobar.de"), + }, + }, + }, + map[string]*acm.GetCertificateOutput{ + "foobar": { + Certificate: aws.String(cert), + }, + }, + map[string]*acm.ListTagsForCertificateOutput{ + "foobar": { + Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}}, + }, }, ), + filterTag: "production=true", expect: acmExpect{ + EmptyList: true, ARN: "foobar", DomainNames: []string{"foobar.de"}, Error: nil, @@ -71,7 +146,7 @@ func TestACM(t *testing.T) { }, } { t.Run(ti.msg, func(t *testing.T) { - provider := newACMCertProvider(ti.api) + provider := newACMCertProvider(ti.api, ti.filterTag) list, err := provider.GetCertificates() if ti.expect.Error != nil { @@ -80,11 +155,16 @@ func TestACM(t *testing.T) { require.NoError(t, err) } - require.Equal(t, 1, len(list)) + if ti.expect.EmptyList { + require.Equal(t, 0, len(list)) - cert := list[0] - require.Equal(t, ti.expect.ARN, cert.ID()) - require.Equal(t, ti.expect.DomainNames, cert.DomainNames()) + } else { + require.Equal(t, 1, len(list)) + + cert := list[0] + require.Equal(t, ti.expect.ARN, cert.ID()) + require.Equal(t, ti.expect.DomainNames, cert.DomainNames()) + } }) } } diff --git a/aws/adapter.go b/aws/adapter.go index 3a0fdb30..19df8977 100644 --- a/aws/adapter.go +++ b/aws/adapter.go @@ -273,12 +273,12 @@ func (a *Adapter) UpdateManifest(clusterID, vpcID string) (*Adapter, error) { return a, err } -func (a *Adapter) NewACMCertificateProvider() certs.CertificatesProvider { - return newACMCertProvider(a.acm) +func (a *Adapter) NewACMCertificateProvider(certFilterTag string) certs.CertificatesProvider { + return newACMCertProvider(a.acm, certFilterTag) } -func (a *Adapter) NewIAMCertificateProvider() certs.CertificatesProvider { - return newIAMCertProvider(a.iam) +func (a *Adapter) NewIAMCertificateProvider(certFilterTag string) certs.CertificatesProvider { + return newIAMCertProvider(a.iam, certFilterTag) } // WithHealthCheckPath returns the receiver adapter after changing the health check path that will be used by diff --git a/aws/fake/acm.go b/aws/fake/acm.go index e50556f5..18e7f967 100644 --- a/aws/fake/acm.go +++ b/aws/fake/acm.go @@ -1,6 +1,8 @@ package fake import ( + "fmt" + "github.com/aws/aws-sdk-go/service/acm" "github.com/aws/aws-sdk-go/service/acm/acmiface" ) @@ -8,7 +10,8 @@ import ( type ACMClient struct { acmiface.ACMAPI output acm.ListCertificatesOutput - cert acm.GetCertificateOutput + cert map[string]*acm.GetCertificateOutput + tags map[string]*acm.ListTagsForCertificateOutput } func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) { @@ -21,12 +24,32 @@ func (m ACMClient) ListCertificatesPages(input *acm.ListCertificatesInput, fn fu } func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCertificateOutput, error) { - return &m.cert, nil + return m.cert[*input.CertificateArn], nil +} + +func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) { + if in.CertificateArn == nil { + return nil, fmt.Errorf("expected a valid CertificateArn, got: nil") + } + arn := *in.CertificateArn + return m.tags[arn], nil +} + +func NewACMClient(output acm.ListCertificatesOutput, cert map[string]*acm.GetCertificateOutput) ACMClient { + return ACMClient{ + output: output, + cert: cert, + } } -func NewACMClient(output acm.ListCertificatesOutput, cert acm.GetCertificateOutput) ACMClient { +func NewACMClientWithTags( + output acm.ListCertificatesOutput, + cert map[string]*acm.GetCertificateOutput, + tags map[string]*acm.ListTagsForCertificateOutput, +) ACMClient { return ACMClient{ output: output, cert: cert, + tags: tags, } } diff --git a/aws/fake/iam.go b/aws/fake/iam.go index ae6ff3d5..676fb964 100644 --- a/aws/fake/iam.go +++ b/aws/fake/iam.go @@ -1,6 +1,8 @@ package fake import ( + "fmt" + "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" ) @@ -9,6 +11,7 @@ type IAMClient struct { iamiface.IAMAPI list iam.ListServerCertificatesOutput cert iam.GetServerCertificateOutput + tags map[string]*iam.ListServerCertificateTagsOutput } func (m IAMClient) ListServerCertificates(*iam.ListServerCertificatesInput) (*iam.ListServerCertificatesOutput, error) { @@ -20,6 +23,17 @@ func (m IAMClient) ListServerCertificatesPages(input *iam.ListServerCertificates return nil } +func (m IAMClient) ListServerCertificateTags( + in *iam.ListServerCertificateTagsInput, +) (*iam.ListServerCertificateTagsOutput, error) { + + if in.ServerCertificateName == nil { + return nil, fmt.Errorf("expected a valid CertificateArn, got: nil") + } + name := *in.ServerCertificateName + return m.tags[name], nil +} + func (m IAMClient) GetServerCertificate(*iam.GetServerCertificateInput) (*iam.GetServerCertificateOutput, error) { return &m.cert, nil } @@ -30,3 +44,15 @@ func NewIAMClient(list iam.ListServerCertificatesOutput, cert iam.GetServerCerti cert: cert, } } + +func NewIAMClientWithTag( + list iam.ListServerCertificatesOutput, + cert iam.GetServerCertificateOutput, + tags map[string]*iam.ListServerCertificateTagsOutput, +) IAMClient { + return IAMClient{ + list: list, + cert: cert, + tags: tags, + } +} diff --git a/aws/iam.go b/aws/iam.go index d6ae94ff..f72a8ca8 100644 --- a/aws/iam.go +++ b/aws/iam.go @@ -2,6 +2,7 @@ package aws import ( "crypto/x509" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/iam" @@ -10,11 +11,12 @@ import ( ) type iamCertificateProvider struct { - api iamiface.IAMAPI + api iamiface.IAMAPI + filterTag string } -func newIAMCertProvider(api iamiface.IAMAPI) certs.CertificatesProvider { - return &iamCertificateProvider{api: api} +func newIAMCertProvider(api iamiface.IAMAPI, filterTag string) certs.CertificatesProvider { + return &iamCertificateProvider{api: api, filterTag: filterTag} } // GetCertificates returns a list of AWS IAM certificates @@ -25,6 +27,17 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, } list := make([]*certs.CertificateSummary, 0) for _, o := range serverCertificatesMetadata { + if kv := strings.Split(p.filterTag, "="); p.filterTag != "=" && len(kv) == 2 { + hasTag, err := certHasTag(p.api, *o.ServerCertificateName, kv[0], kv[1]) + if err != nil { + return nil, err + } + + if !hasTag { + continue + } + } + certDetail, err := getCertificateSummaryFromIAM(p.api, aws.StringValue(o.ServerCertificateName)) if err != nil { return nil, err @@ -34,6 +47,22 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, return list, nil } +func certHasTag(api iamiface.IAMAPI, certName, key, value string) (bool, error) { + t, err := api.ListServerCertificateTags(&iam.ListServerCertificateTagsInput{ + ServerCertificateName: &certName, + }) + if err != nil { + return false, err + } + for _, tag := range t.Tags { + if *tag.Key == key && *tag.Value == value { + return true, nil + } + } + + return false, nil +} + func getIAMServerCertificateMetadata(api iamiface.IAMAPI) ([]*iam.ServerCertificateMetadata, error) { params := &iam.ListServerCertificatesInput{ PathPrefix: aws.String("/"), diff --git a/aws/iam_test.go b/aws/iam_test.go index 86b68cad..dc315689 100644 --- a/aws/iam_test.go +++ b/aws/iam_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/service/iam" "github.com/stretchr/testify/require" "github.com/zalando-incubator/kube-ingress-aws-controller/aws/fake" + "github.com/zalando-incubator/kube-ingress-aws-controller/certs" ) func TestIAM(t *testing.T) { @@ -193,3 +194,87 @@ func TestIAMParseError(t *testing.T) { _, err := provider.GetCertificates() require.Equal(t, ErrNoCertificates, err) } + +func TestIAMTagFiltering(t *testing.T) { + foobarNotBefore := time.Date(2017, 3, 29, 16, 11, 32, 0, time.UTC) + foobarNotAfter := time.Date(2027, 3, 27, 16, 11, 32, 0, time.UTC) + foobarIAMCertificate := &iam.ServerCertificate{ + CertificateBody: aws.String(mustRead("foo-iam.txt")), + ServerCertificateMetadata: &iam.ServerCertificateMetadata{ + Arn: aws.String("foobar-arn"), + ServerCertificateName: aws.String("foobar"), + }, + } + + createProviderwithTag := func(key, value string) certs.CertificatesProvider { + api := fake.NewIAMClientWithTag( + iam.ListServerCertificatesOutput{ + ServerCertificateMetadataList: []*iam.ServerCertificateMetadata{ + { + Arn: aws.String("foobar-arn"), + Path: aws.String("/"), + ServerCertificateName: aws.String("foobar"), + }, + }, + }, + iam.GetServerCertificateOutput{ServerCertificate: foobarIAMCertificate}, + map[string]*iam.ListServerCertificateTagsOutput{ + "foobar": { + Tags: []*iam.Tag{{Key: aws.String(key), Value: aws.String(value)}}, + }, + }, + ) + return newIAMCertProvider(api, "production=true") + } + + type expectedValues struct { + EmptyList bool + ARN string + NotBefore time.Time + NotAfter time.Time + DomainNames []string + } + + for _, ti := range []struct { + msg string + provider certs.CertificatesProvider + expect expectedValues + }{ + { + msg: "Certificate with correct key", + provider: createProviderwithTag("production", "true"), + expect: expectedValues{ + EmptyList: false, + ARN: "foobar-arn", + NotBefore: foobarNotBefore, + NotAfter: foobarNotAfter, + DomainNames: []string{"foobar.de"}, + }, + }, + { + msg: "Certificate with incorrect key", + provider: createProviderwithTag("production", "false"), + expect: expectedValues{ + EmptyList: true, + ARN: "foobar-arn", + NotBefore: foobarNotBefore, + NotAfter: foobarNotAfter, + DomainNames: []string{"foobar.de"}, + }, + }, + } { + t.Run(ti.msg, func(tt *testing.T) { + list, err := ti.provider.GetCertificates() + if ti.expect.EmptyList { + require.Equal(tt, 0, len(list)) + } else { + require.Equal(tt, 1, len(list)) + require.NoError(tt, err) + require.Equal(tt, ti.expect.ARN, list[0].ID()) + require.Equal(tt, ti.expect.DomainNames, list[0].DomainNames()) + require.Equal(tt, ti.expect.NotBefore, list[0].NotBefore()) + require.Equal(tt, ti.expect.NotAfter, list[0].NotAfter()) + } + }) + } +} diff --git a/controller.go b/controller.go index 84d9e1a2..c5710764 100644 --- a/controller.go +++ b/controller.go @@ -51,6 +51,7 @@ var ( disableSNISupport bool disableInstrumentedHttpClient bool certTTL time.Duration + certFilterTag string stackTerminationProtection bool additionalStackTags = make(map[string]string) idleConnectionTimeout time.Duration @@ -115,6 +116,9 @@ func loadSettings() error { StringMapVar(&additionalStackTags) kingpin.Flag("cert-ttl-timeout", "sets the timeout of how long a certificate is kept on an old ALB to be decommissioned."). Default(defaultCertTTL).DurationVar(&certTTL) + + kingpin.Flag("cert-filter-tag", "sets a tag so the ingress controller only consider ACM or IAM certificates that have this tag set when adding a certificate to a load balancer."). + Default("").StringVar(&certFilterTag) kingpin.Flag("health-check-path", "sets the health check path for the created target groups"). Default(aws.DefaultHealthCheckPath).StringVar(&healthCheckPath) kingpin.Flag("health-check-port", "sets the health check port for the created target groups"). @@ -256,6 +260,10 @@ func loadSettings() error { cwAlarmConfigMapLocation = loc } + if kv := strings.Split(certFilterTag, "="); len(kv) != 2 && certFilterTag != "" { + log.Errorf("Certificate filter tag should be in the format \"key=value\", instead it is set to: %s", certFilterTag) + } + if quietFlag && debugFlag { log.Warn("--quiet and --debug flags are both set. Debug will be used as logging level.") } @@ -344,8 +352,8 @@ func main() { certificatesProvider, err := certs.NewCachingProvider( certPollingInterval, blacklistCertArnMap, - awsAdapter.NewACMCertificateProvider(), - awsAdapter.NewIAMCertificateProvider(), + awsAdapter.NewACMCertificateProvider(certFilterTag), + awsAdapter.NewIAMCertificateProvider(certFilterTag), ) if err != nil { log.Fatal(err)