From e7e99c9c4ac23cf623e10c382c636ef95671f2df Mon Sep 17 00:00:00 2001 From: Matt David Date: Tue, 19 Jul 2016 17:22:35 -0700 Subject: [PATCH] - Fix epk reversal in original ECDH-ES implementation (RFC specifies that epk is the key of the originator and this implementation uses the epk as the receiver's public key). This has been fixed + unittests updated. --- src/jwkest/jwe.py | 16 +++++++--------- tests/test_4_jwe.py | 22 ++++++++++------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/jwkest/jwe.py b/src/jwkest/jwe.py index 0fe60d1..3aa0f1b 100644 --- a/src/jwkest/jwe.py +++ b/src/jwkest/jwe.py @@ -582,8 +582,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs): # Generate an ephemeral key pair if none is given curve = NISTEllipticCurve.by_name(key.crv) if "epk" in kwargs: - epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey( - kwargs["epk"]) + epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(kwargs["epk"]) else: raise Exception( "Ephemeral Public Key (EPK) Required for ECDH-ES JWE " @@ -592,7 +591,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs): params = { "apu": b64e(apu), "apv": b64e(apv), - "epk": key.serialize(False) + "epk": epk.serialize(False) } cek = iv = None @@ -602,6 +601,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs): iv = kwargs['iv'] cek, iv = self._generate_key_and_iv(self.enc, cek=cek, iv=iv) + if self.alg == "ECDH-ES": try: dk_len = KEYLEN[self.enc] @@ -609,12 +609,12 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs): raise Exception( "Unknown key length for algorithm %s" % self.enc) - cek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv, + cek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv, str(self.enc).encode(), dk_len) elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]: _pre, _post = self.alg.split("+") klen = int(_post[1:4]) - kek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv, + kek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv, str(_post).encode(), klen) encrypted_key = aes_wrap_key(kek, cek) else: @@ -631,8 +631,7 @@ def dec_setup(self, token, key=None, **kwargs): # Handle EPK / Curve if "epk" not in self.headers or "crv" not in self.headers["epk"]: - raise Exception( - "Ephemeral Public Key Missing in ECDH-ES Computation") + raise Exception("Ephemeral Public Key Missing in ECDH-ES Computation") epubkey = ECKey(**self.headers["epk"]) apu = apv = "" @@ -759,8 +758,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs): if not keys: logger.error( - "Could not find any suitable encryption key for alg='{" - "}'".format(_alg)) + "Could not find any suitable encryption key for alg='{}'".format(_alg)) raise NoSuitableEncryptionKey(_alg) # Determine Encryption Class by Algorithm diff --git a/tests/test_4_jwe.py b/tests/test_4_jwe.py index 9172e39..d8f4a12 100644 --- a/tests/test_4_jwe.py +++ b/tests/test_4_jwe.py @@ -235,20 +235,18 @@ def test_rsa_with_kid(): # Test ECDH-ES curve = NISTEllipticCurve.by_name('P-256') -epriv, epub = curve.key_pair() +remotepriv, remotepub = curve.key_pair() localpriv, localpub = curve.key_pair() -epk = ECKey(crv=curve.name(), d=epriv, x=epub[0], y=epub[1]) localkey = ECKey(crv=curve.name(), d=localpriv, x=localpub[0], y=localpub[1]) - +remotekey = ECKey(crv=curve.name(), d=remotepriv, x=remotepub[0], y=remotepub[1]) def test_ecdh_encrypt_decrypt_direct_key(): - global epk jwenc = JWE_EC(plain, alg="ECDH-ES", enc="A128GCM") cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '', - key=localkey, - epk=epk) + key=remotekey, + epk=localkey) kwargs = {} kwargs['params'] = params @@ -263,8 +261,8 @@ def test_ecdh_encrypt_decrypt_direct_key(): ret_jwe = factory(jwt) jwdec = JWE_EC() - jwdec.dec_setup(ret_jwe.jwt, key=epk) - msg = jwdec.decrypt(ret_jwe.jwt, key=epk) + jwdec.dec_setup(ret_jwe.jwt, key=remotekey) + msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey) assert msg == plain @@ -274,8 +272,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key(): jwenc = JWE_EC(plain, alg="ECDH-ES+A128KW", enc="A128GCM") cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '', - key=localkey, - epk=epk) + key=remotekey, + epk=localkey) kwargs = {} kwargs['params'] = params @@ -290,8 +288,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key(): ret_jwe = factory(jwt) jwdec = JWE_EC() - jwdec.dec_setup(ret_jwe.jwt, key=epk) - msg = jwdec.decrypt(ret_jwe.jwt, key=epk) + jwdec.dec_setup(ret_jwe.jwt, key=remotekey) + msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey) assert msg == plain