From e13e1ae3fc3541563d337b18327e30598a4d27ce Mon Sep 17 00:00:00 2001 From: fabiobatalha Date: Tue, 10 Oct 2017 16:21:25 -0300 Subject: [PATCH] Auto Throttling requests according to X-Rate-Limit-Limt and X-Rate-Limit-Interval --- crossref/__init__.py | 2 +- crossref/restful.py | 126 +++++++++++++++++++++++++++--------------- tests/test_restful.py | 45 ++++++++++++++- 3 files changed, 127 insertions(+), 46 deletions(-) diff --git a/crossref/__init__.py b/crossref/__init__.py index 0ad4a58..ee65984 100644 --- a/crossref/__init__.py +++ b/crossref/__init__.py @@ -1 +1 @@ -VERSION = '1.1.1' +VERSION = '1.2.0' diff --git a/crossref/restful.py b/crossref/restful.py index 7b51947..dfad139 100644 --- a/crossref/restful.py +++ b/crossref/restful.py @@ -2,12 +2,16 @@ import requests import json +from time import sleep + +from datetime import datetime, timedelta from crossref import validators, VERSION LIMIT = 100 MAXOFFSET = 10000 FACETS_MAX_LIMIT = 1000 + API = "api.crossref.org" @@ -23,27 +27,61 @@ class UrlSyntaxError(CrossrefAPIError, ValueError): pass -def do_http_request(method, endpoint, data=None, files=None, timeout=10, only_headers=False, custom_header=None): +class HTTPRequest(object): + + THROTTLING_TUNNING_TIME = 600 + + def __init__(self, throttle=True): + self.throttle = throttle + self.rate_limits = { + 'X-Rate-Limit-Limit': 50, + 'X-Rate-Limit-Interval': 1 + } + + def _update_rate_limits(self, headers): + + self.rate_limits['X-Rate-Limit-Limit'] = int(headers.get('X-Rate-Limit-Limit', 50)) + + interval_value = int(headers.get('X-Rate-Limit-Interval', '1s')[:-1]) + interval_scope = headers.get('X-Rate-Limit-Interval', '1s')[-1] + + if interval_scope == 'm': + interval_value = interval_value * 60 - if only_headers is True: - return requests.head(endpoint) + if interval_scope == 'h': + interval_value = interval_value * 60 * 60 - if method == 'post': - action = requests.post - else: - action = requests.get + self.rate_limits['X-Rate-Limit-Interval'] = interval_value - if custom_header: - headers = {'user-agent': custom_header} - else: - headers = {'user-agent': str(Etiquette())} + @property + def throttling_time(self): + return self.rate_limits['X-Rate-Limit-Interval'] / self.rate_limits['X-Rate-Limit-Limit'] - if method == 'post': - result = action(endpoint, data=data, files=files, timeout=timeout, headers=headers) - else: - result = action(endpoint, params=data, timeout=timeout, headers=headers) + def do_http_request(self, method, endpoint, data=None, files=None, timeout=100, only_headers=False, custom_header=None): - return result + if only_headers is True: + return requests.head(endpoint) + + if method == 'post': + action = requests.post + else: + action = requests.get + + if custom_header: + headers = {'user-agent': custom_header} + else: + headers = {'user-agent': str(Etiquette())} + + if method == 'post': + result = action(endpoint, data=data, files=files, timeout=timeout, headers=headers) + else: + result = action(endpoint, params=data, timeout=timeout, headers=headers) + + if self.throttle is True: + self._update_rate_limits(result.headers) + sleep(self.throttling_time) + + return result def build_url_endpoint(endpoint, context=None): @@ -56,7 +94,6 @@ def build_url_endpoint(endpoint, context=None): class Etiquette: def __init__(self, application_name='undefined', application_version='undefined', application_url='undefined', contact_email='anonymous'): - self.application_name = application_name self.application_version = application_version self.application_url = application_url @@ -77,8 +114,8 @@ class Endpoint: CURSOR_AS_ITER_METHOD = False - def __init__(self, request_url=None, request_params=None, context=None, etiquette=None): - + def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True): + self.do_http_request = HTTPRequest(throttle=throttle).do_http_request self.etiquette = etiquette or Etiquette() self.request_url = request_url or build_url_endpoint(self.ENDPOINT, context) self.request_params = request_params or dict() @@ -89,11 +126,12 @@ def _rate_limits(self): request_params = dict(self.request_params) request_url = str(self.request_url) - result = do_http_request( + result = self.do_http_request( 'get', request_url, only_headers=True, - custom_header=str(self.etiquette) + custom_header=str(self.etiquette), + throttle=False ) rate_limits = { @@ -126,7 +164,7 @@ def version(self): request_params = dict(self.request_params) request_url = str(self.request_url) - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -167,7 +205,7 @@ def count(self): request_url = str(self.request_url) request_params['rows'] = 0 - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -216,7 +254,7 @@ def __iter__(self): if 'sample' in self.request_params: request_params = self._escaped_pagging() - result = do_http_request( + result = self.do_http_request( 'get', self.request_url, data=request_params, @@ -238,7 +276,7 @@ def __iter__(self): request_params['cursor'] = '*' request_params['rows'] = LIMIT while True: - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -262,7 +300,7 @@ def __iter__(self): request_params['offset'] = 0 request_params['rows'] = LIMIT while True: - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -763,7 +801,7 @@ def facet(self, facet_name, facet_count=100): facet_count = self.FACET_VALUES[facet_name] if self.FACET_VALUES[facet_name] is not None and self.FACET_VALUES[facet_name] <= facet_count else facet_count request_params['facet'] = '%s:%s' % (facet_name, facet_count) - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -906,7 +944,7 @@ def doi(self, doi, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -940,7 +978,7 @@ def agency(self, doi, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -980,7 +1018,7 @@ def doi_exists(self, doi): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1105,7 +1143,7 @@ def funder(self, funder_id, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1145,7 +1183,7 @@ def funder_exists(self, funder_id): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1324,7 +1362,7 @@ def member(self, member_id, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1364,7 +1402,7 @@ def member_exists(self, member_id): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1411,7 +1449,7 @@ def type(self, type_id, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1447,7 +1485,7 @@ def all(self): request_url = build_url_endpoint(self.ENDPOINT, self.context) request_params = dict(self.request_params) - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1488,7 +1526,7 @@ def type_exists(self, type_id): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1539,7 +1577,7 @@ def prefix(self, prefix_id, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1623,7 +1661,7 @@ def journal(self, issn, only_message=True): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1664,7 +1702,7 @@ def journal_exists(self, issn): ) request_params = {} - result = do_http_request( + result = self.do_http_request( 'get', request_url, data=request_params, @@ -1693,7 +1731,7 @@ def works(self, issn): class Depositor(object): def __init__(self, prefix, api_user, api_key, etiquette=None): - + self.do_http_request = HTTPRequest(throttle=False).do_http_request self.etiquette = etiquette or Etiquette() self.prefix = prefix self.api_user = api_user @@ -1723,7 +1761,7 @@ def register_doi(self, submission_id, request_xml): 'login_passwd': self.api_key } - result = do_http_request( + result = self.do_http_request( 'post', endpoint, data=params, @@ -1754,7 +1792,7 @@ def request_doi_status_by_filename(self, file_name, data_type='result'): 'type': data_type } - result = do_http_request( + result = self.do_http_request( 'get', endpoint, data=params, @@ -1784,7 +1822,7 @@ def request_doi_status_by_batch_id(self, doi_batch_id, data_type='result'): 'type': data_type } - result = do_http_request( + result = self.do_http_request( 'get', endpoint, data=params, diff --git a/tests/test_restful.py b/tests/test_restful.py index 8906038..45bc089 100644 --- a/tests/test_restful.py +++ b/tests/test_restful.py @@ -75,4 +75,47 @@ def test_members_filters(self): def test_funders_filters(self): result = restful.Funders(etiquette=self.etiquette).filter(location="Japan").url - self.assertEqual(result, 'https://api.crossref.org/funders?filter=location%3AJapan') \ No newline at end of file + self.assertEqual(result, 'https://api.crossref.org/funders?filter=location%3AJapan') + + +class HTTPRequestTest(unittest.TestCase): + + def setUp(self): + + self.httprequest = restful.HTTPRequest() + + def test_default_rate_limits(self): + + expected = {'X-Rate-Limit-Interval': 1, 'X-Rate-Limit-Limit': 50} + + self.assertEqual(self.httprequest.rate_limits, expected) + + def test_update_rate_limits_seconds(self): + + headers = {'X-Rate-Limit-Interval': '2s', 'X-Rate-Limit-Limit': 50} + + self.httprequest._update_rate_limits(headers) + + expected = {'X-Rate-Limit-Interval': 2, 'X-Rate-Limit-Limit': 50} + + self.assertEqual(self.httprequest.rate_limits, expected) + + def test_update_rate_limits_minutes(self): + + headers = {'X-Rate-Limit-Interval': '2m', 'X-Rate-Limit-Limit': 50} + + self.httprequest._update_rate_limits(headers) + + expected = {'X-Rate-Limit-Interval': 120, 'X-Rate-Limit-Limit': 50} + + self.assertEqual(self.httprequest.rate_limits, expected) + + def test_update_rate_limits_hours(self): + + headers = {'X-Rate-Limit-Interval': '2h', 'X-Rate-Limit-Limit': 50} + + self.httprequest._update_rate_limits(headers) + + expected = {'X-Rate-Limit-Interval': 7200, 'X-Rate-Limit-Limit': 50} + + self.assertEqual(self.httprequest.rate_limits, expected)