diff --git a/invenio_rdm_records/services/communities/service.py b/invenio_rdm_records/services/communities/service.py index 613d168c8..92e26aa9d 100644 --- a/invenio_rdm_records/services/communities/service.py +++ b/invenio_rdm_records/services/communities/service.py @@ -387,7 +387,9 @@ def set_default(self, identity, id_, data, uow): ) record = self.record_cls.pid.resolve(id_) self.require_permission(identity, "manage", record=record) - record.parent.communities.default = valid_data["default"]["id"] + + default_community_id = valid_data.get("default", {}).get("id") or None + record.parent.communities.default = default_community_id uow.register( ParentRecordCommitOp( diff --git a/invenio_rdm_records/services/schemas/parent/communities.py b/invenio_rdm_records/services/schemas/parent/communities.py index 89007e243..05fde991e 100644 --- a/invenio_rdm_records/services/schemas/parent/communities.py +++ b/invenio_rdm_records/services/schemas/parent/communities.py @@ -15,7 +15,7 @@ class CommunitiesSchema(Schema): """Communities schema.""" ids = fields.List(fields.String()) - default = fields.String(attribute="default.id") + default = fields.String(attribute="default.id", allow_none=True) entries = fields.List(fields.Nested(CommunitySchema)) @post_dump diff --git a/tests/resources/test_resources_communities.py b/tests/resources/test_resources_communities.py index db56027e2..1e2e5130e 100644 --- a/tests/resources/test_resources_communities.py +++ b/tests/resources/test_resources_communities.py @@ -7,7 +7,6 @@ """Tests record's communities resources.""" -from contextlib import contextmanager from copy import deepcopy import pytest diff --git a/tests/resources/test_resources_record_communities.py b/tests/resources/test_resources_record_communities.py index a9812884f..2c9698496 100644 --- a/tests/resources/test_resources_record_communities.py +++ b/tests/resources/test_resources_record_communities.py @@ -1,3 +1,13 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM-Records is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. + +"""Record communities resources tests.""" + + def test_search_record_suggested_communities( client, community, @@ -43,16 +53,15 @@ def test_search_record_suggested_communities( assert hits["total"] == 0 # add record to a community - data = { - "communities": [ - {"id": open_review_community.id}, # test with id - ] - } record = record_community.create_record() response = client.post( f"/records/{record.pid.pid_value}/communities", headers=headers, - json=data, + json={ + "communities": [ + {"id": open_review_community.id}, # test with id + ] + }, ) assert response.status_code == 200 @@ -69,3 +78,102 @@ def test_search_record_suggested_communities( hits = response.json["hits"] assert hits["total"] == 1 assert hits["hits"][0]["id"] != open_review_community.id + + +def test_set_default_community( + client, + headers, + curator, + inviter, + community, + open_review_community, + closed_review_community, + record_community, +): + """Test setting a default community for a record.""" + # Add the curator user to the open review community + inviter(curator.id, open_review_community.id, "curator") + record = record_community.create_record(uploader=curator) + + # Login as the curator user + client = curator.login(client) + + # Add the record to the open review community + resp = client.post( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"communities": [{"id": open_review_community.id}]}, + ) + assert resp.status_code == 200 + assert not resp.json.get("errors") + processed = resp.json["processed"] + assert len(processed) == 1 + + def _assert_record_in_(communities, default=None): + resp = client.get(f"/records/{record.pid.pid_value}", headers=headers) + record_communities = resp.json["parent"]["communities"] + assert set(record_communities["ids"]) == communities + if default: + assert record_communities["default"] == default + else: + assert "default" not in record_communities + + _assert_record_in_({community.id, open_review_community.id}, default=None) + + # Set the default community to invalid values + resp = client.put( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"default": closed_review_community.id}, + ) + assert resp.status_code == 400 + assert resp.json["message"] == ( + "Cannot set community as the default. The record has not been added to the community." + ) + + # Set the default community to the open review community + resp = client.put( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"default": open_review_community.id}, + ) + assert resp.status_code == 200 + assert resp.json["communities"]["default"] == open_review_community.id + + _assert_record_in_( + {community.id, open_review_community.id}, + default=open_review_community.id, + ) + + # Unset the default community + resp = client.put( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"default": None}, + ) + assert resp.status_code == 200 + assert "default" not in resp.json["communities"] + + _assert_record_in_({community.id, open_review_community.id}, default=None) + + # Set the default community to the original community + resp = client.put( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"default": community.id}, + ) + assert resp.status_code == 200 + assert resp.json["communities"]["default"] == community.id + + _assert_record_in_({community.id, open_review_community.id}, default=community.id) + + # Unset the default community using empty string (current wrong behaviour in the UI) + resp = client.put( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json={"default": ""}, + ) + assert resp.status_code == 200 + assert "default" not in resp.json["communities"] + + _assert_record_in_({community.id, open_review_community.id}, default=None)