Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using custom KMS keys to build private stemcells #31

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea
config.json
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type AmiConfiguration struct {
VirtualizationType string `json:"virtualization_type"`
Encrypted bool `json:"encrypted"`
KmsKeyId string `json:"kms_key_id"`
KmsKeyAliasName string `json:"kms_key_alias_name"`
Visibility string `json:"visibility"`
Tags map[string]string `json:"tags,omitempty"`
}
Expand Down
2 changes: 1 addition & 1 deletion driver/copy_ami_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (d *SDKCopyAmiDriver) Create(driverConfig resources.AmiDriverConfig) (resou
Encrypted: &driverConfig.Encrypted,
}
if driverConfig.KmsKeyId != "" {
lnguyen marked this conversation as resolved.
Show resolved Hide resolved
input.KmsKeyId = &driverConfig.KmsKeyId
input.KmsKeyId = &driverConfig.KmsKey.ARN
}
output, err := ec2Client.CopyImage(input)
if err != nil {
Expand Down
131 changes: 131 additions & 0 deletions driver/kms_driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package driver

import (
"fmt"
"io"
"light-stemcell-builder/config"
"light-stemcell-builder/resources"
"log"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
)

type SDKKmsDriver struct {
creds config.Credentials
logger *log.Logger
}

func NewKmsDriver(logDest io.Writer, creds config.Credentials) *SDKKmsDriver {
logger := log.New(logDest, "KmsDriver ", log.LstdFlags)

return &SDKKmsDriver{creds: creds, logger: logger}
}

func (d *SDKKmsDriver) CreateAlias(driverConfig resources.KmsCreateAliasDriverConfig) (resources.KmsAlias, error) {
if driverConfig.KmsKeyId == "" {
return resources.KmsAlias{}, nil
}

createStartTime := time.Now()
defer func(startTime time.Time) {
d.logger.Printf("Completed CreateKeyAlias() in %f minutes\n", time.Since(startTime).Minutes())
}(createStartTime)

kmsClient := d.createKmsClient(driverConfig.Region)

d.logger.Printf("Creating alias: %s\n", driverConfig.KmsKeyAliasName)
_, err := kmsClient.CreateAlias(&kms.CreateAliasInput{
AliasName: &driverConfig.KmsKeyAliasName,
TargetKeyId: &driverConfig.KmsKeyId,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case kms.ErrCodeAlreadyExistsException:
d.logger.Printf("Alias %s already exists\n", driverConfig.KmsKeyAliasName)
default:
return resources.KmsAlias{}, fmt.Errorf("failed to create alias: %s", err)
}
} else {
return resources.KmsAlias{}, fmt.Errorf("failed to create alias: %s", err)
}
}

d.logger.Printf("Checking existence of alias: %s\n", driverConfig.KmsKeyAliasName)
listAliasResult, err := kmsClient.ListAliases(&kms.ListAliasesInput{
KeyId: &driverConfig.KmsKeyId,
})
if err != nil {
return resources.KmsAlias{}, fmt.Errorf("checking alias existence: %s", err)
}

for i := range listAliasResult.Aliases {
if *listAliasResult.Aliases[i].AliasName == driverConfig.KmsKeyAliasName {
d.logger.Printf("Reusing existing alias: %s\n", driverConfig.KmsKeyAliasName)
return resources.KmsAlias{
TargetKeyId: *listAliasResult.Aliases[i].TargetKeyId,
ARN: *listAliasResult.Aliases[i].AliasArn,
}, nil
}
}

return resources.KmsAlias{}, fmt.Errorf("could not find existing alias: %s", err)
}

func (d *SDKKmsDriver) ReplicateKey(driverConfig resources.KmsReplicateKeyDriverConfig) (resources.KmsKey, error) {
if driverConfig.KmsKeyId == "" {
return resources.KmsKey{}, nil
}

createStartTime := time.Now()
defer func(startTime time.Time) {
d.logger.Printf("Completed ReplicateKey() in %f minutes\n", time.Since(startTime).Minutes())
}(createStartTime)

d.logger.Printf("Replicating kms key: %s\n", driverConfig.KmsKeyId)
_, err := d.createKmsClient(driverConfig.SourceRegion).ReplicateKey(&kms.ReplicateKeyInput{
KeyId: &driverConfig.KmsKeyId,
ReplicaRegion: &driverConfig.TargetRegion,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case kms.ErrCodeAlreadyExistsException:
d.logger.Printf("Kms key %s already replicated\n", driverConfig.KmsKeyId)
default:
return resources.KmsKey{}, fmt.Errorf("failed to replicate key: %s", err)
}
} else {
return resources.KmsKey{}, fmt.Errorf("failed to replicate key: %s", err)
}
}

listKeyResult, err := d.createKmsClient(driverConfig.TargetRegion).ListKeys(&kms.ListKeysInput{})
for i := range listKeyResult.Keys {
if strings.HasSuffix(driverConfig.KmsKeyId, *listKeyResult.Keys[i].KeyId) {
return resources.KmsKey{
ARN: *listKeyResult.Keys[i].KeyArn,
}, nil
}
}

return resources.KmsKey{}, fmt.Errorf("could not replicated kms key: %s", err)
}

func (d *SDKKmsDriver) createKmsClient(region string) *kms.KMS {
creds := config.Credentials{
AccessKey: d.creds.AccessKey,
SecretKey: d.creds.SecretKey,
RoleArn: d.creds.RoleArn,
Region: region,
}

awsConfig := creds.GetAwsConfig().
WithLogger(newDriverLogger(d.logger))

return kms.New(session.Must(session.NewSession(awsConfig)))
}
100 changes: 100 additions & 0 deletions driver/kms_driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package driver_test

import (
"light-stemcell-builder/driverset"
"light-stemcell-builder/resources"
"math/rand"
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("KmsDriver", func() {
It("creates an alias for a given kms key", func() {
aliasName := "alias/" + strconv.Itoa(rand.Int())

driverConfig := resources.KmsCreateAliasDriverConfig{
KmsKeyAliasName: aliasName,
KmsKeyId: kmsKeyId,
Region: creds.Region,
}
ds := driverset.NewStandardRegionDriverSet(GinkgoWriter, creds)
driver := ds.KmsDriver()

aliasCreationResult, err := driver.CreateAlias(driverConfig)
Expect(err).ToNot(HaveOccurred())

//defer cleanup of the created alias
defer func(aliasName string, aliasCreationResult resources.KmsAlias) {
awsSession, _ := session.NewSession(creds.GetAwsConfig())
kmsClient := kms.New(awsSession)
_, _ = kmsClient.DeleteAlias(&kms.DeleteAliasInput{
AliasName: &aliasName,
})
}(aliasName, aliasCreationResult)

awsSession, err := session.NewSession(creds.GetAwsConfig())
Expect(err).ToNot(HaveOccurred())
kmsClient := kms.New(awsSession)
listAliasResult, err := kmsClient.ListAliases(&kms.ListAliasesInput{
KeyId: &kmsKeyId,
})
Expect(err).ToNot(HaveOccurred())

aliasCount := 0
for i := range listAliasResult.Aliases {
if *listAliasResult.Aliases[i].AliasName == aliasName {
aliasCount++
}
}
Expect(aliasCount).To(Equal(1))
})

It("replicates a given kms key to another region", func() {
driverConfig := resources.KmsReplicateKeyDriverConfig{
KmsKeyId: kmsKeyId,
SourceRegion: creds.Region,
TargetRegion: destinationRegion,
}
ds := driverset.NewStandardRegionDriverSet(GinkgoWriter, creds)
driver := ds.KmsDriver()

replicateKeyResult, err := driver.ReplicateKey(driverConfig)
Expect(err).ToNot(HaveOccurred())

original_region := creds.Region
creds.Region = destinationRegion

//defer cleanup of the created key replica, sadly we can only schedule it to be deleted after 7 days
//therefore this test will reuse the replicated key for 7 days and only afterwards create a new one
defer func(aliasCreationResult resources.KmsKey) {
destinationKeyId := strings.ReplaceAll(kmsKeyId, original_region, destinationRegion)
awsSession, _ := session.NewSession(creds.GetAwsConfig())
kmsClient := kms.New(awsSession)

_, _ = kmsClient.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{
KeyId: &destinationKeyId,
PendingWindowInDays: aws.Int64(7),
})
}(replicateKeyResult)

awsSession, err := session.NewSession(creds.GetAwsConfig())
Expect(err).ToNot(HaveOccurred())
kmsClient := kms.New(awsSession)
listKeyResult, err := kmsClient.ListKeys(&kms.ListKeysInput{})
Expect(err).ToNot(HaveOccurred())

keysCount := 0
for i := range listKeyResult.Keys {
if strings.HasSuffix(driverConfig.KmsKeyId, *listKeyResult.Keys[i].KeyId) {
keysCount++
}
}
Expect(keysCount).To(Equal(1))
})
})
12 changes: 10 additions & 2 deletions driver/snapshot_from_image_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,20 @@ func (d *SDKSnapshotFromImageDriver) Create(driverConfig resources.SnapshotDrive
}(createStartTime)

d.logger.Printf("initiating ImportSnapshot task from image: %s\n", driverConfig.MachineImageURL)
reqOutput, err := d.ec2Client.ImportSnapshot(&ec2.ImportSnapshotInput{

input := &ec2.ImportSnapshotInput{
DiskContainer: &ec2.SnapshotDiskContainer{
Url: &driverConfig.MachineImageURL,
Format: aws.String(driverConfig.FileFormat),
},
})
Encrypted: &driverConfig.AmiProperties.Encrypted,
}

if driverConfig.KmsAlias.ARN != "" {
input.KmsKeyId = &driverConfig.KmsAlias.ARN
}

reqOutput, err := d.ec2Client.ImportSnapshot(input)
if err != nil {
return resources.Snapshot{}, fmt.Errorf("creating import snapshot task: %s", err)
}
Expand Down
65 changes: 65 additions & 0 deletions driverset/driversetfakes/fake_standard_region_driver_set.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions driverset/standard_aws_region.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ type StandardRegionDriverSet interface {
CreateSnapshotDriver() resources.SnapshotDriver
CreateAmiDriver() resources.AmiDriver
CopyAmiDriver() resources.AmiDriver
KmsDriver() resources.KmsDriver
}

type standardRegionDriverSet struct {
machineImageDriver resources.MachineImageDriver
snapshotDriver *driver.SDKSnapshotFromImageDriver
amiDriver *driver.SDKCreateAmiDriver
copyAmiDriver *driver.SDKCopyAmiDriver
kmsDriver *driver.SDKKmsDriver
}

func NewStandardRegionDriverSet(logDest io.Writer, creds config.Credentials) StandardRegionDriverSet {
Expand All @@ -35,6 +37,7 @@ func NewStandardRegionDriverSet(logDest io.Writer, creds config.Credentials) Sta
snapshotDriver: driver.NewSnapshotFromImageDriver(logDest, creds),
amiDriver: driver.NewCreateAmiDriver(logDest, creds),
copyAmiDriver: driver.NewCopyAmiDriver(logDest, creds),
kmsDriver: driver.NewKmsDriver(logDest, creds),
}
}

Expand All @@ -53,3 +56,7 @@ func (s *standardRegionDriverSet) CreateAmiDriver() resources.AmiDriver {
func (s *standardRegionDriverSet) CopyAmiDriver() resources.AmiDriver {
return s.copyAmiDriver
}

func (s *standardRegionDriverSet) KmsDriver() resources.KmsDriver {
return s.kmsDriver
}
Loading