Skip to content

Commit

Permalink
Merge pull request #37 from dallerbarn/py3_string
Browse files Browse the repository at this point in the history
Python 2 and 3 string conversion
  • Loading branch information
Roland Hedberg committed Aug 20, 2015
2 parents 2b139d8 + 23722a5 commit 8f640e7
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 37 deletions.
26 changes: 26 additions & 0 deletions src/jwkest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,29 @@ def constant_time_compare(a, b):
for c, d in zip(a, b):
r |= c ^ d
return r == 0


def as_bytes(s):
"""
Convert an unicode string to bytes.
:param s: Unicode / bytes string
:return: bytes string
"""
try:
s = s.encode()
except (AttributeError, UnicodeDecodeError):
pass
return s


def as_unicode(b):
"""
Convert a byte string to a unicode string
:param b: byte string
:return: unicode string
"""
try:
b = b.decode()
except (AttributeError, UnicodeDecodeError):
pass
return b
3 changes: 2 additions & 1 deletion src/jwkest/aes_key_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PyCrypto's AES.
"""
from __future__ import division

try:
from builtins import hex
from builtins import range
Expand All @@ -39,7 +40,7 @@ def aes_unwrap_key_and_iv(kek, wrapped):
B = decrypt(ciphertext)
a = QUAD.unpack(B[:8])[0]
r[i] = B[8:]
return "".join(r[1:]), a
return b"".join(r[1:]), a


def aes_unwrap_key(kek, wrapped, iv=0xa6a6a6a6a6a6a6a6):
Expand Down
10 changes: 5 additions & 5 deletions src/jwkest/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from Crypto.Cipher import PKCS1_v1_5
from Crypto.Cipher import PKCS1_OAEP

from jwkest import b64d
from jwkest import b64d, as_bytes
from jwkest import b64e
from jwkest import JWKESTException
from jwkest import MissingKey
Expand Down Expand Up @@ -404,7 +404,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):

# If no iv and cek are given generate them
cek, iv = self._generate_key_and_iv(self["enc"], cek, iv)
if isinstance(key, six.string_types):
if isinstance(key, six.binary_type):
kek = key
else:
kek = intarr2str(key)
Expand All @@ -415,7 +415,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):

_enc = self["enc"]

ctxt, tag, cek = self.enc_setup(_enc, _msg, jwe.b64_encode_header(),
ctxt, tag, cek = self.enc_setup(_enc, _msg.encode(), jwe.b64_encode_header(),
cek, iv=iv)
return jwe.pack(parts=[jek, iv, ctxt, tag])

Expand Down Expand Up @@ -456,7 +456,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
:return: A jwe
"""

_msg = self.msg
_msg = as_bytes(self.msg)
if "zip" in self:
if self["zip"] == "DEF":
_msg = zlib.compress(_msg)
Expand Down Expand Up @@ -681,7 +681,7 @@ def decrypt(self, token, keys=None, alg=None):
for key in keys:
_key = key.encryption_key(alg=_alg, private=False)
try:
msg = decrypter.decrypt(bytes(token), _key)
msg = decrypter.decrypt(as_bytes(token), _key)
except (KeyError, DecryptionFailed):
pass
else:
Expand Down
9 changes: 5 additions & 4 deletions src/jwkest/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from requests import request

from jwkest import base64url_to_long
from jwkest import as_bytes
from jwkest import base64_to_long
from jwkest import long_to_base64
from jwkest import JWKESTException
Expand Down Expand Up @@ -66,15 +67,15 @@ def intarr2str(arr):


def sha256_digest(msg):
return hashlib.sha256(msg).digest()
return hashlib.sha256(as_bytes(msg)).digest()


def sha384_digest(msg):
return hashlib.sha384(msg).digest()
return hashlib.sha384(as_bytes(msg)).digest()


def sha512_digest(msg):
return hashlib.sha512(msg).digest()
return hashlib.sha512(as_bytes(msg)).digest()


# =============================================================================
Expand Down Expand Up @@ -534,7 +535,7 @@ class SYMKey(Key):

def __init__(self, kty="oct", alg="", use="", kid="", key=None,
x5c=None, x5t="", x5u="", k="", mtrl="", **kwargs):
Key.__init__(self, kty, alg, use, kid, key, x5c, x5t, x5u, **kwargs)
Key.__init__(self, kty, alg, use, kid, as_bytes(key), x5c, x5t, x5u, **kwargs)
self.k = k
if not self.key and self.k:
if isinstance(self.k, str):
Expand Down
17 changes: 7 additions & 10 deletions src/jwkest/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from Crypto.Util.number import bytes_to_long
import sys

from jwkest import b64d
from jwkest import b64d, as_unicode
from jwkest import b64e
from jwkest import constant_time_compare
from jwkest import safe_str_cmp
Expand Down Expand Up @@ -74,11 +74,11 @@ class SignerAlgError(JWSException):
def left_hash(msg, func="HS256"):
""" 128 bits == 16 bytes """
if func == 'HS256':
return b64e(sha256_digest(msg)[:16])
return as_unicode(b64e(sha256_digest(msg)[:16]))
elif func == 'HS384':
return b64e(sha384_digest(msg)[:24])
return as_unicode(b64e(sha384_digest(msg)[:24]))
elif func == 'HS512':
return b64e(sha512_digest(msg)[:32])
return as_unicode(b64e(sha512_digest(msg)[:32]))


def mpint(b):
Expand Down Expand Up @@ -249,10 +249,7 @@ class JWx(object):
"""

def __init__(self, msg=None, with_digest=False, **kwargs):
if six.PY3 and isinstance(msg, six.string_types):
self.msg = msg.encode("utf-8")
else:
self.msg = msg
self.msg = msg

self._dict = {}
self.with_digest = with_digest
Expand Down Expand Up @@ -492,9 +489,9 @@ def sign_compact(self, keys=None, protected=None):
raise UnknownAlgorithm(_alg)

_input = jwt.pack(parts=[self.msg])
sig = _signer.sign(_input, key.get_key(alg=_alg, private=True))
sig = _signer.sign(_input.encode("utf-8"), key.get_key(alg=_alg, private=True))
logger.debug("Signed message using key with kid=%s" % key.kid)
return b".".join([_input, b64encode_item(sig)])
return ".".join([_input, b64encode_item(sig).decode("utf-8")])

def verify_compact(self, jws, keys=None, allow_none=False, sigalg=None):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/jwkest/jwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import six
from jwkest import b64d
from jwkest import b64d, as_unicode
from jwkest import b64e
from jwkest import BadSyntax

Expand Down Expand Up @@ -58,7 +58,7 @@ def unpack(self, token):
part = split_token(token)
self.b64part = part
self.part = [b64d(p) for p in part]
self.headers = json.loads(self.part[0].decode("utf-8"))
self.headers = json.loads(self.part[0].decode())
return self

def pack(self, parts, headers=None):
Expand All @@ -77,10 +77,10 @@ def pack(self, parts, headers=None):
_all = self.b64part = [ self.b64part[0] ]
_all.extend([b64encode_item(p) for p in parts])

return b".".join(_all)
return ".".join([a.decode() for a in _all])

def payload(self):
_msg = self.part[1].decode("utf-8")
_msg = as_unicode(self.part[1])

# If not JSON web token assume JSON
if "cty" in self.headers and self.headers["cty"].lower() != "jwt":
Expand Down
6 changes: 2 additions & 4 deletions tests/test_1_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_pack_jwt():
jwt = _jwt.pack(parts=[{"iss": "joe", "exp": 1300819380,
"http://example.com/is_root": True}, ""])

p = jwt.split(b'.')
p = jwt.split('.')
assert len(p) == 3


Expand All @@ -40,12 +40,10 @@ def test_unpack_str():
"http://example.com/is_root": True}
jwt = _jwt.pack(parts=[payload, ""])

jwt = jwt.decode('utf-8')

_jwt2 = JWT().unpack(jwt)
assert _jwt2
out_payload = _jwt2.payload()


if __name__ == "__main__":
test_unpack_str()
test_unpack_str()
18 changes: 9 additions & 9 deletions tests/test_3_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ def test_hmac_from_keyrep():


def test_left_hash_hs256():
hsh = jws.left_hash(b'Please take a moment to register today')
assert hsh == b'rCFHVJuxTqRxOsn2IUzgvA'
hsh = jws.left_hash('Please take a moment to register today')
assert hsh == 'rCFHVJuxTqRxOsn2IUzgvA'


def test_left_hash_hs512():
hsh = jws.left_hash(b'Please take a moment to register today', "HS512")
assert hsh == b'_h6feWLt8zbYcOFnaBmekTzMJYEHdVTaXlDgJSWsEeY'
hsh = jws.left_hash('Please take a moment to register today', "HS512")
assert hsh == '_h6feWLt8zbYcOFnaBmekTzMJYEHdVTaXlDgJSWsEeY'


def test_rs256():
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_signer_ps256_fail():
keys = [RSAKey(key=import_rsa_key_from_file(KEY))]
#keys[0]._keytype = "private"
_jws = JWS(payload, alg="PS256")
_jwt = _jws.sign_compact(keys)[:-5] + b'abcde'
_jwt = _jws.sign_compact(keys)[:-5] + 'abcde'

_rj = JWS()
try:
Expand Down Expand Up @@ -428,9 +428,9 @@ def test_signer_protected_headers():

exp_protected = protected.copy()
exp_protected['alg'] = 'ES256'
enc_header, enc_payload, sig = _jwt.split(b'.')
assert json.loads(b64d(enc_header).decode("utf-8")) == exp_protected
assert b64d(enc_payload).decode("utf-8") == payload
enc_header, enc_payload, sig = _jwt.split('.')
assert json.loads(b64d(enc_header.encode("utf-8")).decode("utf-8")) == exp_protected
assert b64d(enc_payload.encode("utf-8")).decode("utf-8") == payload

_rj = JWS()
info = _rj.verify_compact(_jwt, keys)
Expand All @@ -444,7 +444,7 @@ def test_verify_protected_headers():
protected = dict(header1=u"header1 is protected",
header2="header2 is protected too", a=1)
_jwt = _jws.sign_compact(keys, protected=protected)
protectedHeader, enc_payload, sig = _jwt.split(b".")
protectedHeader, enc_payload, sig = _jwt.split(".")
data = dict(payload=enc_payload, signatures=[
dict(
header=dict(alg=u"ES256", jwk=_key.serialize()),
Expand Down

0 comments on commit 8f640e7

Please sign in to comment.