Skip to content

Commit

Permalink
Add a tink.AEADWithContext interface. This allows remote AEAD to pass…
Browse files Browse the repository at this point in the history
… along a Context.

In the Keyset package, add functions that use this interface to encrypt/decrypt keysets.

PiperOrigin-RevId: 695204919
Change-Id: I7340fedeab2654c284b4baf97d73a31d0ca895af
  • Loading branch information
juergw authored and copybara-github committed Nov 11, 2024
1 parent 352765d commit 58f182b
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 11 deletions.
75 changes: 70 additions & 5 deletions keyset/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package keyset

import (
"context"
"errors"
"fmt"

Expand Down Expand Up @@ -260,6 +261,20 @@ func ReadWithAssociatedData(reader Reader, masterKey tink.AEAD, associatedData [
return newWithOptions(protoKeyset)
}

// ReadWithContext creates a keyset.Handle from an encrypted keyset obtained via
// reader using the provided AEADWithContext.
func ReadWithContext(ctx context.Context, reader Reader, keyEncryptionAEAD tink.AEADWithContext, associatedData []byte) (*Handle, error) {
encryptedKeyset, err := reader.ReadEncrypted()
if err != nil {
return nil, err
}
protoKeyset, err := decryptWithContext(ctx, encryptedKeyset, keyEncryptionAEAD, associatedData)
if err != nil {
return nil, err
}
return newWithOptions(protoKeyset)
}

// ReadWithNoSecrets tries to create a keyset.Handle from a keyset obtained via reader.
func ReadWithNoSecrets(reader Reader) (*Handle, error) {
protoKeyset, err := reader.Read()
Expand Down Expand Up @@ -381,6 +396,22 @@ func (h *Handle) WriteWithAssociatedData(writer Writer, masterKey tink.AEAD, ass
return writer.WriteEncrypted(encrypted)
}

// WriteWithContext encrypts and writes the keyset using the provided AEADWithContext.
func (h *Handle) WriteWithContext(ctx context.Context, writer Writer, keyEncryptionAEAD tink.AEADWithContext, associatedData []byte) error {
if h == nil {
return fmt.Errorf("keyset.Handle: nil handle")
}
protoKeyset, err := entriesToProtoKeyset(h.entries)
if err != nil {
return fmt.Errorf("keyset.Handle: %v", err)
}
encrypted, err := encryptWithContext(ctx, protoKeyset, keyEncryptionAEAD, associatedData)
if err != nil {
return fmt.Errorf("keyset.Handle: %v", err)
}
return writer.WriteEncrypted(encrypted)
}

// WriteWithNoSecrets exports the keyset in h to the given Writer w returning an error if the keyset
// contains secret key material.
func (h *Handle) WriteWithNoSecrets(w Writer) error {
Expand Down Expand Up @@ -530,11 +561,27 @@ func hasSecrets(ks *tinkpb.Keyset) bool {
return false
}

func decrypt(encryptedKeyset *tinkpb.EncryptedKeyset, masterKey tink.AEAD, associatedData []byte) (*tinkpb.Keyset, error) {
if encryptedKeyset == nil || masterKey == nil {
func decrypt(encryptedKeyset *tinkpb.EncryptedKeyset, keyEncryptionAEAD tink.AEAD, associatedData []byte) (*tinkpb.Keyset, error) {
if encryptedKeyset == nil || keyEncryptionAEAD == nil {
return nil, fmt.Errorf("keyset.Handle: invalid encrypted keyset")
}
decrypted, err := keyEncryptionAEAD.Decrypt(encryptedKeyset.GetEncryptedKeyset(), associatedData)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: decryption failed: %v", err)
}
keyset := new(tinkpb.Keyset)
if err := proto.Unmarshal(decrypted, keyset); err != nil {
return nil, errInvalidKeyset
}
return keyset, nil
}

// decryptWithContext does the same as decrypt, but uses an AEADWithContext instead of an AEAD.
func decryptWithContext(ctx context.Context, encryptedKeyset *tinkpb.EncryptedKeyset, keyEncryptionAEAD tink.AEADWithContext, associatedData []byte) (*tinkpb.Keyset, error) {
if encryptedKeyset == nil || keyEncryptionAEAD == nil {
return nil, fmt.Errorf("keyset.Handle: invalid encrypted keyset")
}
decrypted, err := masterKey.Decrypt(encryptedKeyset.GetEncryptedKeyset(), associatedData)
decrypted, err := keyEncryptionAEAD.DecryptWithContext(ctx, encryptedKeyset.GetEncryptedKeyset(), associatedData)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: decryption failed: %v", err)
}
Expand All @@ -545,12 +592,30 @@ func decrypt(encryptedKeyset *tinkpb.EncryptedKeyset, masterKey tink.AEAD, assoc
return keyset, nil
}

func encrypt(keyset *tinkpb.Keyset, masterKey tink.AEAD, associatedData []byte) (*tinkpb.EncryptedKeyset, error) {
func encrypt(keyset *tinkpb.Keyset, keyEncryptionAEAD tink.AEAD, associatedData []byte) (*tinkpb.EncryptedKeyset, error) {
serializedKeyset, err := proto.Marshal(keyset)
if err != nil {
return nil, errInvalidKeyset
}
encrypted, err := keyEncryptionAEAD.Encrypt(serializedKeyset, associatedData)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: encryption failed: %v", err)
}
// get keyset info
encryptedKeyset := &tinkpb.EncryptedKeyset{
EncryptedKeyset: encrypted,
KeysetInfo: getKeysetInfo(keyset),
}
return encryptedKeyset, nil
}

// encryptWithContext does the same as encrypt, but uses an AEADWithContext instead of an AEAD.
func encryptWithContext(ctx context.Context, keyset *tinkpb.Keyset, keyEncryptionAEAD tink.AEADWithContext, associatedData []byte) (*tinkpb.EncryptedKeyset, error) {
serializedKeyset, err := proto.Marshal(keyset)
if err != nil {
return nil, errInvalidKeyset
}
encrypted, err := masterKey.Encrypt(serializedKeyset, associatedData)
encrypted, err := keyEncryptionAEAD.EncryptWithContext(ctx, serializedKeyset, associatedData)
if err != nil {
return nil, fmt.Errorf("keyset.Handle: encryption failed: %v", err)
}
Expand Down
129 changes: 123 additions & 6 deletions keyset/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package keyset_test

import (
"bytes"
"context"
"errors"
"fmt"
"testing"

Expand All @@ -28,6 +30,7 @@ import (
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/mac"
"github.com/tink-crypto/tink-go/v2/signature"
"github.com/tink-crypto/tink-go/v2/testing/fakekms"
"github.com/tink-crypto/tink-go/v2/testkeyset"
"github.com/tink-crypto/tink-go/v2/testutil"
"github.com/tink-crypto/tink-go/v2/tink"
Expand Down Expand Up @@ -207,14 +210,12 @@ func TestWriteAndReadInJSON(t *testing.T) {
}
}

const fakeKeyURI = "fake-kms://CM2b3_MDElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEIK75t5L-adlUwVhWvRuWUwYARABGM2b3_MDIAE"

func TestWriteAndReadWithAssociatedData(t *testing.T) {
keysetEncryptionHandle, err := keyset.NewHandle(aead.AES128GCMKeyTemplate())
keysetEncryptionAead, err := fakekms.NewAEAD(fakeKeyURI)
if err != nil {
t.Errorf("keyset.NewHandle(aead.AES128GCMKeyTemplate()) err = %v, want nil", err)
}
keysetEncryptionAead, err := aead.New(keysetEncryptionHandle)
if err != nil {
t.Errorf("aead.New(keysetEncryptionHandle) err = %v, want nil", err)
t.Fatalf("fakekms.NewAEAD(fakeKeyURI) err = %v, want nil", err)
}

handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate())
Expand All @@ -238,6 +239,20 @@ func TestWriteAndReadWithAssociatedData(t *testing.T) {
if !proto.Equal(testkeyset.KeysetMaterial(handle), testkeyset.KeysetMaterial(handle2)) {
t.Errorf("keyset.ReadWithAssociatedData() = %v, want %v", handle2, handle)
}

// Test that ReadWithContext is compatible with WriteWithAssociatedData
kekAEADWithContext, err := fakekms.NewAEADWithContext(fakeKeyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(fakeKeyURI) err = %v, want nil", err)
}
ctx := context.Background()
handle3, err := keyset.ReadWithContext(ctx, keyset.NewBinaryReader(bytes.NewBuffer(encrypted)), kekAEADWithContext, associatedData)
if err != nil {
t.Fatalf("keyset.ReadWithContext() err = %v, want nil", err)
}
if !proto.Equal(testkeyset.KeysetMaterial(handle), testkeyset.KeysetMaterial(handle3)) {
t.Errorf("keyset.ReadWithContext() = %v, want %v", handle3, handle)
}
}

func TestReadWithMismatchedAssociatedData(t *testing.T) {
Expand Down Expand Up @@ -270,6 +285,108 @@ func TestReadWithMismatchedAssociatedData(t *testing.T) {
}
}

func TestWriteAndReadWithContext(t *testing.T) {
kekAEADWithContext, err := fakekms.NewAEADWithContext(fakeKeyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(fakeKeyURI) err = %v, want nil", err)
}

handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate()) err = %v, want nil", err)
}
associatedData := []byte{0x01, 0x02}

ctx := context.Background()
buff := &bytes.Buffer{}
err = handle.WriteWithContext(ctx, keyset.NewBinaryWriter(buff), kekAEADWithContext, associatedData)
if err != nil {
t.Fatalf("handle.WriteWithContext() err = %v, want nil", err)
}
encrypted := buff.Bytes()

handle2, err := keyset.ReadWithContext(ctx, keyset.NewBinaryReader(bytes.NewBuffer(encrypted)), kekAEADWithContext, associatedData)
if err != nil {
t.Fatalf("keyset.ReadWithContext() err = %v, want nil", err)
}

if !proto.Equal(testkeyset.KeysetMaterial(handle), testkeyset.KeysetMaterial(handle2)) {
t.Errorf("keyset.ReadWithContext() = %v, want %v", handle2, handle)
}

invalidAssociatedData := []byte{0x01, 0x03}
_, err = keyset.ReadWithContext(ctx, keyset.NewBinaryReader(bytes.NewBuffer(encrypted)), kekAEADWithContext, invalidAssociatedData)
if err == nil {
t.Errorf("keyset.ReadWithContext() err = nil, want err")
}

// Test that ReadWithAssociatedData is compatible with WriteWithContext
kekAEAD, err := fakekms.NewAEAD(fakeKeyURI)
if err != nil {
t.Fatalf("fakekms.NewAEAD(fakeKeyURI) err = %v, want nil", err)
}
handle3, err := keyset.ReadWithAssociatedData(keyset.NewBinaryReader(bytes.NewBuffer(encrypted)), kekAEAD, associatedData)
if err != nil {
t.Fatalf("keyset.ReadWithAssociatedData() err = %v, want nil", err)
}
if !proto.Equal(testkeyset.KeysetMaterial(handle), testkeyset.KeysetMaterial(handle3)) {
t.Errorf("keyset.ReadWithAssociatedData() = %v, want %v", handle3, handle)
}
}

func TestWriteWithContextDoesNotIgnoreContext(t *testing.T) {
kekAEADWithContext, err := fakekms.NewAEADWithContext(fakeKeyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(fakeKeyURI) err = %v, want nil", err)
}

handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate()) err = %v, want nil", err)
}
associatedData := []byte{0x01, 0x02}

canceledCtx, cancel := context.WithCancelCause(context.Background())
causeErr := errors.New("cause error message")
cancel(causeErr)

buff := &bytes.Buffer{}
err = handle.WriteWithContext(canceledCtx, keyset.NewBinaryWriter(buff), kekAEADWithContext, associatedData)
if err == nil {
t.Errorf("handle.WriteWithContext() err = nil, want error")
}
}

func TestReadWithContextDoesNotIgnoreContext(t *testing.T) {
kekAEADWithContext, err := fakekms.NewAEADWithContext(fakeKeyURI)
if err != nil {
t.Fatalf("fakekms.NewAEADWithContext(fakeKeyURI) err = %v, want nil", err)
}

handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate()) err = %v, want nil", err)
}
associatedData := []byte{0x01, 0x02}

ctx := context.Background()
buff := &bytes.Buffer{}
err = handle.WriteWithContext(ctx, keyset.NewBinaryWriter(buff), kekAEADWithContext, associatedData)
if err != nil {
t.Fatalf("handle.WriteWithContext() err = %v, want nil", err)
}
encrypted := buff.Bytes()

canceledCtx, cancel := context.WithCancelCause(ctx)
causeErr := errors.New("cause error message")
cancel(causeErr)

_, err = keyset.ReadWithContext(canceledCtx, keyset.NewBinaryReader(bytes.NewBuffer(encrypted)), kekAEADWithContext, associatedData)
if err == nil {
t.Errorf("keyset.ReadWithContext() err = nil, want error")
}
}

func TestPrimaryReturnsError(t *testing.T) {
testCases := []struct {
name string
Expand Down
43 changes: 43 additions & 0 deletions testing/fakekms/fakekms.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package fakekms

import (
"bytes"
"context"
"encoding/base64"
"fmt"
"strings"
Expand All @@ -44,6 +45,32 @@ type fakeClient struct {
uriPrefix string
}

type fakeAEADWithContext struct {
aead tink.AEAD
}

// EncryptWithContext implements the [tink.AEADWithContext] interface for encryption.
// The call fails if the context is canceled.
func (a *fakeAEADWithContext) EncryptWithContext(ctx context.Context, plaintext, associatedData []byte) ([]byte, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return a.aead.Encrypt(plaintext, associatedData)
}
}

// DecryptWithContext implements the [tink.AEADWithContext] interface for decryption.
// The call fails if the context is canceled.
func (a *fakeAEADWithContext) DecryptWithContext(ctx context.Context, ciphertext, associatedData []byte) ([]byte, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return a.aead.Decrypt(ciphertext, associatedData)
}
}

// NewClient returns a fake KMS client which will handle keys with uriPrefix prefix.
// keyURI must have the following format: 'fake-kms://<base64 encoded aead keyset>'.
func NewClient(uriPrefix string) (registry.KMSClient, error) {
Expand All @@ -65,6 +92,11 @@ func (c *fakeClient) GetAEAD(keyURI string) (tink.AEAD, error) {
if !c.Supported(keyURI) {
return nil, fmt.Errorf("keyURI must start with prefix %s, but got %s", c.uriPrefix, keyURI)
}
return NewAEAD(keyURI)
}

// NewAEAD returns a new [tink.AEAD] for the given keyURI.
func NewAEAD(keyURI string) (tink.AEAD, error) {
encodeKeyset := strings.TrimPrefix(keyURI, fakePrefix)
keysetData, err := base64.RawURLEncoding.DecodeString(encodeKeyset)
if err != nil {
Expand All @@ -78,6 +110,17 @@ func (c *fakeClient) GetAEAD(keyURI string) (tink.AEAD, error) {
return aead.New(handle)
}

// NewAEADWithContext returns a new [tink.AeadWithContext] for the given keyURI.
//
// The returned AEADWithContext will fail if the context is canceled.
func NewAEADWithContext(keyURI string) (tink.AEADWithContext, error) {
aead, err := NewAEAD(keyURI)
if err != nil {
return nil, err
}
return &fakeAEADWithContext{aead: aead}, nil
}

// NewKeyURI returns a new, random fake KMS key URI.
func NewKeyURI() (string, error) {
handle, err := keyset.NewHandle(aead.AES128GCMKeyTemplate())
Expand Down
Loading

0 comments on commit 58f182b

Please sign in to comment.