Skip to content

Commit

Permalink
Support using custom KMS keys to build private stemcells
Browse files Browse the repository at this point in the history
When building a private stemcells, the builder currently uses
a managed KMS key which is default AWS account key. Using this key
prevents sharing stemcells across accounts. Therefore we add the
custom KMS key support.
  • Loading branch information
mvach committed Nov 16, 2023
1 parent 2fa49d1 commit fed3ae5
Show file tree
Hide file tree
Showing 21 changed files with 23,830 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea
config.json
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type AmiConfiguration struct {
VirtualizationType string `json:"virtualization_type"`
Encrypted bool `json:"encrypted"`
KmsKeyId string `json:"kms_key_id"`
KmsKeyAliasName string `json:"kms_key_alias_name"`
Visibility string `json:"visibility"`
Tags map[string]string `json:"tags,omitempty"`
}
Expand Down
2 changes: 1 addition & 1 deletion driver/copy_ami_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (d *SDKCopyAmiDriver) Create(driverConfig resources.AmiDriverConfig) (resou
Encrypted: &driverConfig.Encrypted,
}
if driverConfig.KmsKeyId != "" {
input.KmsKeyId = &driverConfig.KmsKeyId
input.KmsKeyId = &driverConfig.KmsKey.ARN
}
output, err := ec2Client.CopyImage(input)
if err != nil {
Expand Down
131 changes: 131 additions & 0 deletions driver/kms_driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package driver

import (
"fmt"
"io"
"light-stemcell-builder/config"
"light-stemcell-builder/resources"
"log"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
)

type SDKKmsDriver struct {
creds config.Credentials
logger *log.Logger
}

func NewKmsDriver(logDest io.Writer, creds config.Credentials) *SDKKmsDriver {
logger := log.New(logDest, "KmsDriver ", log.LstdFlags)

return &SDKKmsDriver{creds: creds, logger: logger}
}

func (d *SDKKmsDriver) CreateAlias(driverConfig resources.KmsCreateAliasDriverConfig) (resources.KmsAlias, error) {
if driverConfig.KmsKeyId == "" {
return resources.KmsAlias{}, nil
}

createStartTime := time.Now()
defer func(startTime time.Time) {
d.logger.Printf("Completed CreateKeyAlias() in %f minutes\n", time.Since(startTime).Minutes())
}(createStartTime)

kmsClient := d.createKmsClient(driverConfig.Region)

d.logger.Printf("Creating alias: %s\n", driverConfig.KmsKeyAliasName)
_, err := kmsClient.CreateAlias(&kms.CreateAliasInput{
AliasName: &driverConfig.KmsKeyAliasName,
TargetKeyId: &driverConfig.KmsKeyId,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case kms.ErrCodeAlreadyExistsException:
d.logger.Printf("Alias %s already exists\n", driverConfig.KmsKeyAliasName)
default:
return resources.KmsAlias{}, fmt.Errorf("failed to create alias: %s", err)
}
} else {
return resources.KmsAlias{}, fmt.Errorf("failed to create alias: %s", err)
}
}

d.logger.Printf("Checking existence of alias: %s\n", driverConfig.KmsKeyAliasName)
listAliasResult, err := kmsClient.ListAliases(&kms.ListAliasesInput{
KeyId: &driverConfig.KmsKeyId,
})
if err != nil {
return resources.KmsAlias{}, fmt.Errorf("checking alias existence: %s", err)
}

for i := range listAliasResult.Aliases {
if *listAliasResult.Aliases[i].AliasName == driverConfig.KmsKeyAliasName {
d.logger.Printf("Reusing existing alias: %s\n", driverConfig.KmsKeyAliasName)
return resources.KmsAlias{
TargetKeyId: *listAliasResult.Aliases[i].TargetKeyId,
ARN: *listAliasResult.Aliases[i].AliasArn,
}, nil
}
}

return resources.KmsAlias{}, fmt.Errorf("could not find existing alias: %s", err)
}

func (d *SDKKmsDriver) ReplicateKey(driverConfig resources.KmsReplicateKeyDriverConfig) (resources.KmsKey, error) {
if driverConfig.KmsKeyId == "" {
return resources.KmsKey{}, nil
}

createStartTime := time.Now()
defer func(startTime time.Time) {
d.logger.Printf("Completed ReplicateKey() in %f minutes\n", time.Since(startTime).Minutes())
}(createStartTime)

d.logger.Printf("Replicating kms key: %s\n", driverConfig.KmsKeyId)
_, err := d.createKmsClient(driverConfig.SourceRegion).ReplicateKey(&kms.ReplicateKeyInput{
KeyId: &driverConfig.KmsKeyId,
ReplicaRegion: &driverConfig.TargetRegion,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case kms.ErrCodeAlreadyExistsException:
d.logger.Printf("Kms key %s already replicated\n", driverConfig.KmsKeyId)
default:
return resources.KmsKey{}, fmt.Errorf("failed to replicate key: %s", err)
}
} else {
return resources.KmsKey{}, fmt.Errorf("failed to replicate key: %s", err)
}
}

listKeyResult, err := d.createKmsClient(driverConfig.TargetRegion).ListKeys(&kms.ListKeysInput{})
for i := range listKeyResult.Keys {
if strings.HasSuffix(driverConfig.KmsKeyId, *listKeyResult.Keys[i].KeyId) {
return resources.KmsKey{
ARN: *listKeyResult.Keys[i].KeyArn,
}, nil
}
}

return resources.KmsKey{}, fmt.Errorf("could not replicated kms key: %s", err)
}

func (d *SDKKmsDriver) createKmsClient(region string) *kms.KMS {
creds := config.Credentials{
AccessKey: d.creds.AccessKey,
SecretKey: d.creds.SecretKey,
RoleArn: d.creds.RoleArn,
Region: region,
}

awsConfig := creds.GetAwsConfig().
WithLogger(newDriverLogger(d.logger))

return kms.New(session.Must(session.NewSession(awsConfig)))
}
100 changes: 100 additions & 0 deletions driver/kms_driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package driver_test

import (
"light-stemcell-builder/driverset"
"light-stemcell-builder/resources"
"math/rand"
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("KmsDriver", func() {
It("creates an alias for a given kms key", func() {
aliasName := "alias/" + strconv.Itoa(rand.Int())

driverConfig := resources.KmsCreateAliasDriverConfig{
KmsKeyAliasName: aliasName,
KmsKeyId: kmsKeyId,
Region: creds.Region,
}
ds := driverset.NewStandardRegionDriverSet(GinkgoWriter, creds)
driver := ds.KmsDriver()

aliasCreationResult, err := driver.CreateAlias(driverConfig)
Expect(err).ToNot(HaveOccurred())

//defer cleanup of the created alias
defer func(aliasName string, aliasCreationResult resources.KmsAlias) {
awsSession, _ := session.NewSession(creds.GetAwsConfig())
kmsClient := kms.New(awsSession)
_, _ = kmsClient.DeleteAlias(&kms.DeleteAliasInput{
AliasName: &aliasName,
})
}(aliasName, aliasCreationResult)

awsSession, err := session.NewSession(creds.GetAwsConfig())
Expect(err).ToNot(HaveOccurred())
kmsClient := kms.New(awsSession)
listAliasResult, err := kmsClient.ListAliases(&kms.ListAliasesInput{
KeyId: &kmsKeyId,
})
Expect(err).ToNot(HaveOccurred())

aliasCount := 0
for i := range listAliasResult.Aliases {
if *listAliasResult.Aliases[i].AliasName == aliasName {
aliasCount++
}
}
Expect(aliasCount).To(Equal(1))
})

It("replicates a given kms key to another region", func() {
driverConfig := resources.KmsReplicateKeyDriverConfig{
KmsKeyId: kmsKeyId,
SourceRegion: creds.Region,
TargetRegion: destinationRegion,
}
ds := driverset.NewStandardRegionDriverSet(GinkgoWriter, creds)
driver := ds.KmsDriver()

replicateKeyResult, err := driver.ReplicateKey(driverConfig)
Expect(err).ToNot(HaveOccurred())

original_region := creds.Region
creds.Region = destinationRegion

//defer cleanup of the created key replica, sadly we can only schedule it to be deleted after 7 days
//therefore this test will reuse the replicated key for 7 days and only afterwards create a new one
defer func(aliasCreationResult resources.KmsKey) {
destinationKeyId := strings.ReplaceAll(kmsKeyId, original_region, destinationRegion)
awsSession, _ := session.NewSession(creds.GetAwsConfig())
kmsClient := kms.New(awsSession)

_, _ = kmsClient.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{
KeyId: &destinationKeyId,
PendingWindowInDays: aws.Int64(7),
})
}(replicateKeyResult)

awsSession, err := session.NewSession(creds.GetAwsConfig())
Expect(err).ToNot(HaveOccurred())
kmsClient := kms.New(awsSession)
listKeyResult, err := kmsClient.ListKeys(&kms.ListKeysInput{})
Expect(err).ToNot(HaveOccurred())

keysCount := 0
for i := range listKeyResult.Keys {
if strings.HasSuffix(driverConfig.KmsKeyId, *listKeyResult.Keys[i].KeyId) {
keysCount++
}
}
Expect(keysCount).To(Equal(1))
})
})
12 changes: 10 additions & 2 deletions driver/snapshot_from_image_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,20 @@ func (d *SDKSnapshotFromImageDriver) Create(driverConfig resources.SnapshotDrive
}(createStartTime)

d.logger.Printf("initiating ImportSnapshot task from image: %s\n", driverConfig.MachineImageURL)
reqOutput, err := d.ec2Client.ImportSnapshot(&ec2.ImportSnapshotInput{

input := &ec2.ImportSnapshotInput{
DiskContainer: &ec2.SnapshotDiskContainer{
Url: &driverConfig.MachineImageURL,
Format: aws.String(driverConfig.FileFormat),
},
})
Encrypted: &driverConfig.AmiProperties.Encrypted,
}

if driverConfig.KmsAlias.ARN != "" {
input.KmsKeyId = &driverConfig.KmsAlias.ARN
}

reqOutput, err := d.ec2Client.ImportSnapshot(input)
if err != nil {
return resources.Snapshot{}, fmt.Errorf("creating import snapshot task: %s", err)
}
Expand Down
65 changes: 65 additions & 0 deletions driverset/driversetfakes/fake_standard_region_driver_set.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions driverset/standard_aws_region.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ type StandardRegionDriverSet interface {
CreateSnapshotDriver() resources.SnapshotDriver
CreateAmiDriver() resources.AmiDriver
CopyAmiDriver() resources.AmiDriver
KmsDriver() resources.KmsDriver
}

type standardRegionDriverSet struct {
machineImageDriver resources.MachineImageDriver
snapshotDriver *driver.SDKSnapshotFromImageDriver
amiDriver *driver.SDKCreateAmiDriver
copyAmiDriver *driver.SDKCopyAmiDriver
kmsDriver *driver.SDKKmsDriver
}

func NewStandardRegionDriverSet(logDest io.Writer, creds config.Credentials) StandardRegionDriverSet {
Expand All @@ -35,6 +37,7 @@ func NewStandardRegionDriverSet(logDest io.Writer, creds config.Credentials) Sta
snapshotDriver: driver.NewSnapshotFromImageDriver(logDest, creds),
amiDriver: driver.NewCreateAmiDriver(logDest, creds),
copyAmiDriver: driver.NewCopyAmiDriver(logDest, creds),
kmsDriver: driver.NewKmsDriver(logDest, creds),
}
}

Expand All @@ -53,3 +56,7 @@ func (s *standardRegionDriverSet) CreateAmiDriver() resources.AmiDriver {
func (s *standardRegionDriverSet) CopyAmiDriver() resources.AmiDriver {
return s.copyAmiDriver
}

func (s *standardRegionDriverSet) KmsDriver() resources.KmsDriver {
return s.kmsDriver
}
Loading

0 comments on commit fed3ae5

Please sign in to comment.