diff --git a/tests/conftest.py b/tests/conftest.py index 195ab1d..1e7f063 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,8 @@ from invenio_db import db from invenio_app.factory import create_app as _create_app from invenio_accounts.testutils import login_user_via_session +from invenio_records.api import Record +from invenio_records_files.models import RecordsBuckets from pathlib import Path import sys @@ -185,3 +187,59 @@ def create_proprietary_record(client): response = client.post("/records", json=create_proprietary_record, headers=minimal_headers()) assert response.status_code == 201 return response.json['id'] + + +@pytest.fixture(scope="session") +def create_record(): + """Factory pattern for a loaded Record. + + The returned dict record has the interface of a Record. + + It provides a default value for each required field. + """ + + def _create_record(metadata=None): + # TODO: Modify according to record schema + metadata = metadata or {} + record = { + "_access": { + # TODO: Remove if "access_right" includes it + "metadata_restricted": False, + "files_restricted": False, + }, + "access_right": "open", + "title": "This is a record", + "description": "This record is a test record", + "owners": [1, 2, 3], + "internal": { + "access_levels": {}, + }, + } + record.update(metadata) + return record + + return _create_record + + +@pytest.fixture(scope="function") +def create_real_record(create_record, location): + """Factory pattern to create a real Record. + + This is needed for tests relying on database and search engine operations. + """ + + def _create_real_record(bucket, metadata=None): + record_dict = create_record(metadata) + + record = Record.create(record_dict, with_bucket=False) + + # Create link between record and bucket + RecordsBuckets.create(record=record.model, bucket=bucket) + record._bucket = bucket + + return record + # Flush to index and database + # current_search.flush_and_refresh(index='*') + # db.session.commit() + + return _create_real_record diff --git a/tests/test_generators.py b/tests/test_generators.py index e14b3c7..497ebde 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -1,10 +1,20 @@ import pytest from invenio_access.models import RoleNeed from invenio_access.permissions import authenticated_user, superuser_access, any_user +from invenio_files_rest.models import Bucket, Location, ObjectVersion +from io import BytesIO import sys -sys.path.append('../ultraviolet_permissions') -from ultraviolet_permissions.generators import AdminSuperUser, Depositor, Viewer, RestrictedDataUser, PublicViewer, Curator +sys.path.append("../ultraviolet_permissions") +from ultraviolet_permissions.generators import ( + AdminSuperUser, + Depositor, + Viewer, + RestrictedDataUser, + PublicViewer, + Curator, + IfSuppressedFile, +) def test_admin_superuser(): @@ -59,4 +69,26 @@ def test_curator(user_roles_propriatery_record, propriatery_record): other_record = propriatery_record assert generator.needs(record=record) == [RoleNeed("curator")] - assert generator.needs(record=other_record) == [] \ No newline at end of file + assert generator.needs(record=other_record) == [] + + +def test_suppressed(db, bucket_from_dir, create_real_record, location): + generator = IfSuppressedFile() + + suppressed_location = Location.get_by_name("suppressed") + + assert suppressed_location + + suppressed_bucket = Bucket.create(suppressed_location) + ObjectVersion.create( + suppressed_bucket, key="suppressed.txt", stream=BytesIO(b"suppressed") + ) + record = create_real_record(bucket=suppressed_bucket) + + assert generator._condition(record) == True + + bucket = Bucket.create(location) + ObjectVersion.create(bucket, key="public.txt", stream=BytesIO(b"public")) + record = create_real_record(bucket=bucket) + + assert generator._condition(record) == False diff --git a/ultraviolet_permissions/generators.py b/ultraviolet_permissions/generators.py index 94c13bc..dc07d75 100644 --- a/ultraviolet_permissions/generators.py +++ b/ultraviolet_permissions/generators.py @@ -16,6 +16,40 @@ from invenio_records_permissions.generators import Generator from flask_login import current_user +try: + from invenio_records_permissions.generators import ConditionalGenerator +except ImportError: + from invenio_rdm_records.services.generators import ConditionalGenerator + +from invenio_files_rest.models import Location +from invenio_rdm_records.proxies import current_rdm_records + + +SUPPRESSED_NAME = "suppressed" + + +def get_files_from_pid(record_pid): + record = current_rdm_records.records_service.read( + system_identity, record_pid + ) + record = current_rdm_records.records_service.record_cls.pid.resolve( + record_pid + ) + file_manager = record.files + for file_key in file_manager: + pprint(file_key) + pprint(record.files[file_key].file.uri) + pprint(record.files[file_key].__dir__()) + pprint(record.files[file_key].metadata) + + +def get_suppressed_location_uri(): + return Location.get_by_name(SUPPRESSED_NAME).uri + + +def is_suppressed(record_file): + return record_file.uri.startswith(get_suppressed_location_uri()) + def get_roles(record, user_role): # roles = [] @@ -198,4 +232,21 @@ def needs(self, record=None, **kwargs): def query_filter(self, **kwargs): """Filters for current identity as super user.""" # TODO: Implement with new permissions metadata - return dsl.Q('match_all') \ No newline at end of file + return dsl.Q('match_all') + + +class IfSuppressedFile(ConditionalGenerator): + """Conditional generator for suppressed files.""" + + def _condition(self, record, file_key=None, **kwargs): + is_file_suppressed = False + if file_key: + file_record = record.files.get(file_key) + file = file_record.file if file_record is not None else None + is_file_suppressed = file and is_suppressed(file) + else: + file_records = record.files.entries + is_file_suppressed = file_records and all( + is_suppressed(file_record.file) for file_record in file_records + ) + return is_file_suppressed