Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IfSuppressedFile generator #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from invenio_db import db
from invenio_app.factory import create_app as _create_app
from invenio_accounts.testutils import login_user_via_session
from invenio_records.api import Record
from invenio_records_files.models import RecordsBuckets

from pathlib import Path
import sys
Expand Down Expand Up @@ -185,3 +187,59 @@ def create_proprietary_record(client):
response = client.post("/records", json=create_proprietary_record, headers=minimal_headers())
assert response.status_code == 201
return response.json['id']


@pytest.fixture(scope="session")
def create_record():
"""Factory pattern for a loaded Record.

The returned dict record has the interface of a Record.

It provides a default value for each required field.
"""

def _create_record(metadata=None):
# TODO: Modify according to record schema
metadata = metadata or {}
record = {
"_access": {
# TODO: Remove if "access_right" includes it
"metadata_restricted": False,
"files_restricted": False,
},
"access_right": "open",
"title": "This is a record",
"description": "This record is a test record",
"owners": [1, 2, 3],
"internal": {
"access_levels": {},
},
}
record.update(metadata)
return record

return _create_record


@pytest.fixture(scope="function")
def create_real_record(create_record, location):
"""Factory pattern to create a real Record.

This is needed for tests relying on database and search engine operations.
"""

def _create_real_record(bucket, metadata=None):
record_dict = create_record(metadata)

record = Record.create(record_dict, with_bucket=False)

# Create link between record and bucket
RecordsBuckets.create(record=record.model, bucket=bucket)
record._bucket = bucket

return record
# Flush to index and database
# current_search.flush_and_refresh(index='*')
# db.session.commit()

return _create_real_record
38 changes: 35 additions & 3 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest
from invenio_access.models import RoleNeed
from invenio_access.permissions import authenticated_user, superuser_access, any_user
from invenio_files_rest.models import Bucket, Location, ObjectVersion
from io import BytesIO

import sys
sys.path.append('../ultraviolet_permissions')
from ultraviolet_permissions.generators import AdminSuperUser, Depositor, Viewer, RestrictedDataUser, PublicViewer, Curator
sys.path.append("../ultraviolet_permissions")
from ultraviolet_permissions.generators import (
AdminSuperUser,
Depositor,
Viewer,
RestrictedDataUser,
PublicViewer,
Curator,
IfSuppressedFile,
)


def test_admin_superuser():
Expand Down Expand Up @@ -59,4 +69,26 @@ def test_curator(user_roles_propriatery_record, propriatery_record):
other_record = propriatery_record

assert generator.needs(record=record) == [RoleNeed("curator")]
assert generator.needs(record=other_record) == []
assert generator.needs(record=other_record) == []


def test_suppressed(db, bucket_from_dir, create_real_record, location):
generator = IfSuppressedFile()

suppressed_location = Location.get_by_name("suppressed")

assert suppressed_location

suppressed_bucket = Bucket.create(suppressed_location)
ObjectVersion.create(
suppressed_bucket, key="suppressed.txt", stream=BytesIO(b"suppressed")
)
record = create_real_record(bucket=suppressed_bucket)

assert generator._condition(record) == True

bucket = Bucket.create(location)
ObjectVersion.create(bucket, key="public.txt", stream=BytesIO(b"public"))
record = create_real_record(bucket=bucket)

assert generator._condition(record) == False
53 changes: 52 additions & 1 deletion ultraviolet_permissions/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,40 @@
from invenio_records_permissions.generators import Generator
from flask_login import current_user

try:
from invenio_records_permissions.generators import ConditionalGenerator
except ImportError:
from invenio_rdm_records.services.generators import ConditionalGenerator

from invenio_files_rest.models import Location
from invenio_rdm_records.proxies import current_rdm_records


SUPPRESSED_NAME = "suppressed"


def get_files_from_pid(record_pid):
record = current_rdm_records.records_service.read(
system_identity, record_pid
)
record = current_rdm_records.records_service.record_cls.pid.resolve(
record_pid
)
file_manager = record.files
for file_key in file_manager:
pprint(file_key)
pprint(record.files[file_key].file.uri)
pprint(record.files[file_key].__dir__())
pprint(record.files[file_key].metadata)


def get_suppressed_location_uri():
return Location.get_by_name(SUPPRESSED_NAME).uri


def is_suppressed(record_file):
return record_file.uri.startswith(get_suppressed_location_uri())


def get_roles(record, user_role):
# roles = []
Expand Down Expand Up @@ -198,4 +232,21 @@ def needs(self, record=None, **kwargs):
def query_filter(self, **kwargs):
"""Filters for current identity as super user."""
# TODO: Implement with new permissions metadata
return dsl.Q('match_all')
return dsl.Q('match_all')


class IfSuppressedFile(ConditionalGenerator):
"""Conditional generator for suppressed files."""

def _condition(self, record, file_key=None, **kwargs):
is_file_suppressed = False
if file_key:
file_record = record.files.get(file_key)
file = file_record.file if file_record is not None else None
is_file_suppressed = file and is_suppressed(file)
else:
file_records = record.files.entries
is_file_suppressed = file_records and all(
is_suppressed(file_record.file) for file_record in file_records
)
return is_file_suppressed