diff --git a/entity/constants.py b/entity/constants.py deleted file mode 100644 index 4660df2..0000000 --- a/entity/constants.py +++ /dev/null @@ -1,4 +0,0 @@ - -class MembershipType: - UNION = 'UNION' - INTERSECTION = 'INTERSECTION' diff --git a/entity/migrations/0002_entitygroup_membership_type.py b/entity/migrations/0002_entitygroup_logic_string.py similarity index 56% rename from entity/migrations/0002_entitygroup_membership_type.py rename to entity/migrations/0002_entitygroup_logic_string.py index 526185d..c65242a 100644 --- a/entity/migrations/0002_entitygroup_membership_type.py +++ b/entity/migrations/0002_entitygroup_logic_string.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.4 on 2023-08-29 13:41 +# Generated by Django 4.2.4 on 2023-08-30 18:09 from django.db import migrations, models @@ -12,7 +12,7 @@ class Migration(migrations.Migration): operations = [ migrations.AddField( model_name='entitygroup', - name='membership_type', - field=models.CharField(choices=[('UNION', 'Union'), ('INTERSECTION', 'Intersection')], default='UNION'), + name='logic_string', + field=models.TextField(blank=True, default=None, null=True), ), ] diff --git a/entity/models.py b/entity/models.py index 5aefff8..8a67f75 100644 --- a/entity/models.py +++ b/entity/models.py @@ -1,4 +1,5 @@ -from itertools import compress +import ast +from itertools import compress, chain from activatable_model.models import BaseActivatableModel, ActivatableManager, ActivatableQuerySet from django.contrib.contenttypes.fields import GenericForeignKey @@ -9,8 +10,6 @@ from python3_utils import compare_on_attr from functools import reduce -from entity.constants import MembershipType - class AllEntityKindManager(ActivatableManager): """ @@ -336,6 +335,8 @@ def get_membership_cache(self, group_ids=None, is_active=True): if group_ids: membership_queryset = membership_queryset.filter(entity_group_id__in=group_ids) + membership_queryset = membership_queryset.order_by('id') + membership_queryset = membership_queryset.values_list('entity_group_id', 'entity_id', 'sub_entity_kind_id') # Iterate over the query results and build the cache dict @@ -365,12 +366,7 @@ class EntityGroup(models.Model): objects = EntityGroupManager() - membership_type_choices = [ - (MembershipType.UNION, 'Union'), - (MembershipType.INTERSECTION, 'Intersection'), - ] - - membership_type = models.CharField(choices=membership_type_choices, default=MembershipType.UNION) + logic_string = models.TextField(default=None, null=True, blank=True) def all_entities(self, is_active=True): """ @@ -382,6 +378,96 @@ def all_entities(self, is_active=True): """ return self.get_all_entities(return_models=True, is_active=is_active) + def get_filter_indices(self, node): + """ + Makes sure that each filter referenced actually exists + """ + if hasattr(node, 'op'): + # multi-operand operators + if hasattr(node, 'values'): + return list(chain(*[self.get_filter_indices(value) for value in node.values])) + # unary operators + elif hasattr(node, 'operand'): + return list(chain(*[self.get_filter_indices(node.operand)])) + elif hasattr(node, 'n'): + return [node.n] + return None + + def validate_filter_indices(self, indices, memberships): + """ + Raises an error if an invalid filter index is referenced or if an index is not referenced + """ + for index in indices: + if hasattr(index, '__iter__'): + return self.validate_filter_indices(index, memberships) + if index < 1 or index > len(memberships): + raise ValidationError('Filter logic contains an invalid filter index ({0})'.format(index)) + + for i in range(1, len(memberships) + 1): + if i not in indices: + raise ValidationError('Filter logic is missing a filter index ({0})'.format(i)) + + return True + + def _node_to_kmatch(self, node): + """ + Looks at an ast node and either returns the value or recursively returns the kmatch syntax. This is meant + to convert the boolean logic like "1 AND 2" to kmatch syntax like ['&', [1, 2]] + :return: kmatch syntax where memberships are represented by numbers + :rtype: list + """ + if hasattr(node, 'op'): + if hasattr(node, 'values'): + return [node.op, [self._node_to_kmatch(value) for value in node.values]] + elif hasattr(node, 'operand'): + return [node.op, self._node_to_kmatch(node.operand)] + elif hasattr(node, 'n'): + return node.n + return None + + def _map_kmatch_values(self, kmatch, memberships): + """ + Replaces index placeholders in the kmatch with the actual memberships. Any memberships that could not be matched + up with a field will be replaced with None + :return: the complete kmatch pattern + :rtype: list + """ + # Check if single item + if isinstance(kmatch, int): + return memberships[kmatch - 1] + if hasattr(kmatch, '__iter__'): + return [self._map_kmatch_values(value, memberships) for value in kmatch] + + cls = getattr(kmatch, '__class__') + if cls == ast.And: + return '&' + elif cls == ast.Or: + return '|' + elif cls == ast.Not: + return '!' + + def _process_kmatch(self, kmatch, full_set): + """ + Every item is 2 elements - the operator and the value or list of values + """ + entity_ids = set() + operators = {'&', '|', '!'} + + if isinstance(kmatch, set): + return kmatch + + if len(kmatch) == 2 and kmatch[0] not in operators: + return kmatch + + if kmatch[0] == '&': + entity_ids = self._process_kmatch(kmatch[1][0], full_set) & self._process_kmatch(kmatch[1][1], full_set) + elif kmatch[0] == '|': + entity_ids = self._process_kmatch(kmatch[1][0], full_set) | self._process_kmatch(kmatch[1][1], full_set) + elif kmatch[0] == '!': + entity_ids = full_set - self._process_kmatch(kmatch[1], full_set) + + return entity_ids + def get_all_entities(self, membership_cache=None, entities_by_kind=None, return_models=False, is_active=True): """ Returns a list of all entity ids in this group or optionally returns a queryset for all entity models. @@ -410,30 +496,47 @@ def get_all_entities(self, membership_cache=None, entities_by_kind=None, return_ entity_ids = set() # This group does have entities - if membership_cache.get(self.id): - - # Loop over each membership in this group - for entity_id, entity_kind_id in membership_cache[self.id]: - entity_ids_to_apply = set() - if entity_id: - if entity_kind_id: - # All sub entities of this kind under this entity - entity_ids_to_apply.update(entities_by_kind[entity_kind_id][entity_id]) + memberships = membership_cache.get(self.id) + if memberships: + if self.logic_string: + try: + filter_tree = ast.parse(self.logic_string.lower()) + except: + raise Exception + + expanded_memberships = [] + for entity_id, entity_kind_id in memberships: + if entity_id: + if entity_kind_id: + # All sub entities of this kind under this entity + expanded_memberships.append(set(entities_by_kind[entity_kind_id][entity_id])) + else: + # Individual entity + expanded_memberships.append({entity_id}) else: - # Individual entity - entity_ids_to_apply.add(entity_id) - else: - # All entities of this kind - entity_ids_to_apply.update(entities_by_kind[entity_kind_id]['all']) - - # Check membership type - if self.membership_type == MembershipType.UNION: - entity_ids.update(entity_ids_to_apply) - elif self.membership_type == MembershipType.INTERSECTION: - if not entity_ids: - entity_ids.update(entity_ids_to_apply) + # All entities of this kind + expanded_memberships.append(set(entities_by_kind[entity_kind_id]['all'])) + + # Make sure each index is valid + indices = self.get_filter_indices(filter_tree.body[0].value) + self.validate_filter_indices(indices, expanded_memberships) + kmatch = self._node_to_kmatch(filter_tree.body[0].value) + kmatch = self._map_kmatch_values(kmatch, expanded_memberships) + entity_ids = self._process_kmatch(kmatch, full_set=expanded_memberships[-1]) + + else: + # Loop over each membership in this group + for entity_id, entity_kind_id in membership_cache[self.id]: + if entity_id: + if entity_kind_id: + # All sub entities of this kind under this entity + entity_ids.update(entities_by_kind[entity_kind_id][entity_id]) + else: + # Individual entity + entity_ids.add(entity_id) else: - entity_ids = entity_ids.intersection(entity_ids_to_apply) + # All entities of this kind + entity_ids.update(entities_by_kind[entity_kind_id]['all']) # Check if a queryset needs to be returned if return_models: diff --git a/entity/tests/model_tests.py b/entity/tests/model_tests.py index ad7ed7a..e241171 100644 --- a/entity/tests/model_tests.py +++ b/entity/tests/model_tests.py @@ -5,7 +5,6 @@ from entity.signal_handlers import turn_off_syncing, turn_on_syncing -from entity.constants import MembershipType from entity.models import ( Entity, EntityKind, EntityRelationship, EntityGroup, EntityGroupMembership, get_entities_by_kind ) @@ -758,53 +757,76 @@ def setUp(self): self.group = G(EntityGroup) - def test_membership_type_intersection(self): - """ - Given two memberships of entities under different entity kinds, verify that only the intersection is returned - instead of the union. - - This test sets up: - - 5 sub entities under super 1 - - 5 sub entities under super 2 - - 3 sub entities under both - """ - super_entity_kind1 = G(EntityKind) - super_entity_kind2 = G(EntityKind) + def test_logic_string(self): + """ + Given 10 users User 0 - User 9 and 4 groups Group A - Group D + Group A: 0, 1, 2 + Group B: 1, 2, 3 + Group C: 4, 5, 6 + Group D: 6, 7, 8 + + Memberships: + 1. User in Group A + 2. User in Group B + 3. User in Group C + 4. User in Group D + 5. User = User 1 + 6. User = User 9 + + Logic: (1 AND 2) OR (3 AND 4) AND NOT(5) OR 6 + ((0, 1, 2) AND (1, 2, 3)) OR ((4, 5, 6) AND (6, 7, 8)) AND NOT(1) OR (9) + (1, 2) OR (6) AND NOT(1) OR 9 + (1, 2, 6) AND NOT(1) OR 9 + 2, 6, 9 + """ + super_entity_kind = G(EntityKind) sub_entity_kind = G(EntityKind) - super_entity1 = G(Entity, entity_kind=super_entity_kind1) - super_entity2 = G(Entity, entity_kind=super_entity_kind2) - sub_entities1 = [ + super_entity_a = G(Entity, entity_kind=super_entity_kind) + super_entity_b = G(Entity, entity_kind=super_entity_kind) + super_entity_c = G(Entity, entity_kind=super_entity_kind) + super_entity_d = G(Entity, entity_kind=super_entity_kind) + sub_entities = [ G(Entity, entity_kind=sub_entity_kind) - for _ in range(5) - ] - sub_entities2 = [ - G(Entity, entity_kind=sub_entity_kind) - for _ in range(5) + for _ in range(10) ] # Create the relationships - for entity in sub_entities1: - G(EntityRelationship, sub_entity=entity, super_entity=super_entity1) - for entity in sub_entities2: - G(EntityRelationship, sub_entity=entity, super_entity=super_entity2) - - # Create the intersection relationships - G(EntityRelationship, sub_entity=sub_entities1[0], super_entity=super_entity2) - G(EntityRelationship, sub_entity=sub_entities1[1], super_entity=super_entity2) - G(EntityRelationship, sub_entity=sub_entities1[2], super_entity=super_entity2) + relationships = [ + EntityRelationship(sub_entity=sub_entities[0], super_entity=super_entity_a), + EntityRelationship(sub_entity=sub_entities[1], super_entity=super_entity_a), + EntityRelationship(sub_entity=sub_entities[2], super_entity=super_entity_a), + + EntityRelationship(sub_entity=sub_entities[1], super_entity=super_entity_b), + EntityRelationship(sub_entity=sub_entities[2], super_entity=super_entity_b), + EntityRelationship(sub_entity=sub_entities[3], super_entity=super_entity_b), + + EntityRelationship(sub_entity=sub_entities[4], super_entity=super_entity_c), + EntityRelationship(sub_entity=sub_entities[5], super_entity=super_entity_c), + EntityRelationship(sub_entity=sub_entities[6], super_entity=super_entity_c), + + EntityRelationship(sub_entity=sub_entities[6], super_entity=super_entity_d), + EntityRelationship(sub_entity=sub_entities[7], super_entity=super_entity_d), + EntityRelationship(sub_entity=sub_entities[8], super_entity=super_entity_d), + ] + EntityRelationship.objects.bulk_create(relationships) # Create the entity group - entity_group = G(EntityGroup, membership_type=MembershipType.INTERSECTION) + entity_group = G(EntityGroup, logic_string='(((1 AND 2) OR (3 AND 4)) AND NOT(5) OR 6) AND 7') # Create the memberships -- two memberships of all subs under a kind - G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity1) - G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity2) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_a) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_b) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_c) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_d) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=None, entity=sub_entities[1]) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=None, entity=sub_entities[9]) + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=None) entity_ids = entity_group.get_all_entities() self.assertEqual(entity_ids, set([ - sub_entities1[0].id, - sub_entities1[1].id, - sub_entities1[2].id, + sub_entities[2].id, + sub_entities[6].id, + sub_entities[9].id, ])) def test_individual_entities_returned(self): diff --git a/release_notes.md b/release_notes.md index 61776fb..088cc5d 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,7 +1,7 @@ ## Release Notes - 6.2.0: - - Add support for intersection type memberships + - Add support for boolean logic strings to apply to entity group memberships - 6.1.1: - django support for 4.2 - drop django 2.2