diff --git a/services/kms/kms.go b/services/kms/kms.go index 204ae8e..1333209 100644 --- a/services/kms/kms.go +++ b/services/kms/kms.go @@ -647,6 +647,9 @@ func (k *KMS) GenerateDataKeyPair(input GenerateDataKeyPairInput) (*GenerateData default: return nil, ValidationException("Unknown value for KeyPair Spec") } + if err != nil { + return nil, KMSInternalException(err.Error()) + } serializedPublicKey, err := x509.MarshalPKIXPublicKey(pkey.Public()) if err != nil { @@ -839,30 +842,29 @@ func (k *KMS) Decrypt(input DecryptInput) (*DecryptOutput, *awserrors.Error) { input.EncryptionAlgorithm = "SYMMETRIC_DEFAULT" } - keyArn := input.KeyId + keyId := input.KeyId ciphertext := input.CiphertextBlob if len(ciphertext) == 0 { return nil, InvalidCiphertextException("") } - k.mu.Lock() - defer k.mu.Unlock() - - encryptionKey := k.lockedGetKey(keyArn) - - if keyArn == "" || encryptionKey.IsAES() { - // AES can pack keyId into the ciphertext - // This logic is the opposite of Key.Encrypt + if keyId == "" { + // Passing KeyId is optional for symmetric encyption. + // Run the opposite of Key.Encrypt to unpack the key that was used data := ciphertext keyArnLen, data := uint8(data[0]), data[1:] if len(data) < 4+int(keyArnLen) { return nil, InvalidCiphertextException("") } - keyArn, ciphertext = string(data[:keyArnLen]), data[keyArnLen:] + keyId = string(data[:keyArnLen]) } - encryptionKey = k.lockedGetKey(keyArn) + k.mu.Lock() + defer k.mu.Unlock() + + encryptionKey := k.lockedGetKey(keyId) + if encryptionKey == nil { return nil, NotFoundException("") } @@ -871,6 +873,21 @@ func (k *KMS) Decrypt(input DecryptInput) (*DecryptOutput, *awserrors.Error) { return nil, DisabledException("") } + if encryptionKey.IsAES() { + if input.KeyId != "" && input.KeyId != encryptionKey.Id() { + // TODO(zbarsky): we should flag this + // return nil, IncorrectKeyException("") + k.logger.Info("WRONG aes key", "expected", input.KeyId, "actual", encryptionKey.Id()) + } + // AES can pack keyId into the ciphertext + // This logic is the opposite of Key.Encrypt + keyArnLen := uint8(ciphertext[0]) + if len(ciphertext) < 5+int(keyArnLen) { + return nil, InvalidCiphertextException("") + } + ciphertext = ciphertext[keyArnLen+1:] + } + plaintext, err := encryptionKey.Decrypt(ciphertext, input.EncryptionAlgorithm, input.EncryptionContext) if err != nil { if errors.Is(err, key.ErrBadAlgorithm) {