Skip to content

Commit

Permalink
- Fix epk reversal in original ECDH-ES implementation (RFC specifies …
Browse files Browse the repository at this point in the history
…that epk is the key of the originator and this implementation uses the epk as the receiver's public key). This has been fixed + unittests updated.
  • Loading branch information
Matt David committed Jul 20, 2016
1 parent 43bde57 commit e7e99c9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
16 changes: 7 additions & 9 deletions src/jwkest/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
# Generate an ephemeral key pair if none is given
curve = NISTEllipticCurve.by_name(key.crv)
if "epk" in kwargs:
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(
kwargs["epk"])
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(kwargs["epk"])
else:
raise Exception(
"Ephemeral Public Key (EPK) Required for ECDH-ES JWE "
Expand All @@ -592,7 +591,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
params = {
"apu": b64e(apu),
"apv": b64e(apv),
"epk": key.serialize(False)
"epk": epk.serialize(False)
}

cek = iv = None
Expand All @@ -602,19 +601,20 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
iv = kwargs['iv']

cek, iv = self._generate_key_and_iv(self.enc, cek=cek, iv=iv)

if self.alg == "ECDH-ES":
try:
dk_len = KEYLEN[self.enc]
except KeyError:
raise Exception(
"Unknown key length for algorithm %s" % self.enc)

cek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv,
cek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv,
str(self.enc).encode(), dk_len)
elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]:
_pre, _post = self.alg.split("+")
klen = int(_post[1:4])
kek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv,
kek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv,
str(_post).encode(), klen)
encrypted_key = aes_wrap_key(kek, cek)
else:
Expand All @@ -631,8 +631,7 @@ def dec_setup(self, token, key=None, **kwargs):

# Handle EPK / Curve
if "epk" not in self.headers or "crv" not in self.headers["epk"]:
raise Exception(
"Ephemeral Public Key Missing in ECDH-ES Computation")
raise Exception("Ephemeral Public Key Missing in ECDH-ES Computation")

epubkey = ECKey(**self.headers["epk"])
apu = apv = ""
Expand Down Expand Up @@ -759,8 +758,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):

if not keys:
logger.error(
"Could not find any suitable encryption key for alg='{"
"}'".format(_alg))
"Could not find any suitable encryption key for alg='{}'".format(_alg))
raise NoSuitableEncryptionKey(_alg)

# Determine Encryption Class by Algorithm
Expand Down
22 changes: 10 additions & 12 deletions tests/test_4_jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,18 @@ def test_rsa_with_kid():

# Test ECDH-ES
curve = NISTEllipticCurve.by_name('P-256')
epriv, epub = curve.key_pair()
remotepriv, remotepub = curve.key_pair()
localpriv, localpub = curve.key_pair()

epk = ECKey(crv=curve.name(), d=epriv, x=epub[0], y=epub[1])
localkey = ECKey(crv=curve.name(), d=localpriv, x=localpub[0], y=localpub[1])

remotekey = ECKey(crv=curve.name(), d=remotepriv, x=remotepub[0], y=remotepub[1])

def test_ecdh_encrypt_decrypt_direct_key():
global epk

jwenc = JWE_EC(plain, alg="ECDH-ES", enc="A128GCM")
cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '',
key=localkey,
epk=epk)
key=remotekey,
epk=localkey)

kwargs = {}
kwargs['params'] = params
Expand All @@ -263,8 +261,8 @@ def test_ecdh_encrypt_decrypt_direct_key():

ret_jwe = factory(jwt)
jwdec = JWE_EC()
jwdec.dec_setup(ret_jwe.jwt, key=epk)
msg = jwdec.decrypt(ret_jwe.jwt, key=epk)
jwdec.dec_setup(ret_jwe.jwt, key=remotekey)
msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey)

assert msg == plain

Expand All @@ -274,8 +272,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():

jwenc = JWE_EC(plain, alg="ECDH-ES+A128KW", enc="A128GCM")
cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '',
key=localkey,
epk=epk)
key=remotekey,
epk=localkey)

kwargs = {}
kwargs['params'] = params
Expand All @@ -290,8 +288,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():

ret_jwe = factory(jwt)
jwdec = JWE_EC()
jwdec.dec_setup(ret_jwe.jwt, key=epk)
msg = jwdec.decrypt(ret_jwe.jwt, key=epk)
jwdec.dec_setup(ret_jwe.jwt, key=remotekey)
msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey)

assert msg == plain

Expand Down

0 comments on commit e7e99c9

Please sign in to comment.