diff --git a/flag_engine/segments/constants.py b/flag_engine/segments/constants.py index 98fddd1..60cf2cd 100644 --- a/flag_engine/segments/constants.py +++ b/flag_engine/segments/constants.py @@ -20,3 +20,5 @@ IS_SET: ConditionOperator = "IS_SET" IS_NOT_SET: ConditionOperator = "IS_NOT_SET" IN: ConditionOperator = "IN" + +SEGMENT_IDENTIFIER_PROPERTY_NAME: str = "_$identity.identifier" diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index 13f3a18..3c8ea16 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -51,10 +51,10 @@ def evaluate_identity_in_segment( """ return len(segment.rules) > 0 and all( _traits_match_segment_rule( - override_traits or identity.identity_traits, - rule, - segment.id, - identity.django_id or identity.composite_key, + identity_traits=override_traits or identity.identity_traits, + rule=rule, + segment_id=segment.id, + identity=identity, ) for rule in segment.rules ) @@ -64,13 +64,16 @@ def _traits_match_segment_rule( identity_traits: typing.List[TraitModel], rule: SegmentRuleModel, segment_id: typing.Union[int, str], - identity_id: typing.Union[int, str], + identity: IdentityModel, ) -> bool: matches_conditions = ( rule.matching_function( [ _traits_match_segment_condition( - identity_traits, condition, segment_id, identity_id + identity_traits=identity_traits, + condition=condition, + segment_id=segment_id, + identity=identity, ) for condition in rule.conditions ] @@ -80,7 +83,7 @@ def _traits_match_segment_rule( ) return matches_conditions and all( - _traits_match_segment_rule(identity_traits, rule, segment_id, identity_id) + _traits_match_segment_rule(identity_traits, rule, segment_id, identity) for rule in rule.rules ) @@ -89,16 +92,21 @@ def _traits_match_segment_condition( identity_traits: typing.List[TraitModel], condition: SegmentConditionModel, segment_id: typing.Union[int, str], - identity_id: typing.Union[int, str], + identity: IdentityModel, ) -> bool: if condition.operator == constants.PERCENTAGE_SPLIT: assert condition.value float_value = float(condition.value) return ( - get_hashed_percentage_for_object_ids([segment_id, identity_id]) + get_hashed_percentage_for_object_ids( + [segment_id, identity.django_id or identity.composite_key] + ) <= float_value ) + if condition.property_ == constants.SEGMENT_IDENTIFIER_PROPERTY_NAME: + return _condition_matches_value(condition, identity.identifier) + trait = next( filter(lambda t: t.trait_key == condition.property_, identity_traits), None ) @@ -109,41 +117,41 @@ def _traits_match_segment_condition( if condition.operator == constants.IS_SET: return trait is not None - return _matches_trait_value(condition, trait.trait_value) if trait else False + return _condition_matches_value(condition, trait.trait_value) if trait else False -def _matches_trait_value( +def _condition_matches_value( condition: SegmentConditionModel, - trait_value: TraitValue, + matched_value: TraitValue, ) -> bool: if match_func := MATCH_FUNCS_BY_OPERATOR.get(condition.operator): - return match_func(condition.value, trait_value) + return match_func(condition.value, matched_value) return False def _evaluate_not_contains( segment_value: typing.Optional[str], - trait_value: TraitValue, + matched_value: TraitValue, ) -> bool: - return isinstance(trait_value, str) and str(segment_value) not in trait_value + return isinstance(matched_value, str) and str(segment_value) not in matched_value def _evaluate_regex( segment_value: typing.Optional[str], - trait_value: TraitValue, + matched_value: TraitValue, ) -> bool: return ( - trait_value is not None - and re.compile(str(segment_value)).match(str(trait_value)) is not None + matched_value is not None + and re.compile(str(segment_value)).match(str(matched_value)) is not None ) def _evaluate_modulo( segment_value: typing.Optional[str], - trait_value: TraitValue, + matched_value: TraitValue, ) -> bool: - if not isinstance(trait_value, (int, float)): + if not isinstance(matched_value, (int, float)): return False if segment_value is None: @@ -156,35 +164,37 @@ def _evaluate_modulo( except ValueError: return False - return trait_value % divisor == remainder + return matched_value % divisor == remainder -def _evaluate_in(segment_value: typing.Optional[str], trait_value: TraitValue) -> bool: +def _evaluate_in( + segment_value: typing.Optional[str], matched_value: TraitValue +) -> bool: if segment_value: - if isinstance(trait_value, str): - return trait_value in segment_value.split(",") - if isinstance(trait_value, int) and not any( - trait_value is x for x in (False, True) + if isinstance(matched_value, str): + return matched_value in segment_value.split(",") + if isinstance(matched_value, int) and not any( + matched_value is x for x in (False, True) ): - return str(trait_value) in segment_value.split(",") + return str(matched_value) in segment_value.split(",") return False -def _trait_value_typed( +def _matched_value_typed( func: typing.Callable[..., bool], ) -> typing.Callable[[typing.Optional[str], TraitValue], bool]: @wraps(func) def inner( segment_value: typing.Optional[str], - trait_value: typing.Union[TraitValue, semver.Version], + matched_value: typing.Union[TraitValue, semver.Version], ) -> bool: with suppress(TypeError, ValueError): - if isinstance(trait_value, str) and is_semver(segment_value): - trait_value = semver.Version.parse( - trait_value, + if isinstance(matched_value, str) and is_semver(segment_value): + matched_value = semver.Version.parse( + matched_value, ) - match_value = get_casting_function(trait_value)(segment_value) - return func(trait_value, match_value) + matched_against_value = get_casting_function(matched_value)(segment_value) + return func(matched_value, matched_against_value) return False return inner @@ -197,11 +207,11 @@ def inner( constants.REGEX: _evaluate_regex, constants.MODULO: _evaluate_modulo, constants.IN: _evaluate_in, - constants.EQUAL: _trait_value_typed(operator.eq), - constants.GREATER_THAN: _trait_value_typed(operator.gt), - constants.GREATER_THAN_INCLUSIVE: _trait_value_typed(operator.ge), - constants.LESS_THAN: _trait_value_typed(operator.lt), - constants.LESS_THAN_INCLUSIVE: _trait_value_typed(operator.le), - constants.NOT_EQUAL: _trait_value_typed(operator.ne), - constants.CONTAINS: _trait_value_typed(operator.contains), + constants.EQUAL: _matched_value_typed(operator.eq), + constants.GREATER_THAN: _matched_value_typed(operator.gt), + constants.GREATER_THAN_INCLUSIVE: _matched_value_typed(operator.ge), + constants.LESS_THAN: _matched_value_typed(operator.lt), + constants.LESS_THAN_INCLUSIVE: _matched_value_typed(operator.le), + constants.NOT_EQUAL: _matched_value_typed(operator.ne), + constants.CONTAINS: _matched_value_typed(operator.contains), } diff --git a/flag_engine/segments/models.py b/flag_engine/segments/models.py index dcfdc11..c88ebf0 100644 --- a/flag_engine/segments/models.py +++ b/flag_engine/segments/models.py @@ -39,3 +39,4 @@ class SegmentModel(BaseModel): name: str rules: typing.List[SegmentRuleModel] = Field(default_factory=list) feature_states: typing.List[FeatureStateModel] = Field(default_factory=list) + meta: typing.Optional[typing.Dict[str, str]] = None diff --git a/tests/unit/segments/fixtures.py b/tests/unit/segments/fixtures.py index 69229d8..7814d12 100644 --- a/tests/unit/segments/fixtures.py +++ b/tests/unit/segments/fixtures.py @@ -14,6 +14,7 @@ trait_key_3 = "date_joined" trait_value_3 = "2021-01-01" +identifier = "identity_1" empty_segment = SegmentModel(id=1, name="empty_segment") segment_single_condition = SegmentModel( @@ -148,3 +149,24 @@ ) ], ) +segment_identity_override = SegmentModel( + id=7, + name="dentity_override_identity_1_b6c2e2", + rules=[ + SegmentRuleModel( + type=constants.ALL_RULE, + conditions=[ + SegmentConditionModel( + operator=constants.EQUAL, + property_=constants.SEGMENT_IDENTIFIER_PROPERTY_NAME, + value=identifier, + ) + ], + ) + ], + meta={ + "identity_uuid": "d049c16b-e4dd-4830-b238-db2241c159e6", + "identity_identifier": identifier, + "type": "IDENTITY_OVERRIDE", + }, +) diff --git a/tests/unit/segments/test_segments_evaluator.py b/tests/unit/segments/test_segments_evaluator.py index 3097510..a23f5fd 100644 --- a/tests/unit/segments/test_segments_evaluator.py +++ b/tests/unit/segments/test_segments_evaluator.py @@ -8,7 +8,7 @@ from flag_engine.identities.traits.models import TraitModel from flag_engine.segments import constants from flag_engine.segments.evaluator import ( - _matches_trait_value, + _condition_matches_value, evaluate_identity_in_segment, ) from flag_engine.segments.models import ( @@ -19,7 +19,9 @@ from flag_engine.segments.types import ConditionOperator from tests.unit.segments.fixtures import ( empty_segment, + identifier, segment_conditions_and_nested_rules, + segment_identity_override, segment_multiple_conditions_all, segment_multiple_conditions_any, segment_nested_rules, @@ -106,6 +108,7 @@ ], True, ), + (segment_identity_override, [], True), ), ) def test_identity_in_segment( @@ -114,7 +117,7 @@ def test_identity_in_segment( expected_result: bool, ) -> None: identity = IdentityModel( - identifier="foo", + identifier=identifier, identity_traits=identity_traits, environment_api_key="api-key", ) @@ -265,7 +268,7 @@ def test_identity_in_segment_is_set_and_is_not_set( (constants.IN, 1, None, False), ), ) -def test_segment_condition_matches_trait_value( +def test_segment_condition_matches_value( operator: ConditionOperator, trait_value: typing.Union[None, int, str, float], condition_value: object, @@ -279,7 +282,7 @@ def test_segment_condition_matches_trait_value( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _condition_matches_value(segment_condition, trait_value) # Then assert result == expected_result @@ -298,7 +301,7 @@ def test_segment_condition__unsupported_operator__return_false( trait_value = "foo" # When - result = _matches_trait_value(segment_condition, trait_value) + result = _condition_matches_value(segment_condition, trait_value) # Then assert result is False @@ -329,7 +332,7 @@ def test_segment_condition__unsupported_operator__return_false( (constants.LESS_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", False), ], ) -def test_segment_condition_matches_trait_value_for_semver( +def test_segment_condition_matches_value_for_semver( operator: ConditionOperator, trait_value: str, condition_value: str, @@ -343,7 +346,7 @@ def test_segment_condition_matches_trait_value_for_semver( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _condition_matches_value(segment_condition, trait_value) # Then assert result == expected_result @@ -364,7 +367,7 @@ def test_segment_condition_matches_trait_value_for_semver( (1, None, False), ], ) -def test_segment_condition_matches_trait_value_for_modulo( +def test_segment_condition_matches_value_for_modulo( trait_value: typing.Union[int, float, str, bool], condition_value: typing.Optional[str], expected_result: bool, @@ -377,7 +380,7 @@ def test_segment_condition_matches_trait_value_for_modulo( ) # When - result = _matches_trait_value(segment_condition, trait_value) + result = _condition_matches_value(segment_condition, trait_value) # Then assert result == expected_result