Skip to content

Commit

Permalink
tests: Select community before publishing
Browse files Browse the repository at this point in the history
  • Loading branch information
sakshamarora1 committed Sep 23, 2024
1 parent cd57bfd commit daa9e94
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 26 deletions.
4 changes: 2 additions & 2 deletions invenio_rdm_records/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def always_valid(identifier):
#
# Record communities
#
RDM_RECORD_ALWAYS_IN_COMMUNITY = True
"""Enforces at least one community per record on remove community function."""
RDM_RECORD_ALWAYS_IN_COMMUNITY = False
"""Enforces at least one community per record."""

#
# Search configuration
Expand Down
2 changes: 1 addition & 1 deletion invenio_rdm_records/services/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class CommunityNotSelectedError(Exception):
description = "Cannot publish without selecting a community."


class CannotRemoveCommunityError(PermissionDenied):
class CannotRemoveCommunityError(Exception):
"""Error thrown when the last community is being removed from the record."""

description = "Cannot remove. A record should be part of atleast 1 community."
1 change: 1 addition & 0 deletions invenio_rdm_records/services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def purge_record(self, identity, id_, uow=None):

raise NotImplementedError()

@unit_of_work()
def publish(self, identity, id_, uow=None, expand=False):
"""Publish a draft.
Expand Down
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from invenio_records_resources.proxies import current_service_registry
from invenio_records_resources.references.entity_resolvers import ServiceResultResolver
from invenio_records_resources.services.custom_fields import TextCF
from invenio_records_resources.services.uow import UnitOfWork
from invenio_requests.notifications.builders import (
CommentRequestEventCreateNotificationBuilder,
)
Expand Down Expand Up @@ -2080,6 +2081,41 @@ def create_record(
return RecordFactory()


@pytest.fixture()
def record_required_community(db, uploader, minimal_record, community):
"""Creates a record that belongs to a community before publishing."""

class Record:
"""Test record class."""

def create_record(
self,
record_dict=minimal_record,
uploader=uploader,
community=community,
):
"""Creates new record that belongs to the same community."""
# create draft
draft = current_rdm_records_service.create(uploader.identity, record_dict)
record = draft._record
# add the record to the community
community_record = community._record
record.parent.communities.add(community_record, default=False)
record.parent.commit()
db.session.commit()
current_rdm_records_service.indexer.index(
record, arguments={"refresh": True}
)

# publish and get record
community_record = current_rdm_records_service.publish(
uploader.identity, draft.id
)
return community_record

return Record()


@pytest.fixture(scope="session")
def headers():
"""Default headers for making requests."""
Expand Down
57 changes: 34 additions & 23 deletions tests/resources/test_resources_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,13 @@ def ensure_record_community_exists_config(app):


def test_restricted_record_creation(
app, record_community, uploader, curator, community_owner, test_user, superuser
app,
record_community,
uploader,
curator,
community_owner,
test_user,
superuser,
):
"""Verify CommunityNotSelectedError is raised when direct publish a record"""
# You can directly publish a record when the config is disabled
Expand All @@ -942,7 +948,7 @@ def test_restricted_record_creation(


def test_remove_last_existing_non_existing_community(
app, client, uploader, record_community, headers, community
app, client, uploader, record_required_community, headers, community
):
"""Test removal of an existing and non-existing community from the record,
while ensuring at least one community exists."""
Expand All @@ -955,23 +961,24 @@ def test_remove_last_existing_non_existing_community(
}

client = uploader.login(client)
record = record_community.create_record()
record = record_required_community.create_record()
record_pid = record._record.pid.pid_value
with ensure_record_community_exists_config(app):
response = client.delete(
f"/records/{record.pid.pid_value}/communities",
f"/records/{record_pid}/communities",
headers=headers,
json=data,
)
assert response.is_json
assert response.status_code == 200
assert response.status_code == 400
# Should get 3 errors: Can't remove community, 2 bad IDs
assert len(response.json["errors"]) == 3
record_saved = client.get(f"/records/{record.pid.pid_value}", headers=headers)
record_saved = client.get(f"/records/{record_pid}", headers=headers)
assert record_saved.json["parent"]["communities"]


def test_remove_last_community_api_error_handling(
record_community,
record_required_community,
community,
uploader,
headers,
Expand All @@ -980,7 +987,8 @@ def test_remove_last_community_api_error_handling(
app,
):
"""Testing error message when trying to remove last community."""
record = record_community.create_record()
record = record_required_community.create_record()
record_pid = record._record.pid.pid_value
data = {"communities": [{"id": community.id}]}
for user in [uploader, curator]:
client = user.login(client)
Expand All @@ -994,16 +1002,14 @@ def test_remove_last_community_api_error_handling(
)
with ensure_record_community_exists_config(app):
response = client.delete(
f"/records/{record.pid.pid_value}/communities",
f"/records/{record_pid}/communities",
headers=headers,
json=data,
)
assert response.is_json
assert response.status_code == 200
assert response.status_code == 400

record_saved = client.get(
f"/records/{record.pid.pid_value}", headers=headers
)
record_saved = client.get(f"/records/{record_pid}", headers=headers)
assert record_saved.json["parent"]["communities"]
assert len(response.json["errors"]) == 1

Expand All @@ -1023,7 +1029,7 @@ def test_remove_last_community_api_error_handling(
def test_remove_record_last_community_with_multiple_communities(
closed_review_community,
open_review_community,
record_community,
record_required_community,
community2,
uploader,
headers,
Expand All @@ -1034,40 +1040,45 @@ def test_remove_record_last_community_with_multiple_communities(
"""Testing correct removal of multiple communities"""
client = uploader.login(client)

record = record_community.create_record()
record = record_required_community.create_record()
record_pid = record._record.pid.pid_value
comm = [
community2,
open_review_community,
closed_review_community,
] # one more in the rec fixture so it's 4
for com in comm:
_add_to_community(db, record, com)
assert len(record.parent.communities.ids) == 4
_add_to_community(db, record._record, com)
assert len(record._record.parent.communities.ids) == 4

with ensure_record_community_exists_config(app):
data = {"communities": [{"id": x} for x in record.parent.communities.ids]}
data = {
"communities": [{"id": x} for x in record._record.parent.communities.ids]
}

response = client.delete(
f"/records/{record.pid.pid_value}/communities",
f"/records/{record_pid}/communities",
headers=headers,
json=data,
)
# You get res 200 with error msg if all communities you are deleting
assert response.status_code == 200
assert "error" in str(response.data)

rec_com_left = client.get(f"/records/{record.pid.pid_value}", headers=headers)
rec_com_left = client.get(f"/records/{record_pid}", headers=headers)
assert len(rec_com_left.json["parent"]["communities"]["ids"]) == 1

# You get res 400 with error msg if you Delete the last one only.
response = client.delete(
f"/records/{record.pid.pid_value}/communities",
f"/records/{record_pid}/communities",
headers=headers,
json={"communities": [{"id": str(record.parent.communities.ids[0])}]},
json={
"communities": [{"id": str(record._record.parent.communities.ids[0])}]
},
)
assert response.status_code == 400
assert "error" in str(response.data)

record_saved = client.get(f"/records/{record.pid.pid_value}", headers=headers)
record_saved = client.get(f"/records/{record_pid}", headers=headers)
# check that only one community ID is associated with the record
assert len(record_saved.json["parent"]["communities"]["ids"]) == 1

0 comments on commit daa9e94

Please sign in to comment.