Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to filter certificates by tag before adding it to LB #658

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions aws/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"crypto/x509"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/acm"
Expand All @@ -10,16 +11,17 @@ import (
)

type acmCertificateProvider struct {
api acmiface.ACMAPI
api acmiface.ACMAPI
filterTag string
}

func newACMCertProvider(api acmiface.ACMAPI) certs.CertificatesProvider {
return &acmCertificateProvider{api: api}
func newACMCertProvider(api acmiface.ACMAPI, certFilterTag string) certs.CertificatesProvider {
return &acmCertificateProvider{api: api, filterTag: certFilterTag}
}

// GetCertificates returns a list of AWS ACM certificates
func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, error) {
acmSummaries, err := getACMCertificateSummaries(p.api)
acmSummaries, err := getACMCertificateSummaries(p.api, p.filterTag)
if err != nil {
return nil, err
}
Expand All @@ -34,20 +36,47 @@ func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
return result, nil
}

func getACMCertificateSummaries(api acmiface.ACMAPI) ([]*acm.CertificateSummary, error) {
func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.CertificateSummary, error) {
params := &acm.ListCertificatesInput{
CertificateStatuses: []*string{
aws.String(acm.CertificateStatusIssued),
},
}
acmSummaries := make([]*acm.CertificateSummary, 0)

err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool {
acmSummaries = append(acmSummaries, page.CertificateSummaryList...)
return true
})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a

if filterTag = "" {
return acmSummaries
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but I believe this case is covered in the if in line 52, because if filterTag == "" then len(tag) != 2, and in this case we will just return acmSummaries in line 56, which is just after the if.

if tag := strings.Split(filterTag, "="); filterTag != "=" && len(tag) == 2 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: I think for readability it's better to create two if's with the same return here.
The first if will check filterTag, the second will split and check len.
With such two if's, the filterTag != "=" won't be lost somewhere in the middle of the line and would be much easier to notice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how filterTag != "=" can help in the condition. I think len(tag)==2 is enough, because if filterTag == "=" then it will have 3 parts after split.

Copy link
Member

@szuecs szuecs Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah https://go.dev/play/p/oGRkFVn_zu9 now I got it :D
So I am fine with the condition.

return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1])
}

return acmSummaries, err
}

func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) {
prodSummaries := make([]*acm.CertificateSummary, 0)
RomanZavodskikh marked this conversation as resolved.
Show resolved Hide resolved
for _, summary := range allSummaries {
in := &acm.ListTagsForCertificateInput{
CertificateArn: summary.CertificateArn,
}
out, err := api.ListTagsForCertificate(in)
if err != nil {
return nil, err
}

for _, tag := range out.Tags {
if *tag.Key == key && *tag.Value == value {
prodSummaries = append(prodSummaries, summary)
}
}
}

return prodSummaries, nil
}

func getCertificateSummaryFromACM(api acmiface.ACMAPI, arn *string) (*certs.CertificateSummary, error) {
params := &acm.GetCertificateInput{CertificateArn: arn}
resp, err := api.GetCertificate(params)
Expand Down
78 changes: 70 additions & 8 deletions aws/acm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ type acmExpect struct {
DomainNames []string
Chain int
Error error
EmptyList bool
}

func TestACM(t *testing.T) {
cert := mustRead("acm.txt")
chain := mustRead("chain.txt")

for _, ti := range []struct {
msg string
api acmiface.ACMAPI
expect acmExpect
msg string
api acmiface.ACMAPI
filterTag string
expect acmExpect
}{
{
msg: "Found ACM Cert foobar and a chain",
Expand Down Expand Up @@ -69,9 +71,64 @@ func TestACM(t *testing.T) {
Error: nil,
},
},
{
msg: "Found ACM Cert with correct filter tag",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to add tests with several certificates in the list.
For example, "2 certs with corresponding tag" or "1 cert with corresponding tag + 1 cert with not corresponding tag"?

api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("true")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
{
msg: "ACM Cert with incorrect filter tag should not be found",
api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
EmptyList: true,
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
} {
t.Run(ti.msg, func(t *testing.T) {
provider := newACMCertProvider(ti.api)
provider := newACMCertProvider(ti.api, ti.filterTag)
list, err := provider.GetCertificates()

if ti.expect.Error != nil {
Expand All @@ -80,11 +137,16 @@ func TestACM(t *testing.T) {
require.NoError(t, err)
}

require.Equal(t, 1, len(list))
if ti.expect.EmptyList {
require.Equal(t, 0, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
} else {
require.Equal(t, 1, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
}
})
}
}
8 changes: 4 additions & 4 deletions aws/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ func (a *Adapter) UpdateManifest(clusterID, vpcID string) (*Adapter, error) {
return a, err
}

func (a *Adapter) NewACMCertificateProvider() certs.CertificatesProvider {
return newACMCertProvider(a.acm)
func (a *Adapter) NewACMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newACMCertProvider(a.acm, certFilterTag)
}

func (a *Adapter) NewIAMCertificateProvider() certs.CertificatesProvider {
return newIAMCertProvider(a.iam)
func (a *Adapter) NewIAMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newIAMCertProvider(a.iam, certFilterTag)
}

// WithHealthCheckPath returns the receiver adapter after changing the health check path that will be used by
Expand Down
23 changes: 23 additions & 0 deletions aws/fake/acm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/acm"
"github.com/aws/aws-sdk-go/service/acm/acmiface"
)
Expand All @@ -9,6 +11,7 @@ type ACMClient struct {
acmiface.ACMAPI
output acm.ListCertificatesOutput
cert acm.GetCertificateOutput
tags map[string]*acm.ListTagsForCertificateOutput
}

func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) {
Expand All @@ -24,9 +27,29 @@ func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCerti
return &m.cert, nil
}

func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) {
if in.CertificateArn == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
arn := *in.CertificateArn
return m.tags[arn], nil
}

func NewACMClient(output acm.ListCertificatesOutput, cert acm.GetCertificateOutput) ACMClient {
return ACMClient{
output: output,
cert: cert,
}
}

func NewACMClientWithTags(
output acm.ListCertificatesOutput,
cert acm.GetCertificateOutput,
tags map[string]*acm.ListTagsForCertificateOutput,
) ACMClient {
return ACMClient{
output: output,
cert: cert,
tags: tags,
}
}
26 changes: 26 additions & 0 deletions aws/fake/iam.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
)
Expand All @@ -9,6 +11,7 @@ type IAMClient struct {
iamiface.IAMAPI
list iam.ListServerCertificatesOutput
cert iam.GetServerCertificateOutput
tags map[string]*iam.ListServerCertificateTagsOutput
}

func (m IAMClient) ListServerCertificates(*iam.ListServerCertificatesInput) (*iam.ListServerCertificatesOutput, error) {
Expand All @@ -20,6 +23,17 @@ func (m IAMClient) ListServerCertificatesPages(input *iam.ListServerCertificates
return nil
}

func (m IAMClient) ListServerCertificateTags(
in *iam.ListServerCertificateTagsInput,
) (*iam.ListServerCertificateTagsOutput, error) {

if in.ServerCertificateName == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
name := *in.ServerCertificateName
return m.tags[name], nil
}

func (m IAMClient) GetServerCertificate(*iam.GetServerCertificateInput) (*iam.GetServerCertificateOutput, error) {
return &m.cert, nil
}
Expand All @@ -30,3 +44,15 @@ func NewIAMClient(list iam.ListServerCertificatesOutput, cert iam.GetServerCerti
cert: cert,
}
}

func NewIAMClientWithTag(
list iam.ListServerCertificatesOutput,
cert iam.GetServerCertificateOutput,
tags map[string]*iam.ListServerCertificateTagsOutput,
) IAMClient {
return IAMClient{
list: list,
cert: cert,
tags: tags,
}
}
35 changes: 32 additions & 3 deletions aws/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"crypto/x509"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
Expand All @@ -10,11 +11,12 @@ import (
)

type iamCertificateProvider struct {
api iamiface.IAMAPI
api iamiface.IAMAPI
filterTag string
}

func newIAMCertProvider(api iamiface.IAMAPI) certs.CertificatesProvider {
return &iamCertificateProvider{api: api}
func newIAMCertProvider(api iamiface.IAMAPI, filterTag string) certs.CertificatesProvider {
return &iamCertificateProvider{api: api, filterTag: filterTag}
}

// GetCertificates returns a list of AWS IAM certificates
Expand All @@ -25,6 +27,17 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
}
list := make([]*certs.CertificateSummary, 0)
for _, o := range serverCertificatesMetadata {
if kv := strings.Split(p.filterTag, "="); p.filterTag != "=" && len(kv) == 2 {
hasTag, err := certHasTag(p.api, *o.ServerCertificateName, kv[0], kv[1])
if err != nil {
return nil, err
}

if !hasTag {
continue
}
}

certDetail, err := getCertificateSummaryFromIAM(p.api, aws.StringValue(o.ServerCertificateName))
if err != nil {
return nil, err
Expand All @@ -34,6 +47,22 @@ func (p *iamCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
return list, nil
}

func certHasTag(api iamiface.IAMAPI, certName, key, value string) (bool, error) {
t, err := api.ListServerCertificateTags(&iam.ListServerCertificateTagsInput{
ServerCertificateName: &certName,
})
if err != nil {
return false, err
}
for _, tag := range t.Tags {
if *tag.Key == key && *tag.Value == value {
return true, nil
}
}

return false, nil
}

func getIAMServerCertificateMetadata(api iamiface.IAMAPI) ([]*iam.ServerCertificateMetadata, error) {
params := &iam.ListServerCertificatesInput{
PathPrefix: aws.String("/"),
Expand Down
Loading
Loading