Skip to content

Commit

Permalink
Add AEADWithContext support to KMSEnvelopeAEAD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698728450
Change-Id: Ia9714f19394d6ebfdb64b61c669f513f08080f32
  • Loading branch information
juergw authored and copybara-github committed Nov 21, 2024
1 parent a595eb8 commit b1610e5
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 7 deletions.
66 changes: 63 additions & 3 deletions aead/kms_envelope_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package aead

import (
"context"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -29,7 +30,8 @@ const (
maxLengthEncryptedDEK = 4096
)

// KMSEnvelopeAEAD represents an instance of Envelope AEAD.
// KMSEnvelopeAEAD represents an instance of KMS Envelope AEAD that implements
// the [tink.AEAD] interface.
type KMSEnvelopeAEAD struct {
dekTemplate *tinkpb.KeyTemplate
kekAEAD tink.AEAD
Expand All @@ -51,7 +53,36 @@ func isSupporedKMSEnvelopeDEK(dekKeyTypeURL string) bool {
return found
}

// NewKMSEnvelopeAEAD2 creates an new instance of KMSEnvelopeAEAD.
// KMSEnvelopeAEADWithContext represents an instance of KMS Envelope AEAD that implements
// the [tink.AEADWithContext] interface.
type KMSEnvelopeAEADWithContext struct {
dekTemplate *tinkpb.KeyTemplate
kekAEAD tink.AEADWithContext
}

// NewKMSEnvelopeAEADWithContext creates an new instance of [KMSEnvelopeAEADWithContext].
//
// dekTemplate must be a KeyTemplate for any of these Tink AEAD key types (any
// other key template will be rejected):
// - AesCtrHmacAeadKey
// - AesGcmKey
// - ChaCha20Poly1305Key
// - XChaCha20Poly1305
// - AesGcmSivKey
//
// keyEncryptionAEAD is used to encrypt the DEK, and is usually a remote AEAD
// provided by a KMS.
func NewKMSEnvelopeAEADWithContext(dekTemplate *tinkpb.KeyTemplate, keyEncryptionAEAD tink.AEADWithContext) (*KMSEnvelopeAEADWithContext, error) {
if !isSupporedKMSEnvelopeDEK(dekTemplate.GetTypeUrl()) {
return nil, errors.New("unsupported DEK key type")
}
return &KMSEnvelopeAEADWithContext{
dekTemplate: dekTemplate,
kekAEAD: keyEncryptionAEAD,
}, nil
}

// NewKMSEnvelopeAEAD2 creates an new instance of [KMSEnvelopeAEAD].
//
// dekTemplate specifies the key template of the data encryption key (DEK).
// It must be a KeyTemplate for any of these Tink AEAD key types (any
Expand All @@ -62,7 +93,8 @@ func isSupporedKMSEnvelopeDEK(dekKeyTypeURL string) bool {
// - XChaCha20Poly1305
// - AesGcmSivKey
//
// keyEncryptionAEAD is used to encrypt the DEK.
// keyEncryptionAEAD is used to encrypt the DEK, and is usually a remote AEAD
// provided by a KMS. It is preferable to use [NewKMSEnvelopeAEADWithContext] instead.
func NewKMSEnvelopeAEAD2(dekTemplate *tinkpb.KeyTemplate, keyEncryptionAEAD tink.AEAD) *KMSEnvelopeAEAD {
if !isSupporedKMSEnvelopeDEK(dekTemplate.GetTypeUrl()) {
return &KMSEnvelopeAEAD{
Expand Down Expand Up @@ -182,3 +214,31 @@ func (a *KMSEnvelopeAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, er

return decryptDataWithDEK(a.dekTemplate.GetTypeUrl(), dek, payload, associatedData)
}

// EncryptWithContext implements the [tink.AEADWithContext] interface for encryption.
func (a *KMSEnvelopeAEADWithContext) EncryptWithContext(ctx context.Context, plaintext, associatedData []byte) ([]byte, error) {
dek, err := newDEK(a.dekTemplate)
if err != nil {
return nil, err
}
encryptedDEK, err := a.kekAEAD.EncryptWithContext(ctx, dek, []byte{})
if err != nil {
return nil, err
}
return encryptDataAndSerializeEnvelope(a.dekTemplate.GetTypeUrl(), dek, encryptedDEK, plaintext, associatedData)
}

// DecryptWithContext implements the [tink.AEADWithContext] interface for decryption.
func (a *KMSEnvelopeAEADWithContext) DecryptWithContext(ctx context.Context, ciphertext, associatedData []byte) ([]byte, error) {
encryptedDEK, payload, err := parseEnvelope(ciphertext)
if err != nil {
return nil, err
}

dek, err := a.kekAEAD.DecryptWithContext(ctx, encryptedDEK, []byte{})
if err != nil {
return nil, err
}

return decryptDataWithDEK(a.dekTemplate.GetTypeUrl(), dek, payload, associatedData)
}
107 changes: 103 additions & 4 deletions aead/kms_envelope_aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package aead_test

import (
"bytes"
"context"
"encoding/hex"
"testing"

Expand All @@ -30,6 +31,10 @@ func TestKMSEnvelopeWorksWithTinkKeyTemplatesAsDekTemplate(t *testing.T) {
if err != nil {
t.Fatalf("fakekms.NewAEAD(keyURI) err = %q, want nil", err)
}
kekAEADWithContext, err := fakekms.NewAEADWithContext(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(keyURI) err = %q, want nil", err)
}
plaintext := []byte("plaintext")
associatedData := []byte("associatedData")
invalidAssociatedData := []byte("invalidAssociatedData")
Expand Down Expand Up @@ -87,6 +92,42 @@ func TestKMSEnvelopeWorksWithTinkKeyTemplatesAsDekTemplate(t *testing.T) {
if _, err = a.Decrypt(ciphertext, invalidAssociatedData); err == nil {
t.Error("a.Decrypt(ciphertext, invalidAssociatedData) err = nil, want error")
}

ctx := context.Background()
r, err := aead.NewKMSEnvelopeAEADWithContext(tc.dekTemplate, kekAEADWithContext)
if err != nil {
t.Error("a.DecryptWithContext(ctx, ciphertext, invalidAssociatedData) err = nil, want error")
}
ciphertext2, err := r.EncryptWithContext(ctx, plaintext, associatedData)
if err != nil {
t.Fatalf("a.EncryptWithContext(ctx, plaintext, associatedData) err = %q, want nil", err)
}
gotPlaintext2, err := r.DecryptWithContext(ctx, ciphertext2, associatedData)
if err != nil {
t.Fatalf("a.DecryptWithContext(ctx, ciphertext2, associatedData) err = %q, want nil", err)
}
if !bytes.Equal(gotPlaintext2, plaintext) {
t.Fatalf("got plaintext %q, want %q", gotPlaintext, plaintext)
}
if _, err = r.DecryptWithContext(ctx, ciphertext2, invalidAssociatedData); err == nil {
t.Error("a.DecryptWithContext(ctx, ciphertext2, invalidAssociatedData) err = nil, want error")
}

// check that DecryptWithContext is compatible with Decrypt
gotPlaintext3, err := r.DecryptWithContext(ctx, ciphertext, associatedData)
if err != nil {
t.Fatalf("r.DecryptWithContext(ctx, ciphertext, associatedData) err = %q, want nil", err)
}
if !bytes.Equal(gotPlaintext3, plaintext) {
t.Fatalf("got plaintext %q, want %q", gotPlaintext3, plaintext)
}
gotPlaintext4, err := a.Decrypt(ciphertext2, associatedData)
if err != nil {
t.Fatalf("a.Decrypt(ciphertext2, associatedData) err = %q, want nil", err)
}
if !bytes.Equal(gotPlaintext4, plaintext) {
t.Fatalf("got plaintext %q, want %q", gotPlaintext4, plaintext)
}
})
}
}
Expand All @@ -102,6 +143,7 @@ func TestKMSEnvelopeDecryptTestVector(t *testing.T) {
t.Fatalf("hex.DecodeString(ciphertextHex) err = %q, want nil", err)
}

// with NewKMSEnvelopeAEAD2.
kekAEAD, err := fakekms.NewAEAD(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEAD(keyURI) err = %q, want nil", err)
Expand All @@ -114,6 +156,24 @@ func TestKMSEnvelopeDecryptTestVector(t *testing.T) {
if !bytes.Equal(gotPlaintext, plaintext) {
t.Fatalf("got plaintext %q, want %q", gotPlaintext, plaintext)
}

// with NewKMSEnvelopeAEADWithContext.
ctx := context.Background()
kekAEADWithContext, err := fakekms.NewAEADWithContext(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(keyURI) err = %q, want nil", err)
}
r, err := aead.NewKMSEnvelopeAEADWithContext(aead.AES256GCMKeyTemplate(), kekAEADWithContext)
if err != nil {
t.Fatalf("aead.NewKMSEnvelopeAEADWithContext() err = %q, want nil", err)
}
gotPlaintext2, err := r.DecryptWithContext(ctx, ciphertext, associatedData)
if err != nil {
t.Fatalf("r.DecryptWithContext(ciphertext, associatedData) err = %q, want nil", err)
}
if !bytes.Equal(gotPlaintext2, plaintext) {
t.Fatalf("got plaintext %q, want %q", gotPlaintext, plaintext)
}
}

func TestKMSEnvelopeWithKmsEnvelopeKeyTemplatesAsDekTemplate_fails(t *testing.T) {
Expand All @@ -126,6 +186,7 @@ func TestKMSEnvelopeWithKmsEnvelopeKeyTemplatesAsDekTemplate_fails(t *testing.T)
t.Fatalf("aead.CreateKMSEnvelopAEADKeyTemplate() err = %q, want nil", err)
}

// NewKMSEnvelopeAEAD2 can't return an error. But it always fails when calling Encrypt.
kekAEAD, err := fakekms.NewAEAD(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEAD(keyURI) err = %q, want nil", err)
Expand All @@ -135,10 +196,22 @@ func TestKMSEnvelopeWithKmsEnvelopeKeyTemplatesAsDekTemplate_fails(t *testing.T)
if err == nil {
t.Error("a.Encrypt(plaintext, associatedData) err = nil, want error")
}

// NewKMSEnvelopeAEADWithContext returns an error.
kekAEADWithContext, err := fakekms.NewAEADWithContext(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(keyURI) err = %q, want nil", err)
}
_, err = aead.NewKMSEnvelopeAEADWithContext(envelopeDEKTemplate, kekAEADWithContext)
if err == nil {
t.Error("NewKMSEnvelopeAEADWithContext() err = nil, want error")
}
}

func TestKMSEnvelopeShortCiphertext(t *testing.T) {
keyURI := "fake-kms://CM2b3_MDElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEIK75t5L-adlUwVhWvRuWUwYARABGM2b3_MDIAE"

// with NewKMSEnvelopeAEAD2.
kekAEAD, err := fakekms.NewAEAD(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEAD(keyURI) err = %q, want nil", err)
Expand All @@ -147,23 +220,49 @@ func TestKMSEnvelopeShortCiphertext(t *testing.T) {
if _, err = a.Decrypt([]byte{1}, nil); err == nil {
t.Error("a.Decrypt([]byte{1}, nil) err = nil, want error")
}

// with NewKMSEnvelopeAEADWithContext.
kekAEADWithContext, err := fakekms.NewAEADWithContext(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(keyURI) err = %q, want nil", err)
}
r, err := aead.NewKMSEnvelopeAEADWithContext(aead.AES256GCMKeyTemplate(), kekAEADWithContext)
if err != nil {
t.Fatalf("fakekms.NewKMSEnvelopeAEADWithContext() err = %q, want nil", err)
}
if _, err = r.DecryptWithContext(context.Background(), []byte{1}, nil); err == nil {
t.Error("a.DecryptWithContext([]byte{1}, nil) err = nil, want error")
}
}

func TestKMSEnvelopeDecryptHugeEncryptedDek(t *testing.T) {
keyURI := "fake-kms://CM2b3_MDElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEIK75t5L-adlUwVhWvRuWUwYARABGM2b3_MDIAE"
// A ciphertext with a huge encrypted DEK length
ciphertext := []byte{0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88}

// with NewKMSEnvelopeAEAD2.
kekAEAD, err := fakekms.NewAEAD(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEAD(keyURI) err = %q, want nil", err)
}
a := aead.NewKMSEnvelopeAEAD2(aead.AES256GCMKeyTemplate(), kekAEAD)

ciphertext := []byte{0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88}
if _, err = a.Decrypt(ciphertext, nil); err == nil {
t.Error("a.Decrypt([]byte{1}, nil) err = nil, want error")
}
expectedError := "kms_envelope_aead: length of encrypted DEK too large"
if err.Error() != expectedError {
t.Errorf("a.Decrypt([]byte{1}, nil) err = %q, want %q", err, expectedError)

// with NewKMSEnvelopeAEADWithContext.
ctx := context.Background()
kekAEADWithContext, err := fakekms.NewAEADWithContext(keyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(keyURI) err = %q, want nil", err)
}
r, err := aead.NewKMSEnvelopeAEADWithContext(aead.AES256GCMKeyTemplate(), kekAEADWithContext)
if err != nil {
t.Fatalf("fakekms.NewKMSEnvelopeAEADWithContext() err = %q, want nil", err)
}
if _, err = r.DecryptWithContext(ctx, ciphertext, nil); err == nil {
t.Error("a.Decrypt([]byte{1}, nil) err = nil, want error")
}
}

Expand Down

0 comments on commit b1610e5

Please sign in to comment.