Skip to content

Commit

Permalink
Merge pull request #62 from techguy613/master
Browse files Browse the repository at this point in the history
Fix epk reversal in original ECDH-ES implementation
  • Loading branch information
rohe authored Jul 20, 2016
2 parents 43bde57 + e7e99c9 commit 311f188
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
16 changes: 7 additions & 9 deletions src/jwkest/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
# Generate an ephemeral key pair if none is given
curve = NISTEllipticCurve.by_name(key.crv)
if "epk" in kwargs:
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(
kwargs["epk"])
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(kwargs["epk"])
else:
raise Exception(
"Ephemeral Public Key (EPK) Required for ECDH-ES JWE "
Expand All @@ -592,7 +591,7 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
params = {
"apu": b64e(apu),
"apv": b64e(apv),
"epk": key.serialize(False)
"epk": epk.serialize(False)
}

cek = iv = None
Expand All @@ -602,19 +601,20 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
iv = kwargs['iv']

cek, iv = self._generate_key_and_iv(self.enc, cek=cek, iv=iv)

if self.alg == "ECDH-ES":
try:
dk_len = KEYLEN[self.enc]
except KeyError:
raise Exception(
"Unknown key length for algorithm %s" % self.enc)

cek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv,
cek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv,
str(self.enc).encode(), dk_len)
elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]:
_pre, _post = self.alg.split("+")
klen = int(_post[1:4])
kek = ecdh_derive_key(curve, key.d, (epk.x, epk.y), apu, apv,
kek = ecdh_derive_key(curve, epk.d, (key.x, key.y), apu, apv,
str(_post).encode(), klen)
encrypted_key = aes_wrap_key(kek, cek)
else:
Expand All @@ -631,8 +631,7 @@ def dec_setup(self, token, key=None, **kwargs):

# Handle EPK / Curve
if "epk" not in self.headers or "crv" not in self.headers["epk"]:
raise Exception(
"Ephemeral Public Key Missing in ECDH-ES Computation")
raise Exception("Ephemeral Public Key Missing in ECDH-ES Computation")

epubkey = ECKey(**self.headers["epk"])
apu = apv = ""
Expand Down Expand Up @@ -759,8 +758,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):

if not keys:
logger.error(
"Could not find any suitable encryption key for alg='{"
"}'".format(_alg))
"Could not find any suitable encryption key for alg='{}'".format(_alg))
raise NoSuitableEncryptionKey(_alg)

# Determine Encryption Class by Algorithm
Expand Down
22 changes: 10 additions & 12 deletions tests/test_4_jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,18 @@ def test_rsa_with_kid():

# Test ECDH-ES
curve = NISTEllipticCurve.by_name('P-256')
epriv, epub = curve.key_pair()
remotepriv, remotepub = curve.key_pair()
localpriv, localpub = curve.key_pair()

epk = ECKey(crv=curve.name(), d=epriv, x=epub[0], y=epub[1])
localkey = ECKey(crv=curve.name(), d=localpriv, x=localpub[0], y=localpub[1])

remotekey = ECKey(crv=curve.name(), d=remotepriv, x=remotepub[0], y=remotepub[1])

def test_ecdh_encrypt_decrypt_direct_key():
global epk

jwenc = JWE_EC(plain, alg="ECDH-ES", enc="A128GCM")
cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '',
key=localkey,
epk=epk)
key=remotekey,
epk=localkey)

kwargs = {}
kwargs['params'] = params
Expand All @@ -263,8 +261,8 @@ def test_ecdh_encrypt_decrypt_direct_key():

ret_jwe = factory(jwt)
jwdec = JWE_EC()
jwdec.dec_setup(ret_jwe.jwt, key=epk)
msg = jwdec.decrypt(ret_jwe.jwt, key=epk)
jwdec.dec_setup(ret_jwe.jwt, key=remotekey)
msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey)

assert msg == plain

Expand All @@ -274,8 +272,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():

jwenc = JWE_EC(plain, alg="ECDH-ES+A128KW", enc="A128GCM")
cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '',
key=localkey,
epk=epk)
key=remotekey,
epk=localkey)

kwargs = {}
kwargs['params'] = params
Expand All @@ -290,8 +288,8 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():

ret_jwe = factory(jwt)
jwdec = JWE_EC()
jwdec.dec_setup(ret_jwe.jwt, key=epk)
msg = jwdec.decrypt(ret_jwe.jwt, key=epk)
jwdec.dec_setup(ret_jwe.jwt, key=remotekey)
msg = jwdec.decrypt(ret_jwe.jwt, key=remotekey)

assert msg == plain

Expand Down

0 comments on commit 311f188

Please sign in to comment.