From 5ced0efbfa769952e72c4772da61ba40cb21790d Mon Sep 17 00:00:00 2001 From: Javier Romero Castro Date: Tue, 16 Jul 2024 09:39:53 +0200 Subject: [PATCH] service: add create update many * closes https://github.com/inveniosoftware/invenio-vocabularies/issues/353 --- .../services/records/service.py | 75 +++++++++- invenio_records_resources/services/uow.py | 21 +++ tests/services/conftest.py | 8 + tests/services/test_service_create_update.py | 141 ++++++++++++++++++ 4 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 tests/services/test_service_create_update.py diff --git a/invenio_records_resources/services/records/service.py b/invenio_records_resources/services/records/service.py index 1d816b51..e369130b 100644 --- a/invenio_records_resources/services/records/service.py +++ b/invenio_records_resources/services/records/service.py @@ -17,6 +17,8 @@ from invenio_search import current_search_client from invenio_search.engine import dsl from kombu import Queue +from marshmallow import ValidationError +from sqlalchemy.orm.exc import NoResultFound from werkzeug.local import LocalProxy from invenio_records_resources.services.errors import ( @@ -26,7 +28,7 @@ from ..base import LinksTemplate, Service from ..errors import RevisionIdMismatchError -from ..uow import RecordCommitOp, RecordDeleteOp, unit_of_work +from ..uow import RecordBulkCommitOp, RecordCommitOp, RecordDeleteOp, unit_of_work from .schema import ServiceSchemaWrapper @@ -579,3 +581,74 @@ def on_relation_update( self.reindex(identity, search_query=search_query) return True + + @unit_of_work() + def create_or_update_many(self, identity, data, uow=None): + """Create or update a list of records. + + This method takes a list of record data and creates or updates the corresponding records. + + Args: + identity (object): The user identity performing the operation. + data (list): A list of tuples containing the record ID and record data. + uow (UnitOfWork, optional): The unit of work to register the record operations. Defaults to None. + + Returns: + list: A list of tuples containing the operation type ('create' or 'update'), the processed record or the record dict, and any schema errors encountered. + """ + records_processed = [] + for record_id, record_dict in data: + try: + record = self.record_cls.pid.resolve(record_id) + + # Permissions + self.require_permission(identity, "update", record=record) + record_data, schema_errors = self.schema.load( + record_dict, + context=dict(identity=identity, pid=record.pid, record=record), + raise_errors=False, + ) + + # If errors we avoid creating the record + if schema_errors: + records_processed.append(("update", record_dict, schema_errors)) + continue + + # Run components + self.run_components( + "update", identity, data=record_data, record=record, uow=uow + ) + + records_processed.append(("update", record, schema_errors)) + except (NoResultFound, PIDDoesNotExistError): + self.require_permission(identity, "create") + + # Validate data and create record with pid + record_data, schema_errors = self.schema.load( + record_dict, context={"identity": identity}, raise_errors=False + ) + + # If errors we avoid creating the record + if schema_errors: + records_processed.append(("create", record_dict, schema_errors)) + continue + + # It's the components who saves the actual data in the record. + record = self.record_cls.create({}) + + # Run components + self.run_components( + "create", + identity, + data=record_data, + record=record, + errors=schema_errors, + uow=uow, + ) + records_processed.append(("create", record, schema_errors)) + + # We only commit records that have no errors + records = [record for _, record, errors in records_processed if errors == []] + uow.register(RecordBulkCommitOp(records, self.indexer)) + + return records_processed diff --git a/invenio_records_resources/services/uow.py b/invenio_records_resources/services/uow.py index 32b5e3be..86e14c60 100644 --- a/invenio_records_resources/services/uow.py +++ b/invenio_records_resources/services/uow.py @@ -183,6 +183,27 @@ def on_commit(self, uow): self._indexer.index(self._record, arguments=arguments) +class RecordBulkCommitOp(Operation): + """Record bulk commit operation with indexing.""" + + def __init__(self, records, indexer=None, index_refresh=False): + """Initialize the bulk record commit operation.""" + self._records = records + self._indexer = indexer + self._index_refresh = index_refresh + + def on_register(self, uow): + """Save objects to the session.""" + for record in self._records: + record.commit() + + def on_commit(self, uow): + """Run the operation.""" + if self._indexer is not None: + record_ids = [record.id for record in self._records] + self._indexer.bulk_index(record_ids) + + class RecordIndexOp(RecordCommitOp): """Record indexing operation.""" diff --git a/tests/services/conftest.py b/tests/services/conftest.py index 13156b44..e35a8cbc 100644 --- a/tests/services/conftest.py +++ b/tests/services/conftest.py @@ -89,3 +89,11 @@ def cache(): yield current_cache finally: current_cache.clear() + + +@pytest.fixture() +def invalid_input_data(): + """Input data (as coming from the view layer).""" + return { + "metadata": {"title": 10}, + } diff --git a/tests/services/test_service_create_update.py b/tests/services/test_service_create_update.py new file mode 100644 index 00000000..e6ed480d --- /dev/null +++ b/tests/services/test_service_create_update.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 CERN. +# Copyright (C) 2020 Northwestern University. +# +# Invenio-Records-Resources is free software; you can redistribute it and/or +# modify it under the terms of the MIT License; see LICENSE file for more +# details. + +"""Service create update many tests.""" + +import pytest +from invenio_pidstore.errors import PIDDoesNotExistError + +from invenio_records_resources.services.errors import PermissionDeniedError + + +def test_create(app, service, identity_simple, input_data): + """Create a record .""" + data = [(None, input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 1 + op_type, record, errors = records[0] + assert record.id + assert errors == [] + assert op_type == "create" + assert record.get("metadata") == input_data["metadata"] + + # Assert it's saved + read_item = service.read(identity_simple, record.get("id")) + assert record.get("id") == read_item.id + assert record.get("metadata") == read_item.data.get("metadata") + + +def test_create_multiple_records(app, service, identity_simple, input_data): + """Create multiple records.""" + data = [(None, input_data), (None, input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 2 + for op_type, record, errors in records: + assert record.id + assert errors == [] + assert op_type == "create" + assert record.get("metadata") == input_data["metadata"] + + # Assert it's saved + read_item = service.read(identity_simple, record.get("id")) + assert record.get("id") == read_item.id + assert record.get("metadata") == read_item.data.get("metadata") + + +def test_update_example_record(app, service, identity_simple, input_data): + """Update an existing record.""" + item = service.create(identity_simple, input_data) + id_ = item.id + updated_data = input_data.copy() + updated_data["metadata"]["title"] = "Updated Title" + + data = [(id_, updated_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 1 + op_type, record, errors = records[0] + assert record.get("id") == id_ + assert errors == [] + assert op_type == "update" + assert record.get("metadata")["title"] == "Updated Title" + + +def test_create_and_update_mixed(app, service, identity_simple, input_data): + """Create and update records in one call.""" + item = service.create(identity_simple, input_data) + id_ = item.id + updated_data = input_data.copy() + updated_data["metadata"]["title"] = "Updated Title" + + data = [(id_, updated_data), (None, input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 2 + for op_type, record, errors in records: + assert record.id + assert errors == [] + if op_type == "create": + assert record.get("metadata") == input_data["metadata"] + elif op_type == "update": + assert record.get("metadata")["title"] == "Updated Title" + + # Assert it's saved + read_item = service.read(identity_simple, record.get("id")) + assert record.get("id") == read_item.id + assert record.get("metadata") == read_item.data.get("metadata") + + +def test_create_with_validation_errors( + app, service, identity_simple, invalid_input_data +): + """Create a record with validation errors.""" + data = [(None, invalid_input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 1 + op_type, record, errors = records[0] + assert errors != [] + assert op_type == "create" + + # Assert it's not saved + with pytest.raises(PIDDoesNotExistError): + service.read(identity_simple, record.get("id")) + + +def test_update_with_validation_errors( + app, service, identity_simple, input_data, invalid_input_data +): + """Update an existing record with validation errors.""" + item = service.create(identity_simple, input_data) + id_ = item.id + invalid_input_data["id"] = id_ + data = [(id_, invalid_input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 1 + op_type, record, errors = records[0] + assert record.get("id") == id_ + assert errors != [] + assert op_type == "update" + + +def test_multiple_records( + app, service, identity_simple, input_data, invalid_input_data +): + """Create multiple records.""" + data = [(None, input_data), (None, invalid_input_data)] + records = service.create_or_update_many(identity_simple, data) + assert len(records) == 2 + for op_type, record, errors in records: + if errors: + # Assert it failed to insert + with pytest.raises(PIDDoesNotExistError): + service.read(identity_simple, record.get("id")) + else: + # Assert it's saved + read_item = service.read(identity_simple, record.get("id")) + assert record.get("id") == read_item.id + assert record.get("metadata") == read_item.data.get("metadata")