Skip to content

Commit

Permalink
Refactor kms_envelope_aead.go.
Browse files Browse the repository at this point in the history
Move parts of encrypt and decrypt into helper functions.
These helper function can then be used to implement encrypt and decrypt with context.

PiperOrigin-RevId: 697921752
Change-Id: Ifdf0833517b1a6507b0e8406bc98da9080923382
  • Loading branch information
juergw authored and copybara-github committed Nov 19, 2024
1 parent f7da869 commit 86a86c3
Showing 1 changed file with 50 additions and 24 deletions.
74 changes: 50 additions & 24 deletions aead/kms_envelope_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,19 @@ func NewKMSEnvelopeAEAD2(dekTemplate *tinkpb.KeyTemplate, keyEncryptionAEAD tink
}
}

// Encrypt implements the tink.AEAD interface for encryption.
func (a *KMSEnvelopeAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
if a.err != nil {
return nil, a.err
}
dekKeyData, err := registry.NewKeyData(a.dekTemplate)
if err != nil {
return nil, err
}
dek := dekKeyData.GetValue()
encryptedDEK, err := a.kekAEAD.Encrypt(dek, []byte{})
func newDEK(template *tinkpb.KeyTemplate) ([]byte, error) {
dekKeyData, err := registry.NewKeyData(template)
if err != nil {
return nil, err
}
return dekKeyData.GetValue(), nil
}

func encryptDataAndSerializeEnvelope(dekTypeURL string, dek, encryptedDEK []byte, plaintext, associatedData []byte) ([]byte, error) {
if len(encryptedDEK) == 0 {
return nil, errors.New("encrypted dek is empty")
}
p, err := registry.Primitive(a.dekTemplate.TypeUrl, dek)
p, err := registry.Primitive(dekTypeURL, dek)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +104,9 @@ func (a *KMSEnvelopeAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, err
return nil, err
}
if len(encryptedDEK) > maxLengthEncryptedDEK {
return nil, errors.New("kms_envelope_aead: length of encrypted DEK too large")
return nil, fmt.Errorf(
"kms_envelope_aead: length of encrypted DEK too large; got %d, want at most %d",
len(encryptedDEK), maxLengthEncryptedDEK)
}
res := make([]byte, 0, lenDEK+len(encryptedDEK)+len(payload))
res = binary.BigEndian.AppendUint32(res, uint32(len(encryptedDEK)))
Expand All @@ -118,34 +115,44 @@ func (a *KMSEnvelopeAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, err
return res, nil
}

// Decrypt implements the tink.AEAD interface for decryption.
func (a *KMSEnvelopeAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
// Encrypt implements the tink.AEAD interface for encryption.
func (a *KMSEnvelopeAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
if a.err != nil {
return nil, a.err
}
dek, err := newDEK(a.dekTemplate)
if err != nil {
return nil, err
}
encryptedDEK, err := a.kekAEAD.Encrypt(dek, []byte{})
if err != nil {
return nil, err
}
return encryptDataAndSerializeEnvelope(a.dekTemplate.GetTypeUrl(), dek, encryptedDEK, plaintext, associatedData)
}

// parseEnvelope extracts encryptedDEK and payload from the ciphertext.
func parseEnvelope(ciphertext []byte) ([]byte, []byte, error) {
// Verify we have enough bytes for the length of the encrypted DEK.
if len(ciphertext) <= lenDEK {
return nil, errors.New("kms_envelope_aead: invalid ciphertext")
return nil, nil, errors.New("kms_envelope_aead: invalid ciphertext")
}

// Extract length of encrypted DEK and advance past that length.
encryptedDEKLen := int(binary.BigEndian.Uint32(ciphertext[:lenDEK]))
if encryptedDEKLen <= 0 || encryptedDEKLen > maxLengthEncryptedDEK || encryptedDEKLen > len(ciphertext)-lenDEK {
return nil, errors.New("kms_envelope_aead: length of encrypted DEK too large")
return nil, nil, errors.New("kms_envelope_aead: length of encrypted DEK too large")
}
ciphertext = ciphertext[lenDEK:]

encryptedDEK := ciphertext[:encryptedDEKLen]
payload := ciphertext[encryptedDEKLen:]
ciphertext = nil

dek, err := a.kekAEAD.Decrypt(encryptedDEK, []byte{})
if err != nil {
return nil, err
}
return encryptedDEK, payload, nil
}

func decryptDataWithDEK(dekTypeURL string, dek []byte, payload, associatedData []byte) ([]byte, error) {
// Get an AEAD primitive corresponding to the DEK.
p, err := registry.Primitive(a.dekTemplate.TypeUrl, dek)
p, err := registry.Primitive(dekTypeURL, dek)
if err != nil {
return nil, fmt.Errorf("kms_envelope_aead: %s", err)
}
Expand All @@ -156,3 +163,22 @@ func (a *KMSEnvelopeAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, er

return dekAEAD.Decrypt(payload, associatedData)
}

// Decrypt implements the [tink.AEAD] interface for decryption.
func (a *KMSEnvelopeAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
if a.err != nil {
return nil, a.err
}

encryptedDEK, payload, err := parseEnvelope(ciphertext)
if err != nil {
return nil, err
}

dek, err := a.kekAEAD.Decrypt(encryptedDEK, []byte{})
if err != nil {
return nil, err
}

return decryptDataWithDEK(a.dekTemplate.GetTypeUrl(), dek, payload, associatedData)
}

0 comments on commit 86a86c3

Please sign in to comment.