diff --git a/src/jwkest/jwe.py b/src/jwkest/jwe.py index 3aa0f1b..a69520d 100644 --- a/src/jwkest/jwe.py +++ b/src/jwkest/jwe.py @@ -489,15 +489,20 @@ def encrypt(self, key, iv="", cek="", **kwargs): else: raise ParameterError("Zip has unknown value: %s" % self["zip"]) + kwarg_cek = cek or None + _enc = self["enc"] cek, iv = self._generate_key_and_iv(_enc, cek, iv) + self["cek"] = cek logger.debug("cek: %s, iv: %s" % ([c for c in cek], [c for c in iv])) _encrypt = RSAEncrypter(self.with_digest).encrypt _alg = self["alg"] - if _alg == "RSA-OAEP": + if kwarg_cek: + jwe_enc_key = '' + elif _alg == "RSA-OAEP": jwe_enc_key = _encrypt(cek, key, 'pkcs1_oaep_padding') elif _alg == "RSA1_5": jwe_enc_key = _encrypt(cek, key) @@ -511,7 +516,7 @@ def encrypt(self, key, iv="", cek="", **kwargs): ctxt, tag, key = self.enc_setup(_enc, _msg, enc_header, cek, iv) return jwe.pack(parts=[jwe_enc_key, iv, ctxt, tag]) - def decrypt(self, token, key): + def decrypt(self, token, key, cek=None): """ Decrypts a JWT :param token: The JWT @@ -529,13 +534,16 @@ def decrypt(self, token, key): _decrypt = RSAEncrypter(self.with_digest).decrypt _alg = jwe.headers["alg"] - if _alg == "RSA-OAEP": + if cek: + pass + elif _alg == "RSA-OAEP": cek = _decrypt(jek, key, 'pkcs1_oaep_padding') elif _alg == "RSA1_5": cek = _decrypt(jek, key) else: raise NotSupportedAlgorithm(_alg) + self["cek"] = cek enc = jwe.headers["enc"] try: assert enc in SUPPORTED["enc"] @@ -687,7 +695,7 @@ def encrypt(self, key, iv="", cek="", **kwargs): return jwe.pack(parts=[kwargs['encrypted_key'], iv, ctxt, tag]) return jwe.pack(parts=[iv, ctxt, tag]) - def decrypt(self, token=None, key=None): + def decrypt(self, token=None, key=None, **kwargs): if not self.cek: raise Exception("Content Encryption Key is Not Yet Set") @@ -747,7 +755,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs): :return: Encrypted message """ - encrypted_key = cek = iv = None + # encrypted_key = cek = iv = None _alg = self["alg"] # Find Usable Keys @@ -801,6 +809,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs): try: token = encrypter.encrypt(_key, **kwargs) + self["cek"] = encrypter.cek if 'cek' in encrypter else None except TypeError as err: raise err else: @@ -811,7 +820,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs): logger.error("Could not find any suitable encryption key") raise NoSuitableEncryptionKey() - def decrypt(self, token=None, keys=None, alg=None): + def decrypt(self, token=None, keys=None, alg=None, cek=None): if token: jwe = JWEnc().unpack(token) # header, ek, eiv, ctxt, tag = token.split(b".") @@ -829,7 +838,7 @@ def decrypt(self, token=None, keys=None, alg=None): else: keys = self._pick_keys(self._get_keys(), use="enc", alg=_alg) - if not keys: + if not keys and not cek: raise NoSuitableDecryptionKey(_alg) if _alg in ["RSA-OAEP", "RSA1_5"]: @@ -847,10 +856,21 @@ def decrypt(self, token=None, keys=None, alg=None): else: raise NotSupportedAlgorithm + if cek: + try: + msg = decrypter.decrypt(as_bytes(token), None, cek=cek) + self["cek"] = decrypter.cek if 'cek' in decrypter else None + except (KeyError, DecryptionFailed): + pass + else: + logger.debug("Decrypted message using exiting CEK") + return msg + for key in keys: _key = key.encryption_key(alg=_alg, private=False) try: msg = decrypter.decrypt(as_bytes(token), _key) + self["cek"] = decrypter.cek if 'cek' in decrypter else None except (KeyError, DecryptionFailed): pass else: diff --git a/tests/test_4_jwe.py b/tests/test_4_jwe.py index d8f4a12..173269d 100644 --- a/tests/test_4_jwe.py +++ b/tests/test_4_jwe.py @@ -193,6 +193,37 @@ def full_path(local_file): rsa = RSA.importKey(open(KEY, 'r').read()) plain = b'Now is the time for all good men to come to the aid of their country.' +def test_cek_reuse_encryption_rsaes_rsa15(): + + _rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256") + jwt = _rsa.encrypt(rsa) + dec = JWE_RSA() + msg = dec.decrypt(jwt, rsa) + + assert msg == plain + + _rsa2 = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256") + jwt = _rsa2.encrypt(None, cek=dec["cek"]) + dec2 = JWE_RSA() + msg = dec2.decrypt(jwt, None, cek=_rsa["cek"]) + + assert msg == plain + +def test_cek_reuse_encryption_rsaes_rsa_oaep(): + + _rsa = JWE_RSA(plain, alg="RSA-OAEP", enc="A256GCM") + jwt = _rsa.encrypt(rsa) + dec = JWE_RSA() + msg = dec.decrypt(jwt, rsa) + + assert msg == plain + + _rsa2 = JWE_RSA(plain, alg="RSA-OAEP", enc="A256GCM") + jwt = _rsa2.encrypt(None, cek=dec["cek"]) + dec2 = JWE_RSA() + msg = dec2.decrypt(jwt, None, cek=_rsa["cek"]) + + assert msg == plain def test_rsa_encrypt_decrypt_rsa_cbc(): _rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")