Skip to content

Commit

Permalink
Merge pull request #55 from vimalloc/expires_overwrite
Browse files Browse the repository at this point in the history
Add abillity to change token expires time at a non-global level
  • Loading branch information
vimalloc authored Jun 14, 2017
2 parents 950a333 + 2bc81aa commit e45c0d1
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 4 deletions.
25 changes: 25 additions & 0 deletions docs/changing_default_behavior.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Changing Default Behaviors
==========================

Changing callback functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~

We provide what we think are sensible behaviors when attempting to access a
protected endpoint. If the access token is not valid for any reason (missing,
expired, tampered with, etc) we will return json in the format of {'msg': 'why
Expand Down Expand Up @@ -34,3 +37,25 @@ Possible loader functions are:
* - **revoked_token_loader**
- Function to call when a revoked token accesses a protected endpoint
- None

Dynamic token expires time
~~~~~~~~~~~~~~~~~~~~~~~~~~

You can also change the expires time for a token via the **expires_delta** kwarg
in the **create_refresh_token** and **create_access_token** functions. This takes
a **datetime.timedelta** and overrides the **JWT_REFRESH_TOKEN_EXPIRES** and
**JWT_ACCESS_TOKEN_EXPIRES** options. This can be useful if you have different
use cases for different tokens. An example of this might be you use short lived
access tokens used in your web application, but you allow the creation of long
lived access tokens that other developers can generate and use to interact with
your api in their programs.

.. code-block:: python
@app.route('/create-dev-token', methods=[POST])
@jwt_required
def create_dev_token():
username = get_jwt_identity()
expires = datatime.timedelta(days=365)
token = create_access_token(username, expires_delta=expires)
return jsonify({'token': token}), 201
20 changes: 16 additions & 4 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def revoked_token_loader(self, callback):
self._revoked_token_callback = callback
return callback

def create_refresh_token(self, identity):
def create_refresh_token(self, identity, expires_delta=None):
"""
Creates a new refresh token
Expand All @@ -256,13 +256,19 @@ def create_refresh_token(self, identity):
query disk twice, once for initially finding the identity
in your login endpoint, and once for setting addition data
in the JWT via the user_claims_loader
:param expires_delta: A datetime.timedelta for how long this token should
last before it expires. If this is None, it will
use the 'JWT_REFRESH_TOKEN_EXPIRES` config value
:return: A new refresh token
"""
if expires_delta is None:
expires_delta = config.refresh_expires

refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.refresh_expires,
expires_delta=expires_delta,
csrf=config.csrf_protect
)

Expand All @@ -273,7 +279,7 @@ def create_refresh_token(self, identity):
store_token(decoded_token, revoked=False)
return refresh_token

def create_access_token(self, identity, fresh=False):
def create_access_token(self, identity, fresh=False, expires_delta=None):
"""
Creates a new access token
Expand All @@ -287,13 +293,19 @@ def create_access_token(self, identity, fresh=False):
in the JWT via the user_claims_loader
:param fresh: If this token should be marked as fresh, and can thus access
fresh_jwt_required protected endpoints. Defaults to False
:param expires_delta: A datetime.timedelta for how long this token should
last before it expires. If this is None, it will
use the 'JWT_ACCESS_TOKEN_EXPIRES` config value
:return: A new access token
"""
if expires_delta is None:
expires_delta = config.access_expires

access_token = encode_access_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.access_expires,
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
Expand Down
30 changes: 30 additions & 0 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def login():
}
return jsonify(ret), 200

@self.app.route('/auth/login2', methods=['POST'])
def login2():
expires = timedelta(minutes=5)
ret = {
'access_token': create_access_token('test', fresh=True,
expires_delta=expires),
'refresh_token': create_refresh_token('test', expires_delta=expires),
}
return jsonify(ret), 200

@self.app.route('/auth/refresh', methods=['POST'])
@jwt_refresh_token_required
def refresh():
Expand Down Expand Up @@ -342,6 +352,26 @@ def test_bad_tokens(self):
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

def test_expires_time_override(self):
# Test access token
response = self.client.post('/auth/login2')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']
time.sleep(2)
status_code, data = self._jwt_get('/partially-protected', access_token)
self.assertEqual(status_code, 200)
self.assertEqual(data, {'msg': 'protected hello world'})

# Test refresh token
response = self.client.post('/auth/login2')
data = json.loads(response.get_data(as_text=True))
refresh_token = data['refresh_token']
time.sleep(2)
status_code, data = self._jwt_post('/auth/refresh', refresh_token)
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertNotIn('msg', data)

def test_optional_jwt_bad_tokens(self):
# Test expired access token
response = self.client.post('/auth/login')
Expand Down

0 comments on commit e45c0d1

Please sign in to comment.