From 22471dc5828a9039a9ead20dfac8a0e3b49b06a8 Mon Sep 17 00:00:00 2001 From: Kyle McCormick Date: Mon, 12 Jun 2017 15:20:10 -0400 Subject: [PATCH] Change course_summaries() to use POST instead of GET POST method allows large number of course ID arguments to be passed as data, while GET method is restricted by URL length. EDUCATOR-464 --- AUTHORS | 1 + analyticsclient/client.py | 76 ++++++++++++++++--- analyticsclient/course_summaries.py | 15 ++-- analyticsclient/tests/__init__.py | 25 ++++-- analyticsclient/tests/test_client.py | 29 ++++++- .../tests/test_course_summaries.py | 1 + 6 files changed, 115 insertions(+), 32 deletions(-) diff --git a/AUTHORS b/AUTHORS index f840b12..495829d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -6,3 +6,4 @@ Dylan Rhodes Dmitry Viskov Tyler Hallada Braden MacDonald +Kyle McCormick diff --git a/analyticsclient/client.py b/analyticsclient/client.py index bd8fec4..b96574e 100644 --- a/analyticsclient/client.py +++ b/analyticsclient/client.py @@ -28,6 +28,9 @@ class Client(object): DATE_FORMAT = '%Y-%m-%d' DATETIME_FORMAT = DATE_FORMAT + 'T%H%M%S' + METHOD_GET = 'GET' + METHOD_POST = 'POST' + def __init__(self, base_url, auth_token=None, timeout=0.25): """ Initialize the client. @@ -63,17 +66,37 @@ def get(self, resource, timeout=None, data_format=DF.JSON): Raises: ClientError if the resource cannot be retrieved for any reason. """ - response = self._request(resource, timeout=timeout, data_format=data_format) + return self._get_or_post( + self.METHOD_GET, + resource, + timeout=timeout, + data_format=data_format + ) + + def post(self, resource, post_data=None, timeout=None, data_format=DF.JSON): + """ + Retrieve the data for POST request. - if data_format == DF.CSV: - return response.text + Arguments: - try: - return response.json() - except ValueError: - message = 'Unable to decode JSON response' - log.exception(message) - raise ClientError(message) + resource (str): Path in the form of slash separated strings. + post_data (dict): Dictionary containing POST data. + timeout (float): Continue to attempt to retrieve a resource for this many seconds before giving up and + raising an error. + data_format (str): Format in which data should be returned + + Returns: API response data in specified data_format + + Raises: ClientError if the resource cannot be retrieved for any reason. + + """ + return self._get_or_post( + self.METHOD_POST, + resource, + post_data=post_data, + timeout=timeout, + data_format=data_format + ) def has_resource(self, resource, timeout=None): """ @@ -91,13 +114,32 @@ def has_resource(self, resource, timeout=None): """ try: - self._request(resource, timeout=timeout) + self._request(self.METHOD_GET, resource, timeout=timeout) return True except ClientError: return False + def _get_or_post(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON): + response = self._request( + method, + resource, + post_data=post_data, + timeout=timeout, + data_format=data_format + ) + + if data_format == DF.CSV: + return response.text + + try: + return response.json() + except ValueError: + message = 'Unable to decode JSON response' + log.exception(message) + raise ClientError(message) + # pylint: disable=no-member - def _request(self, resource, timeout=None, data_format=DF.JSON): + def _request(self, method, resource, post_data=None, timeout=None, data_format=DF.JSON): if timeout is None: timeout = self.timeout @@ -114,7 +156,17 @@ def _request(self, resource, timeout=None, data_format=DF.JSON): try: uri = '{0}/{1}'.format(self.base_url, resource) - response = requests.get(uri, headers=headers, timeout=timeout) + + if method == self.METHOD_GET: + response = requests.get(uri, headers=headers, timeout=timeout) + elif method == self.METHOD_POST: + response = requests.post(uri, data=(post_data or {}), headers=headers, timeout=timeout) + else: + raise ValueError( + 'Invalid \'method\' argument: expected {0} or {1}, got {2}'.format( + self.METHOD_GET, self.METHOD_POST, method + ) + ) status = response.status_code if status != requests.codes.ok: diff --git a/analyticsclient/course_summaries.py b/analyticsclient/course_summaries.py index 648e83a..6ce0eda 100644 --- a/analyticsclient/course_summaries.py +++ b/analyticsclient/course_summaries.py @@ -1,5 +1,3 @@ -import urllib - import analyticsclient.constants.data_format as DF @@ -27,15 +25,12 @@ def course_summaries(self, course_ids=None, fields=None, exclude=None, programs= exclude: Array of fields to exclude from response. Default is to not exclude any fields. programs: If included in the query parameters, will include the programs array in the response. """ - query_params = {} - for query_arg, data in zip(['course_ids', 'fields', 'exclude', 'programs'], - [course_ids, fields, exclude, programs]): + post_data = {} + for param_name, data in zip(['course_ids', 'fields', 'exclude', 'programs'], + [course_ids, fields, exclude, programs]): if data: - query_params[query_arg] = ','.join(data) + post_data[param_name] = data path = 'course_summaries/' - querystring = urllib.urlencode(query_params) - if querystring: - path += '?{0}'.format(querystring) - return self.client.get(path, data_format=data_format) + return self.client.post(path, post_data=post_data, data_format=data_format) diff --git a/analyticsclient/tests/__init__.py b/analyticsclient/tests/__init__.py index 9983ea4..e67f533 100644 --- a/analyticsclient/tests/__init__.py +++ b/analyticsclient/tests/__init__.py @@ -34,6 +34,7 @@ class APIListTestCase(object): # Override in the subclass: endpoint = 'list' id_field = 'id' + uses_post_method = False def setUp(self): """Set up the test case.""" @@ -58,17 +59,25 @@ def expected_query(self, **kwargs): def kwarg_test(self, **kwargs): """Construct URL with given query parameters and check if it is what we expect.""" httpretty.reset() - uri_template = '{uri}?' - for key in kwargs: - uri_template += '%s={%s}' % (key, key) - uri = uri_template.format(uri=self.base_uri, **kwargs) - httpretty.register_uri(httpretty.GET, uri, body='{}') - getattr(self.client_class, self.endpoint)(**kwargs) - self.verify_last_querystring_equal(self.expected_query(**kwargs)) + if self.uses_post_method: + httpretty.register_uri(httpretty.POST, self.base_uri, body='{}') + getattr(self.client_class, self.endpoint)(**kwargs) + self.assertDictEqual(httpretty.last_request().parsed_body or {}, kwargs) + else: + uri_template = '{uri}?' + for key in kwargs: + uri_template += '%s={%s}' % (key, key) + uri = uri_template.format(uri=self.base_uri, **kwargs) + httpretty.register_uri(httpretty.GET, uri, body='{}') + getattr(self.client_class, self.endpoint)(**kwargs) + self.verify_last_querystring_equal(self.expected_query(**kwargs)) def test_all_items_url(self): """Endpoint can be called without parameters.""" - httpretty.register_uri(httpretty.GET, self.base_uri, body='{}') + httpretty.register_uri( + httpretty.POST if self.uses_post_method else httpretty.GET, + self.base_uri, body='{}' + ) getattr(self.client_class, self.endpoint)() @ddt.data( diff --git a/analyticsclient/tests/test_client.py b/analyticsclient/tests/test_client.py index 2a8d981..385f7a3 100644 --- a/analyticsclient/tests/test_client.py +++ b/analyticsclient/tests/test_client.py @@ -46,6 +46,11 @@ def test_get(self): httpretty.register_uri(httpretty.GET, self.test_url, body=json.dumps(data)) self.assertEquals(self.client.get(self.test_endpoint), data) + def test_post(self): + data = {'foo': 'bar'} + httpretty.register_uri(httpretty.POST, self.test_url, body=json.dumps(data)) + self.assertEquals(self.client.post(self.test_endpoint), data) + def test_get_invalid_response_body(self): """ Verify that client raises a ClientError if the response body cannot be properly parsed. """ @@ -71,7 +76,13 @@ def test_request_timeout(self, mock_get, lc): timeout = None headers = {'Accept': 'application/json'} - self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout) + self.assertRaises( + TimeoutError, + self.client._request, + self.client.METHOD_GET, + self.test_endpoint, + timeout=timeout + ) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, self.client.timeout) lc.check(('analyticsclient.client', 'ERROR', msg)) lc.clear() @@ -79,7 +90,13 @@ def test_request_timeout(self, mock_get, lc): mock_get.reset_mock() timeout = 10 - self.assertRaises(TimeoutError, self.client._request, self.test_endpoint, timeout=timeout) + self.assertRaises( + TimeoutError, + self.client._request, + self.client.METHOD_GET, + self.test_endpoint, + timeout=timeout + ) mock_get.assert_called_once_with(url, headers=headers, timeout=timeout) msg = 'Response from {0} exceeded timeout of {1}s.'.format(self.test_endpoint, timeout) lc.check(('analyticsclient.client', 'ERROR', msg)) @@ -100,3 +117,11 @@ def test_request_format(self): response = self.client.get(self.test_endpoint, data_format=data_format.JSON) self.assertEquals(httpretty.last_request().headers['Accept'], 'application/json') self.assertDictEqual(response, {}) + + def test_unsupported_method(self): + self.assertRaises( + ValueError, + self.client._request, + 'PATCH', + self.test_endpoint + ) diff --git a/analyticsclient/tests/test_course_summaries.py b/analyticsclient/tests/test_course_summaries.py index 75d2df9..9a5990e 100644 --- a/analyticsclient/tests/test_course_summaries.py +++ b/analyticsclient/tests/test_course_summaries.py @@ -9,6 +9,7 @@ class CourseSummariesTests(APIListTestCase, ClientTestCase): endpoint = 'course_summaries' id_field = 'course_ids' + uses_post_method = True @ddt.data( ['123'],