diff --git a/invenio_records_rest/config.py b/invenio_records_rest/config.py index 821ab00e..625ff015 100644 --- a/invenio_records_rest/config.py +++ b/invenio_records_rest/config.py @@ -88,8 +88,9 @@ def can(self): RECORDS_REST_ENDPOINTS = { - 'record-pid-type': { + 'endpoint-prefix': { 'create_permission_factory_imp': permission_check_factory(), + 'default_endpoint_prefix': True, 'default_media_type': 'application/json', 'delete_permission_factory_imp': permission_check_factory(), 'item_route': ''/recods/'', @@ -132,6 +133,10 @@ def can(self): :param create_permission_factory_imp: Import path to factory that create permission object for a given record. +:param default_endpoint_prefix: declare the current endpoint as the default + when building endpoints for the defined ``pid_type``. By default the + default prefix is defined to be the value of ``pid_type``. + :param default_media_type: Default media type for both records and search. :param delete_permission_factory_imp: Import path to factory that creates a @@ -146,10 +151,10 @@ def can(self): :param max_result_window: Maximum total number of records retrieved from a query. -:param pid_type: It specifies the record pid type. It's used also to build the - endpoint name. Required. +:param pid_type: It specifies the record pid type. Required. You can generate an URL to list all records of the given ``pid_type`` by - calling ``url_for('invenio_records_rest.{0}_list'.format(pid_type))``. + calling ``url_for('invenio_records_rest.{0}_list'.format( + current_records_rest.default_endpoint_prefixes[pid_type]))``. :param pid_fetcher: It identifies the registered fetcher name. Required. diff --git a/invenio_records_rest/ext.py b/invenio_records_rest/ext.py index ea04b969..54ffeff4 100644 --- a/invenio_records_rest/ext.py +++ b/invenio_records_rest/ext.py @@ -29,7 +29,7 @@ from werkzeug.utils import cached_property from . import config -from .utils import load_or_import_from_config +from .utils import load_or_import_from_config, build_default_endpoint_prefixes from .views import create_blueprint @@ -75,6 +75,10 @@ def delete_permission_factory(self): 'RECORDS_REST_DEFAULT_DELETE_PERMISSION_FACTORY', app=self.app ) + @cached_property + def default_endpoint_prefixes(self): + return build_default_endpoint_prefixes() + def reset_permission_factories(self): """Remove cached permission factories.""" for key in ('read', 'create', 'update', 'delete'): diff --git a/invenio_records_rest/links.py b/invenio_records_rest/links.py index 5233c8f1..475f1ab3 100644 --- a/invenio_records_rest/links.py +++ b/invenio_records_rest/links.py @@ -26,6 +26,8 @@ from flask import url_for +from .proxies import current_records_rest + def default_links_factory(pid): """Factory for record links generation. @@ -33,7 +35,8 @@ def default_links_factory(pid): :param pid: A Persistent Identifier instance. :returns: Dictionary containing a list of useful links for the record. """ - endpoint = '.{0}_item'.format(pid.pid_type) + endpoint = '.{0}_item'.format( + current_records_rest.default_endpoint_prefixes[pid.pid_type]) links = dict(self=url_for(endpoint, pid_value=pid.pid_value, _external=True)) return links diff --git a/invenio_records_rest/utils.py b/invenio_records_rest/utils.py index f1d0ba63..a5306d0c 100644 --- a/invenio_records_rest/utils.py +++ b/invenio_records_rest/utils.py @@ -38,6 +38,45 @@ from .errors import PIDDeletedRESTError, PIDDoesNotExistRESTError, \ PIDMissingObjectRESTError, PIDRedirectedRESTError, \ PIDUnregisteredRESTError +from .proxies import current_records_rest + + +def build_default_endpoint_prefixes(): + """Build the default_endpoint_prefixes map.""" + ret = {} + record_rest_endpoints = current_app.config['RECORD_REST_ENDPOINTS'] + for endpoint in record_rest_endpoints.values(): + pid_type = endpoint['pid_type'] + ret[pid_type] = get_default_endpoint_for(pid_type, + record_rest_endpoints) + + return ret + + +def get_default_endpoint_for(pid_type, _record_rest_endpoints=None): + """Get default endpoint for the given pid_type.""" + if _record_rest_endpoints is None: + _record_rest_endpoints = current_app.config['RECORD_REST_ENDPOINTS'] + + endpoint_prefix = None + + for key, value in _record_rest_endpoints.items(): + if (value['pid_type'] == pid_type and + value.get('default_endpoint_prefix')): + if endpoint_prefix is None: + endpoint_prefix = key + else: + raise ValueError('More than one endpoint-prefix has been ' + 'defined as default for ' + 'pid_type="{0}"'.format(pid_type)) + + if endpoint_prefix: + return endpoint_prefix + if pid_type in _record_rest_endpoints: + return pid_type + + raise ValueError('No endpoint-prefix corresponds to pid_type="{0}"'.format( + pid_type)) def obj_or_import_string(value, default=None): @@ -129,7 +168,9 @@ def data(self): except PIDRedirectedError as e: try: location = url_for( - '.{0}_item'.format(e.destination_pid.pid_type), + '.{0}_item'.format( + current_records_rest.default_endpoint_prefixes[ + e.destination_pid.pid_type]), pid_value=e.destination_pid.pid_value) data = dict( status=301, @@ -139,7 +180,7 @@ def data(self): response = make_response(jsonify(data), data['status']) response.headers['Location'] = location abort(response) - except BuildError: + except (BuildError, KeyError): current_app.logger.exception( 'Invalid redirect - pid_type "{0}" ' 'endpoint missing.'.format( diff --git a/invenio_records_rest/views.py b/invenio_records_rest/views.py index 84ad305e..4acae28b 100644 --- a/invenio_records_rest/views.py +++ b/invenio_records_rest/views.py @@ -452,7 +452,8 @@ def get(self, **kwargs): size=size, _external=True, ) - endpoint = '.{0}_list'.format(self.pid_type) + endpoint = '.{0}_list'.format( + current_records_rest.default_endpoint_prefixes[self.pid_type]) links = dict(self=url_for(endpoint, page=page, **urlkwargs)) if page > 1: links['prev'] = url_for(endpoint, page=page - 1, **urlkwargs) @@ -515,7 +516,8 @@ def post(self, **kwargs): pid, record, 201, links_factory=self.item_links_factory) # Add location headers - endpoint = '.{0}_item'.format(pid.pid_type) + endpoint = '.{0}_item'.format( + current_records_rest.default_endpoint_prefixes[pid.pid_type]) location = url_for(endpoint, pid_value=pid.pid_value, _external=True) response.headers.extend(dict(location=location)) return response diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..cbd3adbb --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2016 CERN. +# +# Invenio is free software; you can redistribute it +# and/or modify it under the terms of the GNU General Public License as +# published by the Free Software Foundation; either version 2 of the +# License, or (at your option) any later version. +# +# Invenio is distributed in the hope that it will be +# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Invenio; if not, write to the +# Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, +# MA 02111-1307, USA. +# +# In applying this license, CERN does not +# waive the privileges and immunities granted to it by virtue of its status +# as an Intergovernmental Organization or submit itself to any jurisdiction. + + +"""Utils tests.""" + +from __future__ import absolute_import, print_function + +from pytest import raises + +from invenio_records_rest.utils import get_default_endpoint_for + + +def test_get_default_endpoint_for(): + """Test get_default_endpoint_for().""" + assert get_default_endpoint_for('recid', { + 'recid': { + 'pid_type': 'recid', + }}) == 'recid' + + assert get_default_endpoint_for('recid', { + 'recid': { + 'pid_type': 'recid', + 'default_endpoint_prefix': True, + }}) == 'recid' + + assert get_default_endpoint_for('recid', { + 'recid': { + 'pid_type': 'recid', + }, + 'recid2': { + 'pid_type': 'recid', + }}) == 'recid' + + assert get_default_endpoint_for('recid', { + 'recid': { + 'pid_type': 'recid', + 'default_endpoint_prefix': True, + }, + 'recid2': { + 'pid_type': 'recid', + }}) == 'recid' + + assert get_default_endpoint_for('recid', { + 'recid': { + 'pid_type': 'recid', + }, + 'recid2': { + 'pid_type': 'recid', + 'default_endpoint_prefix': True, + }}) == 'recid2' + + with raises(ValueError) as excinfo: + get_default_endpoint_for('recid', { + 'recid1': { + 'pid_type': 'recid', + 'default_endpoint_prefix': True, + }, + 'recid2': { + 'pid_type': 'recid', + 'default_endpoint_prefix': True, + }}) + assert 'More than one' in str(excinfo.value) + + with raises(ValueError) as excinfo: + get_default_endpoint_for('recid', { + 'recid1': { + 'pid_type': 'recid', + }, + 'recid2': { + 'pid_type': 'recid', + }}) + assert 'No endpoint-prefix' in str(excinfo.value) + + with raises(ValueError) as excinfo: + get_default_endpoint_for('foo', { + 'recid1': { + 'pid_type': 'recid', + }, + 'recid2': { + 'pid_type': 'recid', + }}) + assert 'No endpoint-prefix' in str(excinfo.value)