From 616c6be03a0b8c9dd742415c2fd2cde8cd08c95d Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Tue, 28 Nov 2023 10:22:53 +0000 Subject: [PATCH] fix: Rely on Flagsmith Engine for segment evaluation, avoid N+1 queries (#3038) --- api/conftest.py | 3 +- .../tests/test_dynamodb_identity_wrapper.py | 3 +- api/environments/identities/managers.py | 28 +- api/environments/identities/models.py | 21 +- .../identities/tests/test_models.py | 18 +- .../identities/tests/test_views.py | 18 +- api/environments/identities/views.py | 19 +- .../feature_segments/tests/test_models.py | 3 +- api/integrations/integration.py | 11 +- api/integrations/webhook/serializers.py | 14 +- api/segments/models.py | 256 +----------------- api/segments/serializers.py | 3 +- ...test_environments_views_sdk_environment.py | 3 +- .../test_unit_environments_views.py | 3 +- .../test_unit_import_export_export.py | 4 +- .../unit/projects/test_unit_projects_admin.py | 3 +- api/tests/unit/segments/test_conditions.py | 157 ----------- .../unit/segments/test_unit_segments_views.py | 3 +- api/util/mappers/engine.py | 67 +++-- 19 files changed, 175 insertions(+), 462 deletions(-) delete mode 100644 api/tests/unit/segments/test_conditions.py diff --git a/api/conftest.py b/api/conftest.py index 146788a42ce0..71ea7d54d92d 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -3,6 +3,7 @@ import pytest from django.contrib.contenttypes.models import ContentType from django.core.cache import cache +from flag_engine.segments.constants import EQUAL from rest_framework.authtoken.models import Token from rest_framework.test import APIClient @@ -46,7 +47,7 @@ ) from projects.permissions import VIEW_PROJECT from projects.tags.models import Tag -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from task_processor.task_run_method import TaskRunMethod from users.models import FFAdminUser, UserPermissionGroup diff --git a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py index 262b83f1d7c7..f1ca7411506b 100644 --- a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py +++ b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py @@ -5,12 +5,13 @@ from core.constants import INTEGER from django.core.exceptions import ObjectDoesNotExist from flag_engine.identities.builders import build_identity_model +from flag_engine.segments.constants import IN from rest_framework.exceptions import NotFound from environments.dynamodb import DynamoIdentityWrapper from environments.identities.models import Identity from environments.identities.traits.models import Trait -from segments.models import IN, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from util.mappers import ( map_environment_to_environment_document, map_identity_to_identity_document, diff --git a/api/environments/identities/managers.py b/api/environments/identities/managers.py index 2d6e607b2f8d..db36da45174c 100644 --- a/api/environments/identities/managers.py +++ b/api/environments/identities/managers.py @@ -1,6 +1,32 @@ +from typing import TYPE_CHECKING, Iterable + from django.db.models import Manager +if TYPE_CHECKING: + from environments.identities.models import Identity + from environments.models import Environment + from integrations.integration import IntegrationConfig + -class IdentityManager(Manager): +class IdentityManager(Manager["Identity"]): def get_by_natural_key(self, identifier, environment_api_key): return self.get(identifier=identifier, environment__api_key=environment_api_key) + + def get_or_create_for_sdk( + self, + identifier: str, + environment: "Environment", + integrations: Iterable["IntegrationConfig"], + ) -> tuple["Identity", bool]: + return ( + self.select_related( + "environment", + "environment__project", + *[ + f"environment__{integration['relation_name']}" + for integration in integrations + ], + ) + .prefetch_related("identity_traits") + .get_or_create(identifier=identifier, environment=environment) + ) diff --git a/api/environments/identities/models.py b/api/environments/identities/models.py index 61aca21719d6..b33c956a193e 100644 --- a/api/environments/identities/models.py +++ b/api/environments/identities/models.py @@ -3,6 +3,7 @@ from django.db import models from django.db.models import Prefetch, Q from django.utils import timezone +from flag_engine.segments.evaluator import evaluate_identity_in_segment from environments.identities.managers import IdentityManager from environments.identities.traits.models import Trait @@ -10,6 +11,11 @@ from features.models import FeatureState from features.multivariate.models import MultivariateFeatureStateValue from segments.models import Segment +from util.mappers.engine import ( + map_identity_to_engine, + map_segment_to_engine, + map_traits_to_engine, +) class Identity(models.Model): @@ -153,8 +159,21 @@ def get_segments( else: all_segments = self.environment.project.get_segments_from_cache() + engine_identity = map_identity_to_engine( + self, + with_overrides=False, + with_traits=False, + ) + engine_traits = map_traits_to_engine(traits) + for segment in all_segments: - if segment.does_identity_match(self, traits=traits): + engine_segment = map_segment_to_engine(segment) + + if evaluate_identity_in_segment( + identity=engine_identity, + segment=engine_segment, + override_traits=engine_traits, + ): matching_segments.append(segment) return matching_segments diff --git a/api/environments/identities/tests/test_models.py b/api/environments/identities/tests/test_models.py index 3ec3bad9f5a6..9e0691b22f12 100644 --- a/api/environments/identities/tests/test_models.py +++ b/api/environments/identities/tests/test_models.py @@ -1,6 +1,13 @@ import pytest from core.constants import FLOAT from django.utils import timezone +from flag_engine.segments.constants import ( + EQUAL, + GREATER_THAN, + GREATER_THAN_INCLUSIVE, + LESS_THAN_INCLUSIVE, + NOT_EQUAL, +) from rest_framework.test import APITestCase from environments.identities.models import Identity @@ -15,16 +22,7 @@ from features.value_types import BOOLEAN, INTEGER, STRING from organisations.models import Organisation from projects.models import Project -from segments.models import ( - EQUAL, - GREATER_THAN, - GREATER_THAN_INCLUSIVE, - LESS_THAN_INCLUSIVE, - NOT_EQUAL, - Condition, - Segment, - SegmentRule, -) +from segments.models import Condition, Segment, SegmentRule from .helpers import ( create_trait_for_identity, diff --git a/api/environments/identities/tests/test_views.py b/api/environments/identities/tests/test_views.py index b9552d671391..bbec39ed38b4 100644 --- a/api/environments/identities/tests/test_views.py +++ b/api/environments/identities/tests/test_views.py @@ -4,10 +4,11 @@ from unittest.case import TestCase import pytest -from core.constants import FLAGSMITH_UPDATED_AT_HEADER +from core.constants import FLAGSMITH_UPDATED_AT_HEADER, STRING from django.test import override_settings from django.urls import reverse from django.utils import timezone +from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import status from rest_framework.test import APIClient, APITestCase @@ -21,7 +22,6 @@ from integrations.amplitude.models import AmplitudeConfiguration from organisations.models import Organisation, OrganisationRole from projects.models import Project -from segments import models from segments.models import Condition, Segment, SegmentRule from util.tests import Helper @@ -371,7 +371,7 @@ def test_identities_endpoint_returns_traits(self, mock_amplitude_wrapper): trait = Trait.objects.create( identity=self.identity, trait_key="trait_key", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) @@ -423,7 +423,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment( Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type="STRING", + value_type=STRING, string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -477,7 +477,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment_an Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type="STRING", + value_type=STRING, string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -529,7 +529,7 @@ def test_identities_endpoint_returns_value_for_segment_if_rule_type_percentage_s [segment.id, self.identity.id] ) Condition.objects.create( - operator=models.PERCENTAGE_SPLIT, + operator=PERCENTAGE_SPLIT, value=(identity_percentage_value + (1 - identity_percentage_value) / 2) * 100.0, rule=segment_rule, @@ -576,7 +576,7 @@ def test_identities_endpoint_returns_default_value_if_rule_type_percentage_split [segment.id, self.identity.id] ) Condition.objects.create( - operator=models.PERCENTAGE_SPLIT, + operator=PERCENTAGE_SPLIT, value=identity_percentage_value / 2, rule=segment_rule, ) @@ -629,13 +629,13 @@ def test_post_identify_deletes_a_trait_if_trait_value_is_none(self): trait_1 = Trait.objects.create( identity=self.identity, trait_key="trait_key_1", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) trait_2 = Trait.objects.create( identity=self.identity, trait_key="trait_key_2", - value_type="STRING", + value_type=STRING, string_value="trait_value", ) diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index dec1b52dd672..af5750011be1 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -112,11 +112,11 @@ class SDKIdentitiesDeprecated(SDKAPIView): def get(self, request, identifier, *args, **kwargs): # if we have identifier fetch, or create if does not exist if identifier: - identity, _ = Identity.objects.get_or_create( + identity, _ = Identity.objects.get_or_create_for_sdk( identifier=identifier, environment=request.environment, + integrations=IDENTITY_INTEGRATIONS, ) - else: return Response( {"detail": "Missing identifier"}, status=status.HTTP_400_BAD_REQUEST @@ -172,17 +172,10 @@ def get(self, request): {"detail": "Missing identifier"} ) # TODO: add 400 status - will this break the clients? - identity, _ = ( - Identity.objects.select_related( - "environment", - "environment__project", - *[ - f"environment__{integration['relation_name']}" - for integration in IDENTITY_INTEGRATIONS - ], - ) - .prefetch_related("identity_traits") - .get_or_create(identifier=identifier, environment=request.environment) + identity, _ = Identity.objects.get_or_create_for_sdk( + identifier=identifier, + environment=request.environment, + integrations=IDENTITY_INTEGRATIONS, ) self.identity = identity diff --git a/api/features/feature_segments/tests/test_models.py b/api/features/feature_segments/tests/test_models.py index 2ff952d1a663..34222c6f3f3b 100644 --- a/api/features/feature_segments/tests/test_models.py +++ b/api/features/feature_segments/tests/test_models.py @@ -1,6 +1,7 @@ import pytest from core.constants import STRING from django.test import TestCase +from flag_engine.segments.constants import EQUAL from environments.identities.models import Identity from environments.identities.traits.models import Trait @@ -8,7 +9,7 @@ from features.models import Feature, FeatureSegment from organisations.models import Organisation from projects.models import Project -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule @pytest.mark.django_db diff --git a/api/integrations/integration.py b/api/integrations/integration.py index 2704a9fbff32..faecf80aebff 100644 --- a/api/integrations/integration.py +++ b/api/integrations/integration.py @@ -1,11 +1,20 @@ +from typing import Type, TypedDict + from integrations.amplitude.amplitude import AmplitudeWrapper +from integrations.common.wrapper import AbstractBaseIdentityIntegrationWrapper from integrations.heap.heap import HeapWrapper from integrations.mixpanel.mixpanel import MixpanelWrapper from integrations.rudderstack.rudderstack import RudderstackWrapper from integrations.segment.segment import SegmentWrapper from integrations.webhook.webhook import WebhookWrapper -IDENTITY_INTEGRATIONS = [ + +class IntegrationConfig(TypedDict): + relation_name: str + wrapper: Type[AbstractBaseIdentityIntegrationWrapper] + + +IDENTITY_INTEGRATIONS: list[IntegrationConfig] = [ {"relation_name": "amplitude_config", "wrapper": AmplitudeWrapper}, {"relation_name": "segment_config", "wrapper": SegmentWrapper}, {"relation_name": "heap_config", "wrapper": HeapWrapper}, diff --git a/api/integrations/webhook/serializers.py b/api/integrations/webhook/serializers.py index a07d43b0704e..e161b36a0316 100644 --- a/api/integrations/webhook/serializers.py +++ b/api/integrations/webhook/serializers.py @@ -1,6 +1,7 @@ import typing from django.db.models import Q +from flag_engine.segments.evaluator import evaluate_identity_in_segment from rest_framework import serializers from features.serializers import FeatureStateSerializerFull @@ -8,6 +9,7 @@ BaseEnvironmentIntegrationModelSerializer, ) from segments.models import Segment +from util.mappers.engine import map_identity_to_engine, map_segment_to_engine from .models import WebhookConfiguration @@ -25,8 +27,16 @@ class Meta: model = Segment fields = ("id", "name", "member") - def get_member(self, obj): - return obj.does_identity_match(identity=self.context.get("identity")) + def get_member(self, obj: Segment) -> bool: + engine_identity = map_identity_to_engine( + self.context.get("identity"), + with_overrides=False, + ) + engine_segment = map_segment_to_engine(obj) + return evaluate_identity_in_segment( + identity=engine_identity, + segment=engine_segment, + ) class IntegrationFeatureStateSerializer(FeatureStateSerializerFull): diff --git a/api/segments/models.py b/api/segments/models.py index a49e73b94ebc..21eeb0d6de69 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -2,8 +2,6 @@ import typing from copy import deepcopy -import semver -from core.constants import BOOLEAN, FLOAT, INTEGER from core.models import ( AbstractBaseExportableModel, SoftDeleteExportableModel, @@ -11,47 +9,15 @@ ) from django.core.exceptions import ValidationError from django.db import models -from flag_engine.utils.semver import is_semver, remove_semver_suffix +from flag_engine.segments import constants from audit.constants import SEGMENT_CREATED_MESSAGE, SEGMENT_UPDATED_MESSAGE from audit.related_object_type import RelatedObjectType -from environments.identities.helpers import ( - get_hashed_percentage_for_object_ids, -) from features.models import Feature from projects.models import Project -if typing.TYPE_CHECKING: - from environments.identities.models import Identity - from environments.identities.traits.models import Trait - - logger = logging.getLogger(__name__) -try: - import re2 as re - - logger.info("Using re2 library for regex.") -except ImportError: - logger.warning("Unable to import re2. Falling back to re.") - import re - -# Condition Types -EQUAL = "EQUAL" -GREATER_THAN = "GREATER_THAN" -LESS_THAN = "LESS_THAN" -LESS_THAN_INCLUSIVE = "LESS_THAN_INCLUSIVE" -CONTAINS = "CONTAINS" -GREATER_THAN_INCLUSIVE = "GREATER_THAN_INCLUSIVE" -NOT_CONTAINS = "NOT_CONTAINS" -NOT_EQUAL = "NOT_EQUAL" -REGEX = "REGEX" -PERCENTAGE_SPLIT = "PERCENTAGE_SPLIT" -MODULO = "MODULO" -IS_SET = "IS_SET" -IS_NOT_SET = "IS_NOT_SET" -IN = "IN" - class Segment( SoftDeleteExportableModel, @@ -106,14 +72,6 @@ def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool: return False - def does_identity_match( - self, identity: "Identity", traits: typing.List["Trait"] = None - ) -> bool: - rules = self.rules.all() - return rules.count() > 0 and all( - rule.does_identity_match(identity, traits) for rule in rules - ) - def get_create_log_message(self, history_instance) -> typing.Optional[str]: return SEGMENT_CREATED_MESSAGE % self.name @@ -155,34 +113,6 @@ def __str__(self): str(self.segment) if self.segment else str(self.rule), ) - def does_identity_match( - self, identity: "Identity", traits: typing.List["Trait"] = None - ) -> bool: - matches_conditions = False - conditions = self.conditions.all() - - if conditions.count() == 0: - matches_conditions = True - elif self.type == self.ALL_RULE: - matches_conditions = all( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - elif self.type == self.ANY_RULE: - matches_conditions = any( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - elif self.type == self.NONE_RULE: - matches_conditions = not any( - condition.does_identity_match(identity, traits) - for condition in conditions - ) - - return matches_conditions and all( - rule.does_identity_match(identity, traits) for rule in self.rules.all() - ) - def get_segment(self): """ rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the @@ -203,20 +133,20 @@ class Condition( related_object_type = RelatedObjectType.SEGMENT CONDITION_TYPES = ( - (EQUAL, "Exactly Matches"), - (GREATER_THAN, "Greater than"), - (LESS_THAN, "Less than"), - (CONTAINS, "Contains"), - (GREATER_THAN_INCLUSIVE, "Greater than or equal to"), - (LESS_THAN_INCLUSIVE, "Less than or equal to"), - (NOT_CONTAINS, "Does not contain"), - (NOT_EQUAL, "Does not match"), - (REGEX, "Matches regex"), - (PERCENTAGE_SPLIT, "Percentage split"), - (MODULO, "Modulo Operation"), - (IS_SET, "Is set"), - (IS_NOT_SET, "Is not set"), - (IN, "In"), + (constants.EQUAL, "Exactly Matches"), + (constants.GREATER_THAN, "Greater than"), + (constants.LESS_THAN, "Less than"), + (constants.CONTAINS, "Contains"), + (constants.GREATER_THAN_INCLUSIVE, "Greater than or equal to"), + (constants.LESS_THAN_INCLUSIVE, "Less than or equal to"), + (constants.NOT_CONTAINS, "Does not contain"), + (constants.NOT_EQUAL, "Does not match"), + (constants.REGEX, "Matches regex"), + (constants.PERCENTAGE_SPLIT, "Percentage split"), + (constants.MODULO, "Modulo Operation"), + (constants.IS_SET, "Is set"), + (constants.IS_NOT_SET, "Is not set"), + (constants.IN, "In"), ) operator = models.CharField(choices=CONDITION_TYPES, max_length=500) @@ -241,162 +171,6 @@ def __str__(self): self.value, ) - def does_identity_match( # noqa: C901 - self, identity: "Identity", traits: typing.List["Trait"] = None - ) -> bool: - if self.operator == PERCENTAGE_SPLIT: - return self._check_percentage_split_operator(identity) - - # we allow passing in traits to handle when they aren't - # persisted for certain organisations - traits = identity.identity_traits.all() if traits is None else traits - matching_trait = next( - filter(lambda t: t.trait_key == self.property, traits), None - ) - if matching_trait is None: - return self.operator == IS_NOT_SET - - if self.operator in (IS_SET, IS_NOT_SET): - return self.operator == IS_SET - elif self.operator == MODULO: - if matching_trait.value_type in [INTEGER, FLOAT]: - return self._check_modulo_operator(matching_trait.trait_value) - elif self.operator == IN: - return str(matching_trait.trait_value) in self.value.split(",") - elif matching_trait.value_type == INTEGER: - return self.check_integer_value(matching_trait.integer_value) - elif matching_trait.value_type == FLOAT: - return self.check_float_value(matching_trait.float_value) - elif matching_trait.value_type == BOOLEAN: - return self.check_boolean_value(matching_trait.boolean_value) - elif is_semver(self.value): - return self.check_semver_value(matching_trait.string_value) - - return self.check_string_value(matching_trait.string_value) - - def _check_percentage_split_operator(self, identity): - try: - float_value = float(self.value) / 100.0 - except ValueError: - return False - - segment = self.rule.get_segment() - return ( - get_hashed_percentage_for_object_ids( - object_ids=[segment.id, identity.get_hash_key()] - ) - <= float_value - ) - - def _check_modulo_operator(self, value: typing.Union[int, float]) -> bool: - try: - divisor, remainder = self.value.split("|") - divisor = float(divisor) - remainder = float(remainder) - except ValueError: - return False - - return value % divisor == remainder - - def check_integer_value(self, value: int) -> bool: - try: - int_value = int(str(self.value)) - except ValueError: - return False - - if self.operator == EQUAL: - return value == int_value - elif self.operator == GREATER_THAN: - return value > int_value - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= int_value - elif self.operator == LESS_THAN: - return value < int_value - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= int_value - elif self.operator == NOT_EQUAL: - return value != int_value - - return False - - def check_float_value(self, value: float) -> bool: - try: - float_value = float(str(self.value)) - except ValueError: - return False - - if self.operator == EQUAL: - return value == float_value - elif self.operator == GREATER_THAN: - return value > float_value - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= float_value - elif self.operator == LESS_THAN: - return value < float_value - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= float_value - elif self.operator == NOT_EQUAL: - return value != float_value - - return False - - def check_boolean_value(self, value: bool) -> bool: - if self.value in ("False", "false", "0"): - bool_value = False - elif self.value in ("True", "true", "1"): - bool_value = True - else: - return False - - if self.operator == EQUAL: - return value == bool_value - elif self.operator == NOT_EQUAL: - return value != bool_value - - return False - - def check_semver_value(self, value: str) -> bool: - try: - condition_version_info = semver.VersionInfo.parse( - remove_semver_suffix(self.value) - ) - except ValueError: - return False - - if self.operator == EQUAL: - return value == condition_version_info - elif self.operator == GREATER_THAN: - return value > condition_version_info - elif self.operator == GREATER_THAN_INCLUSIVE: - return value >= condition_version_info - elif self.operator == LESS_THAN: - return value < condition_version_info - elif self.operator == LESS_THAN_INCLUSIVE: - return value <= condition_version_info - elif self.operator == NOT_EQUAL: - return value != condition_version_info - - return False - - def check_string_value(self, value: str) -> bool: - try: - str_value = str(self.value) - except ValueError: - return False - - if self.operator == EQUAL: - return value == str_value - elif self.operator == NOT_EQUAL: - return value != str_value - elif self.operator == CONTAINS: - return str_value in value - elif self.operator == NOT_CONTAINS: - return str_value not in value - elif self.operator == REGEX: - return re.compile(str(self.value)).match(value) is not None - - return False - def get_update_log_message(self, history_instance) -> typing.Optional[str]: return f"Condition updated on segment '{self._get_segment().name}'." diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 29a99c292e25..28966e65fd7c 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,12 +1,13 @@ import typing +from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.serializers import ListSerializer from rest_framework_recursive.fields import RecursiveField from projects.models import Project -from segments.models import PERCENTAGE_SPLIT, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule class ConditionSerializer(serializers.ModelSerializer): diff --git a/api/tests/unit/environments/test_environments_views_sdk_environment.py b/api/tests/unit/environments/test_environments_views_sdk_environment.py index 9dc3983b1466..5857fa66b895 100644 --- a/api/tests/unit/environments/test_environments_views_sdk_environment.py +++ b/api/tests/unit/environments/test_environments_views_sdk_environment.py @@ -1,11 +1,12 @@ from core.constants import FLAGSMITH_UPDATED_AT_HEADER from django.urls import reverse +from flag_engine.segments.constants import EQUAL from rest_framework import status from rest_framework.test import APIClient from environments.models import Environment, EnvironmentAPIKey from features.models import Feature -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule def test_get_environment_document( diff --git a/api/tests/unit/environments/test_unit_environments_views.py b/api/tests/unit/environments/test_unit_environments_views.py index d6614c2f0c85..2c903e1c463b 100644 --- a/api/tests/unit/environments/test_unit_environments_views.py +++ b/api/tests/unit/environments/test_unit_environments_views.py @@ -6,6 +6,7 @@ from core.constants import STRING from django.contrib.contenttypes.models import ContentType from django.urls import reverse +from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from pytest_mock import MockerFixture from rest_framework import status @@ -26,7 +27,7 @@ UserProjectPermission, ) from projects.permissions import CREATE_ENVIRONMENT, VIEW_PROJECT -from segments.models import EQUAL, Condition, SegmentRule +from segments.models import Condition, SegmentRule from users.models import FFAdminUser from util.tests import Helper diff --git a/api/tests/unit/import_export/test_unit_import_export_export.py b/api/tests/unit/import_export/test_unit_import_export_export.py index 509734db1ef5..19b48a50f6db 100644 --- a/api/tests/unit/import_export/test_unit_import_export_export.py +++ b/api/tests/unit/import_export/test_unit_import_export_export.py @@ -7,7 +7,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.management import call_command from django.core.serializers.json import DjangoJSONEncoder -from flag_engine.segments.constants import ALL_RULE +from flag_engine.segments.constants import ALL_RULE, EQUAL from moto import mock_s3 from environments.models import Environment, EnvironmentAPIKey, Webhook @@ -42,7 +42,7 @@ from organisations.models import Organisation, OrganisationWebhook from projects.models import Project from projects.tags.models import Tag -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule def test_export_organisation(db): diff --git a/api/tests/unit/projects/test_unit_projects_admin.py b/api/tests/unit/projects/test_unit_projects_admin.py index 0be0f847cad3..e1f735f576e1 100644 --- a/api/tests/unit/projects/test_unit_projects_admin.py +++ b/api/tests/unit/projects/test_unit_projects_admin.py @@ -3,12 +3,13 @@ import pytest from django.contrib.admin import AdminSite +from flag_engine.segments.constants import EQUAL from environments.models import Environment from features.models import Feature, FeatureSegment, FeatureState from projects.admin import ProjectAdmin from projects.models import Project -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule if typing.TYPE_CHECKING: from django.contrib.auth.models import AbstractUser diff --git a/api/tests/unit/segments/test_conditions.py b/api/tests/unit/segments/test_conditions.py deleted file mode 100644 index 48191203996d..000000000000 --- a/api/tests/unit/segments/test_conditions.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from core.constants import INTEGER, STRING - -from environments.identities.traits.models import Trait -from segments.models import ( - EQUAL, - GREATER_THAN, - GREATER_THAN_INCLUSIVE, - IN, - IS_NOT_SET, - IS_SET, - LESS_THAN, - LESS_THAN_INCLUSIVE, - MODULO, - NOT_EQUAL, - Condition, -) - - -@pytest.mark.parametrize( - "operator, trait_value, condition_value, result", - [ - (EQUAL, "1.0.0", "1.0.0:semver", True), - (EQUAL, "1.0.0", "1.0.1:semver", False), - (NOT_EQUAL, "1.0.0", "1.0.0:semver", False), - (NOT_EQUAL, "1.0.0", "1.0.1:semver", True), - (GREATER_THAN, "1.0.1", "1.0.0:semver", True), - (GREATER_THAN, "1.0.0", "1.0.0-beta:semver", True), - (GREATER_THAN, "1.0.1", "1.2.0:semver", False), - (GREATER_THAN, "1.0.1", "1.0.1:semver", False), - (GREATER_THAN, "1.2.4", "1.2.3-pre.2+build.4:semver", True), - (LESS_THAN, "1.0.0", "1.0.1:semver", True), - (LESS_THAN, "1.0.0", "1.0.0:semver", False), - (LESS_THAN, "1.0.1", "1.0.0:semver", False), - (LESS_THAN, "1.0.0-rc.2", "1.0.0-rc.3:semver", True), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", True), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.2.0:semver", False), - (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.1:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.1:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.0:semver", True), - (LESS_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", False), - ], -) -def test_does_identity_match_for_semver_values( - identity, operator, trait_value, condition_value, result -): - # Given - condition = Condition(operator=operator, property="version", value=condition_value) - traits = [ - Trait( - trait_key="version", - string_value=trait_value, - identity=identity, - ) - ] - # Then - assert condition.does_identity_match(identity, traits) is result - - -@pytest.mark.parametrize( - "trait_value, condition_value, result", - [ - (1, "2|0", False), - (2, "2|0", True), - (3, "2|0", False), - (34.2, "4|3", False), - (35.0, "4|3", True), - ("dummy", "3|0", False), - ("1.0.0", "3|0", False), - (False, "1|3", False), - ], -) -def test_does_identity_match_for_modulo_operator( - identity, trait_value, condition_value, result -): - condition = Condition(operator=MODULO, property="user_id", value=condition_value) - - trait_value_data = Trait.generate_trait_value_data(trait_value) - traits = [Trait(trait_key="user_id", identity=identity, **trait_value_data)] - - assert condition.does_identity_match(identity, traits) is result - - -def test_does_identity_match_is_set_true(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_SET, property=trait_key) - traits = [Trait(trait_key=trait_key, identity=identity)] - - # Then - assert condition.does_identity_match(identity, traits) is True - - -def test_does_identity_match_is_set_false(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_SET, property=trait_key) - traits = [] - - # Then - assert condition.does_identity_match(identity, traits) is False - - -def test_does_identity_match_is_not_set_true(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_NOT_SET, property=trait_key) - traits = [Trait(trait_key=trait_key, identity=identity)] - - # Then - assert condition.does_identity_match(identity, traits) is False - - -def test_does_identity_match_is_not_set_false(identity): - # Given - trait_key = "some_property" - condition = Condition(operator=IS_NOT_SET, property=trait_key) - traits = [] - - # Then - assert condition.does_identity_match(identity, traits) is True - - -@pytest.mark.parametrize( - "condition_value, trait_value_type, trait_string_value, trait_integer_value, expected_result", - ( - ("", STRING, "foo", None, False), - ("foo,bar", STRING, "foo", None, True), - ("foo", STRING, "foo", None, True), - ("1,2,3,4", INTEGER, None, 1, True), - ("", INTEGER, None, 1, False), - ("1", INTEGER, None, 1, True), - ), -) -def test_does_identity_match_in( - identity, - condition_value, - trait_value_type, - trait_string_value, - trait_integer_value, - expected_result, -): - # Given - trait_key = "some_property" - condition = Condition(operator=IN, property=trait_key, value=condition_value) - traits = [ - Trait( - trait_key=trait_key, - identity=identity, - value_type=trait_value_type, - string_value=trait_string_value, - integer_value=trait_integer_value, - ) - ] - - # Then - assert condition.does_identity_match(identity, traits) is expected_result diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index ebdbe7c545c3..e0126b386a1e 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -4,6 +4,7 @@ import pytest from django.contrib.auth import get_user_model from django.urls import reverse +from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from rest_framework import status @@ -11,7 +12,7 @@ from audit.related_object_type import RelatedObjectType from environments.models import Environment from features.models import Feature -from segments.models import EQUAL, Condition, Segment, SegmentRule +from segments.models import Condition, Segment, SegmentRule from util.mappers import map_identity_to_identity_document User = get_user_model() diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index eb545a425dc5..87b0beb81b49 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -45,9 +45,34 @@ "map_feature_to_engine", "map_identity_to_engine", "map_mv_option_to_engine", + "map_segment_to_engine", + "map_traits_to_engine", ) +def map_traits_to_engine(traits: Iterable["Trait"]) -> list[TraitModel]: + return [ + TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) + for trait in traits + ] + + +def map_segment_to_engine( + segment: "Segment", +) -> SegmentModel: + segment_rules = segment.rules.all() + + # No reading from ORM past this point! + + return SegmentModel( + id=segment.pk, + name=segment.name, + rules=[ + map_segment_rule_to_engine(segment_rule) for segment_rule in segment_rules + ], + ) + + def map_segment_rule_to_engine( segment_rule: "SegmentRule", ) -> SegmentRuleModel: @@ -167,7 +192,7 @@ def map_environment_to_engine( int, Iterable["SegmentRule"], ] = {segment.pk: segment.rules.all() for segment in project_segments} - project_segment_feature_states_by_segment_id = _get_project_segment_feature_states( + project_segment_feature_states_by_segment_id = _get_segment_feature_states( project_segments, environment.pk, ) @@ -330,19 +355,30 @@ def map_environment_api_key_to_engine( ) -def map_identity_to_engine(identity: "Identity") -> IdentityModel: +def map_identity_to_engine( + identity: "Identity", + *, + with_overrides: bool = True, + with_traits: bool = True, +) -> IdentityModel: environment_api_key = identity.environment.api_key # Read relationships - grab all the data needed from the ORM here. - identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( - identity.identity_features.all(), - ) - multivariate_feature_state_values_by_feature_state_id = { - feature_state.pk: feature_state.multivariate_feature_state_values.all() - for feature_state in identity_feature_states - } + if with_overrides: + identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( + identity.identity_features.all(), + ) + multivariate_feature_state_values_by_feature_state_id = { + feature_state.pk: feature_state.multivariate_feature_state_values.all() + for feature_state in identity_feature_states + } + else: + identity_feature_states = [] + multivariate_feature_state_values_by_feature_state_id = {} - identity_traits: Iterable["Trait"] = identity.identity_traits.all() + identity_traits: Iterable["Trait"] = ( + identity.identity_traits.all() if with_traits else [] + ) # Prepare relationships. identity_feature_state_models = [ @@ -352,10 +388,7 @@ def map_identity_to_engine(identity: "Identity") -> IdentityModel: ) for feature_state in identity_feature_states ] - identity_trait_models = [ - TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) - for trait in identity_traits - ] + identity_trait_models = map_traits_to_engine(identity_traits) return IdentityModel( # Attributes: @@ -388,12 +421,12 @@ def _get_prioritised_feature_states( return list(prioritised_feature_state_by_feature_id.values()) -def _get_project_segment_feature_states( - project_segments: Iterable["Segment"], +def _get_segment_feature_states( + segments: Iterable["Segment"], environment_id: int, ) -> Dict[int, List["FeatureState"]]: feature_states_by_segment_id = {} - for segment in project_segments: + for segment in segments: segment_feature_states = feature_states_by_segment_id.setdefault(segment.pk, []) for feature_segment in segment.feature_segments.all(): if feature_segment.environment_id != environment_id: