Skip to content

Commit

Permalink
feat: support the option to use aws-sdk-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
spennymac committed Mar 26, 2024
1 parent fda348b commit 6d9d7cf
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 34 deletions.
14 changes: 14 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,24 @@ go 1.21

require (
github.com/aws/aws-sdk-go v1.49.21
github.com/aws/aws-sdk-go-v2 v1.26.0
github.com/aws/aws-sdk-go-v2/config v1.27.9
github.com/aws/aws-sdk-go-v2/service/kms v1.30.0
github.com/tink-crypto/tink-go/v2 v2.1.0
)

require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.9 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.0 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.3 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.5 // indirect
github.com/aws/smithy-go v1.20.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/sys v0.13.0 // indirect
Expand Down
28 changes: 28 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
github.com/aws/aws-sdk-go v1.49.21 h1:Rl8KW6HqkwzhATwvXhyr7vD4JFUMi7oXGAw9SrxxIFY=
github.com/aws/aws-sdk-go v1.49.21/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go-v2 v1.26.0 h1:/Ce4OCiM3EkpW7Y+xUnfAFpchU78K7/Ug01sZni9PgA=
github.com/aws/aws-sdk-go-v2 v1.26.0/go.mod h1:35hUlJVYd+M++iLI3ALmVwMOyRYMmRqUXpTtRGW+K9I=
github.com/aws/aws-sdk-go-v2/config v1.27.9 h1:gRx/NwpNEFSk+yQlgmk1bmxxvQ5TyJ76CWXs9XScTqg=
github.com/aws/aws-sdk-go-v2/config v1.27.9/go.mod h1:dK1FQfpwpql83kbD873E9vz4FyAxuJtR22wzoXn3qq0=
github.com/aws/aws-sdk-go-v2/credentials v1.17.9 h1:N8s0/7yW+h8qR8WaRlPQeJ6czVMNQVNtNdUqf6cItao=
github.com/aws/aws-sdk-go-v2/credentials v1.17.9/go.mod h1:446YhIdmSV0Jf/SLafGZalQo+xr2iw7/fzXGDPTU1yQ=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.0 h1:af5YzcLf80tv4Em4jWVD75lpnOHSBkPUZxZfGkrI3HI=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.0/go.mod h1:nQ3how7DMnFMWiU1SpECohgC82fpn4cKZ875NDMmwtA=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.4 h1:0ScVK/4qZ8CIW0k8jOeFVsyS/sAiXpYxRBLolMkuLQM=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.4/go.mod h1:84KyjNZdHC6QZW08nfHI6yZgPd+qRgaWcYsyLUo3QY8=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.4 h1:sHmMWWX5E7guWEFQ9SVo6A3S4xpPrWnd77a6y4WM6PU=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.4/go.mod h1:WjpDrhWisWOIoS9n3nk67A3Ll1vfULJ9Kq6h29HTD48=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 h1:EyBZibRTVAs6ECHZOw5/wlylS9OcTzwyjeQMudmREjE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1/go.mod h1:JKpmtYhhPs7D97NL/ltqz7yCkERFW5dOlHyVl66ZYF8=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.6 h1:b+E7zIUHMmcB4Dckjpkapoy47W6C9QBv/zoUP+Hn8Kc=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.6/go.mod h1:S2fNV0rxrP78NhPbCZeQgY8H9jdDMeGtwcfZIRxzBqU=
github.com/aws/aws-sdk-go-v2/service/kms v1.30.0 h1:yS0JkEdV6h9JOo8sy2JSpjX+i7vsKifU8SIeHrqiDhU=
github.com/aws/aws-sdk-go-v2/service/kms v1.30.0/go.mod h1:+I8VUUSVD4p5ISQtzpgSva4I8cJ4SQ4b1dcBcof7O+g=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.3 h1:mnbuWHOcM70/OFUlZZ5rcdfA8PflGXXiefU/O+1S3+8=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.3/go.mod h1:5HFu51Elk+4oRBZVxmHrSds5jFXmFj8C3w7DVF2gnrs=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.3 h1:uLq0BKatTmDzWa/Nu4WO0M1AaQDaPpwTKAeByEc6WFM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.3/go.mod h1:b+qdhjnxj8GSR6t5YfphOffeoQSQ1KmpoVVuBn+PWxs=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.5 h1:J/PpTf/hllOjx8Xu9DMflff3FajfLxqM5+tepvVXmxg=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.5/go.mod h1:0ih0Z83YDH/QeQ6Ori2yGE2XvWYv/Xm+cZc01LC6oK0=
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
Expand Down
1 change: 0 additions & 1 deletion integration/awskms/aws_kms_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
//
////////////////////////////////////////////////////////////////////////////////

// Package awskms provides integration with the AWS Key Management Service.
package awskms

import (
Expand Down
154 changes: 143 additions & 11 deletions integration/awskms/aws_kms_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
package awskms

import (
"context"
"encoding/csv"
"errors"
"fmt"
"os"
"regexp"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/config"
kmsv2 "github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -35,7 +39,8 @@ import (
)

const (
awsPrefix = "aws-kms://"
awsPrefix = "aws-kms://"
defaultTimeout = 5 * time.Second
)

var (
Expand All @@ -44,12 +49,18 @@ var (
errCredCSV = errors.New("malformed credential CSV file")
)

type V2KMS interface {
Encrypt(ctx context.Context, params *kmsv2.EncryptInput, optFns ...func(*kmsv2.Options)) (*kmsv2.EncryptOutput, error)
Decrypt(ctx context.Context, params *kmsv2.DecryptInput, optFns ...func(*kmsv2.Options)) (*kmsv2.DecryptOutput, error)
}

// awsClient is a wrapper around an AWS SDK provided KMS client that can
// instantiate Tink primitives.
type awsClient struct {
keyURIPrefix string
kms kmsiface.KMSAPI
encryptionContextName EncryptionContextName
builder func(keyURI string, encContextName EncryptionContextName) (tink.AEAD, error)
}

// ClientOption is an interface for defining options that are passed to
Expand All @@ -60,6 +71,13 @@ type option func(*awsClient) error

func (o option) set(a *awsClient) error { return o(a) }

// V2ClientOption is an interface for defining options that are passed to [WithV2KMSOptions].
type V2ClientOption interface{ set(*v2Client) error }

type v2option func(*v2Client) error

func (o v2option) set(a *v2Client) error { return o(a) }

// WithCredentialPath instantiates the underlying AWS KMS client using the
// credentials located at credentialPath.
//
Expand Down Expand Up @@ -113,7 +131,7 @@ const (
)

var encryptionContextNames = map[EncryptionContextName]string{
AssociatedData: "associatedData",
AssociatedData: "associatedData",
LegacyAdditionalData: "additionalData",
}

Expand Down Expand Up @@ -151,6 +169,78 @@ func WithEncryptionContextName(name EncryptionContextName) ClientOption {
})
}

func WithV2KMS(kms V2KMS) V2ClientOption {
return v2option(func(v *v2Client) error {
if v.kms != nil {
return errors.New("V2KMS client already set")
}

v.kms = kms

return nil
})
}

func UseV2() ClientOption {
return option(func(a *awsClient) error {
var v v2Client
a.builder = v.BuildAead

return nil
})
}

func WithV2KMSOptions(opts ...V2ClientOption) ClientOption {
return option(func(a *awsClient) error {
var v v2Client
a.builder = v.BuildAead

for _, opt := range opts {
if err := opt.set(&v); err != nil {
return fmt.Errorf("failed setting option: %v", err)
}
}

return nil
})
}

// WithAPITimeout sets the timeout for API requests made by the KMS client.
func WithAPITimeout(timeout time.Duration) V2ClientOption {
return v2option(func(v *v2Client) error {
if v.timeout != 0 {
return errors.New("timeout already set")
}
v.timeout = timeout

return nil
})
}

// WithLoadOptions sets the load options used to create the AWS SDK config.
func WithLoadOptions(opts ...func(*config.LoadOptions) error) V2ClientOption {
return v2option(func(v *v2Client) error {
if len(v.loadOpts) > 0 {
return errors.New("load options already set")
}
v.loadOpts = opts

return nil
})
}

// WithKMSOptions sets the options used to create the AWS SDK KMS client.
func WithKMSOptions(opts ...func(options *kmsv2.Options)) V2ClientOption {
return v2option(func(v *v2Client) error {
if len(v.kmsOpts) > 0 {
return errors.New("KMS options already set")
}
v.kmsOpts = opts

return nil
})
}

// NewClientWithOptions returns a [registry.KMSClient] which wraps an AWS KMS
// client and will handle keys whose URIs start with uriPrefix.
//
Expand All @@ -167,21 +257,26 @@ func NewClientWithOptions(uriPrefix string, opts ...ClientOption) (registry.KMSC
keyURIPrefix: uriPrefix,
}

// Default to v1 client
a.builder = func(keyURI string, encContextName EncryptionContextName) (tink.AEAD, error) {
// Populate values not defined via options.
if a.kms == nil {
k, err := getKMS(uriPrefix)
if err != nil {
return nil, err
}
a.kms = k
}
return newAWSAEAD(keyURI, a.kms, encContextName), nil
}

// Process options, if any.
for _, opt := range opts {
if err := opt.set(a); err != nil {
return nil, fmt.Errorf("failed setting option: %v", err)
}
}

// Populate values not defined via options.
if a.kms == nil {
k, err := getKMS(uriPrefix)
if err != nil {
return nil, err
}
a.kms = k
}
if a.encryptionContextName == 0 {
a.encryptionContextName = AssociatedData
}
Expand Down Expand Up @@ -275,7 +370,12 @@ func (c *awsClient) GetAEAD(keyURI string) (tink.AEAD, error) {
}

uri := strings.TrimPrefix(keyURI, awsPrefix)
return newAWSAEAD(uri, c.kms, c.encryptionContextName), nil
aead, err := c.builder(uri, c.encryptionContextName)
if err != nil {
return nil, fmt.Errorf("building AEAD: %w", err)
}

return aead, nil
}

func getKMS(uriPrefix string) (*kms.KMS, error) {
Expand Down Expand Up @@ -380,3 +480,35 @@ func getRegion(keyURI string) (string, error) {
}
return r[2], nil
}

type v2Client struct {
kms V2KMS
kmsOpts []func(options *kmsv2.Options)
loadOpts []func(*config.LoadOptions) error
timeout time.Duration
}

func (v *v2Client) BuildAead(keyURI string, encContextName EncryptionContextName) (tink.AEAD, error) {
if v.timeout == 0 {
v.timeout = defaultTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), v.timeout)
defer cancel()

if v.kms == nil {
cfg, err := config.LoadDefaultConfig(ctx, v.loadOpts...)
if err != nil {
return nil, fmt.Errorf("loading AWS default config: %w", err)
}

kmsClient, err := kmsv2.NewFromConfig(cfg, v.kmsOpts...), nil
if err != nil {
return nil, fmt.Errorf("creating V2KMS client: %w", err)

}
v.kms = kmsClient
}

return newAWSV2AEAD(keyURI, v.kms, encContextName, v.timeout), nil
}
Loading

0 comments on commit 6d9d7cf

Please sign in to comment.