Skip to content

Commit

Permalink
feat: strict typing
Browse files Browse the repository at this point in the history
- add strict mode mypy, fix typing errors
- add absolufy-imports, ditch relative imports
- move segment evaluation logic to `flag_engine.segments.evaluation` module
- add `type: ignore` comment for decorator usage on a property
- add `type: ignore` comments for untyped dependencies
  • Loading branch information
khvn26 committed Sep 18, 2023
1 parent 399b4f6 commit 832de2c
Show file tree
Hide file tree
Showing 36 changed files with 389 additions and 208 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,22 @@ jobs:
- name: Check Formatting
run: black --check .

- name: Check Imports
run: |
git ls-files | grep '\.py$' | xargs absolufy-imports
isort . --check
- name: Check flake8 linting
run: flake8 .

- name: Check Typing
run: mypy --strict .

- name: Run Tests
run: pytest -p no:warnings

- name: Check Coverage
uses: 5monkeys/cobertura-action@v13
with:
minimum_coverage: 100
fail_below_threshold: true
minimum_coverage: 100
fail_below_threshold: true
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: mypy
args: [--strict]
additional_dependencies:
[pydantic, pytest, pytest_mock, types-pytest-lazy-fixture, types-setuptools, semver]
- repo: https://github.com/MarcoGorelli/absolufy-imports
rev: v0.3.1
hooks:
- id: absolufy-imports
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
14 changes: 9 additions & 5 deletions flag_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from flag_engine.utils.exceptions import FeatureStateNotFound


def get_environment_feature_states(environment: EnvironmentModel):
def get_environment_feature_states(
environment: EnvironmentModel,
) -> typing.List[FeatureStateModel]:
"""
Get a list of feature states for a given environment
Expand All @@ -19,7 +21,9 @@ def get_environment_feature_states(environment: EnvironmentModel):
return environment.feature_states


def get_environment_feature_state(environment: EnvironmentModel, feature_name: str):
def get_environment_feature_state(
environment: EnvironmentModel, feature_name: str
) -> FeatureStateModel:
"""
Get a specific feature state for a given feature_name in a given environment
Expand All @@ -38,7 +42,7 @@ def get_environment_feature_state(environment: EnvironmentModel, feature_name: s
def get_identity_feature_states(
environment: EnvironmentModel,
identity: IdentityModel,
override_traits: typing.List[TraitModel] = None,
override_traits: typing.Optional[typing.List[TraitModel]] = None,
) -> typing.List[FeatureStateModel]:
"""
Get a list of feature states for a given identity in a given environment.
Expand All @@ -63,8 +67,8 @@ def get_identity_feature_state(
environment: EnvironmentModel,
identity: IdentityModel,
feature_name: str,
override_traits: typing.List[TraitModel] = None,
):
override_traits: typing.Optional[typing.List[TraitModel]] = None,
) -> FeatureStateModel:
"""
Get a specific feature state for a given identity in a given environment.
Expand Down
2 changes: 1 addition & 1 deletion flag_engine/environments/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from flag_engine.environments.models import EnvironmentAPIKeyModel, EnvironmentModel


def build_environment_model(environment_dict: dict[str, Any]) -> EnvironmentModel:
def build_environment_model(environment_dict: Dict[str, Any]) -> EnvironmentModel:
return EnvironmentModel.model_validate(environment_dict)


Expand Down
10 changes: 5 additions & 5 deletions flag_engine/environments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class EnvironmentAPIKeyModel(BaseModel):
active: bool = True

@property
def is_valid(self):
def is_valid(self) -> bool:
return self.active and (
not self.expires_at or self.expires_at > utcnow_with_tz()
)
Expand Down Expand Up @@ -52,7 +52,7 @@ class EnvironmentModel(BaseModel):

webhook_config: typing.Optional[WebhookModel] = None

_INTEGRATION_ATTS = [
_INTEGRATION_ATTRS = [
"amplitude_config",
"heap_config",
"mixpanel_config",
Expand All @@ -76,9 +76,9 @@ def integrations_data(self) -> typing.Dict[str, typing.Dict[str, str]]:
"""

integrations_data = {}
for integration_attr in self._INTEGRATION_ATTS:
integration_config: IntegrationModel = getattr(self, integration_attr, None)
if integration_config:
for integration_attr in self._INTEGRATION_ATTRS:
integration_config: typing.Optional[IntegrationModel]
if integration_config := getattr(self, integration_attr, None):
integrations_data[integration_attr] = {
"base_url": integration_config.base_url,
"api_key": integration_config.api_key,
Expand Down
52 changes: 29 additions & 23 deletions flag_engine/features/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import typing
import uuid

from annotated_types import Ge, Le
from annotated_types import Ge, Le, SupportsLt
from pydantic import UUID4, BaseModel, Field, model_validator
from pydantic_collections import BaseCollectionModel
from typing_extensions import Annotated

from flag_engine.utils.exceptions import InvalidPercentageAllocation
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
Expand All @@ -15,21 +16,21 @@ class FeatureModel(BaseModel):
name: str
type: str

def __eq__(self, other):
return self.id == other.id
def __eq__(self, other: object) -> bool:
return isinstance(other, FeatureModel) and self.id == other.id

def __hash__(self):
def __hash__(self) -> int:
return hash(self.id)


class MultivariateFeatureOptionModel(BaseModel):
value: typing.Any
id: int = None
id: typing.Optional[int] = None


class MultivariateFeatureStateValueModel(BaseModel):
multivariate_feature_option: MultivariateFeatureOptionModel
percentage_allocation: typing.Annotated[float, Ge(0), Le(100)]
percentage_allocation: Annotated[float, Ge(0), Le(100)]
id: typing.Optional[int] = None
mv_fs_value_uuid: UUID4 = Field(default_factory=uuid.uuid4)

Expand All @@ -39,7 +40,7 @@ class FeatureSegmentModel(BaseModel):


class MultivariateFeatureStateValueList(
BaseCollectionModel[MultivariateFeatureStateValueModel]
BaseCollectionModel[MultivariateFeatureStateValueModel] # type: ignore[misc,no-any-unimported]
):
@staticmethod
def _ensure_correct_percentage_allocations(
Expand Down Expand Up @@ -74,15 +75,15 @@ def append(
class FeatureStateModel(BaseModel, validate_assignment=True):
feature: FeatureModel
enabled: bool
django_id: int = None
feature_segment: FeatureSegmentModel = None
django_id: typing.Optional[int] = None
feature_segment: typing.Optional[FeatureSegmentModel] = None
featurestate_uuid: UUID4 = Field(default_factory=uuid.uuid4)
feature_state_value: typing.Any = None
multivariate_feature_state_values: MultivariateFeatureStateValueList = Field(
default_factory=MultivariateFeatureStateValueList
)

def set_value(self, value: typing.Any):
def set_value(self, value: typing.Any) -> None:
self.feature_state_value = value

def get_value(self, identity_id: typing.Union[None, int, str] = None) -> typing.Any:
Expand Down Expand Up @@ -112,18 +113,19 @@ def is_higher_segment_priority(self, other: "FeatureStateModel") -> bool:
"""

try:
return (
getattr(
self.feature_segment,
"priority",
math.inf,
if other_feature_segment := other.feature_segment:
if (
other_feature_segment_priority := other_feature_segment.priority
) is not None:
return (
getattr(
self.feature_segment,
"priority",
math.inf,
)
< other_feature_segment_priority
)
< other.feature_segment.priority
)

except (TypeError, AttributeError):
return False
return False

def _get_multivariate_value(
self, identity_id: typing.Union[int, str]
Expand All @@ -137,10 +139,14 @@ def _get_multivariate_value(
# the percentage allocations of the multivariate options. This gives us a
# way to ensure that the same value is returned every time we use the same
# percentage value.
start_percentage = 0
start_percentage = 0.0

def _mv_fs_sort_key(mv_value: MultivariateFeatureStateValueModel) -> SupportsLt:
return mv_value.id or mv_value.mv_fs_value_uuid

for mv_value in sorted(
self.multivariate_feature_state_values,
key=lambda v: v.id or v.mv_fs_value_uuid,
key=_mv_fs_sort_key,
):
limit = mv_value.percentage_allocation + start_percentage
if start_percentage <= percentage_value < limit:
Expand Down
4 changes: 2 additions & 2 deletions flag_engine/identities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flag_engine.utils.exceptions import DuplicateFeatureState


class IdentityFeaturesList(BaseCollectionModel[FeatureStateModel]):
class IdentityFeaturesList(BaseCollectionModel[FeatureStateModel]): # type: ignore[misc,no-any-unimported]
@staticmethod
def _ensure_unique_feature_ids(
value: typing.MutableSequence[FeatureStateModel],
Expand Down Expand Up @@ -45,7 +45,7 @@ class IdentityModel(BaseModel):
identity_uuid: UUID4 = Field(default_factory=uuid.uuid4)
django_id: typing.Optional[int] = None

@computed_field
@computed_field # type: ignore[misc]
@property
def composite_key(self) -> str:
return self.generate_composite_key(self.environment_api_key, self.identifier)
Expand Down
3 changes: 2 additions & 1 deletion flag_engine/identities/traits/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from decimal import Decimal
from typing import Any, TypeGuard, Union, get_args
from typing import Any, Union, get_args

from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypeGuard

from flag_engine.identities.traits.types import TraitValue

Expand Down
3 changes: 2 additions & 1 deletion flag_engine/identities/traits/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, Union
from typing import Union

from pydantic.types import AllowInfNan, StringConstraints
from typing_extensions import Annotated

from flag_engine.identities.traits.constants import TRAIT_STRING_VALUE_MAX_LENGTH

Expand Down
2 changes: 1 addition & 1 deletion flag_engine/organisations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class OrganisationModel(BaseModel):
persist_trait_data: bool

@property
def unique_slug(self):
def unique_slug(self) -> str:
return str(self.id) + "-" + self.name
Empty file added flag_engine/py.typed
Empty file.
55 changes: 19 additions & 36 deletions flag_engine/segments/constants.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,22 @@
# Segment Rules
ALL_RULE = "ALL"
ANY_RULE = "ANY"
NONE_RULE = "NONE"
from flag_engine.segments.types import ConditionOperator, RuleType

RULE_TYPES = [ALL_RULE, ANY_RULE, NONE_RULE]
# Segment Rules
ALL_RULE: RuleType = "ALL"
ANY_RULE: RuleType = "ANY"
NONE_RULE: RuleType = "NONE"

# Segment Condition Operators
EQUAL = "EQUAL"
GREATER_THAN = "GREATER_THAN"
LESS_THAN = "LESS_THAN"
LESS_THAN_INCLUSIVE = "LESS_THAN_INCLUSIVE"
CONTAINS = "CONTAINS"
GREATER_THAN_INCLUSIVE = "GREATER_THAN_INCLUSIVE"
NOT_CONTAINS = "NOT_CONTAINS"
NOT_EQUAL = "NOT_EQUAL"
REGEX = "REGEX"
PERCENTAGE_SPLIT = "PERCENTAGE_SPLIT"
MODULO = "MODULO"
IS_SET = "IS_SET"
IS_NOT_SET = "IS_NOT_SET"
IN = "IN"

CONDITION_OPERATORS = [
EQUAL,
GREATER_THAN,
LESS_THAN,
LESS_THAN_INCLUSIVE,
CONTAINS,
GREATER_THAN_INCLUSIVE,
NOT_CONTAINS,
NOT_EQUAL,
REGEX,
PERCENTAGE_SPLIT,
MODULO,
IS_SET,
IS_NOT_SET,
IN,
]
EQUAL: ConditionOperator = "EQUAL"
GREATER_THAN: ConditionOperator = "GREATER_THAN"
LESS_THAN: ConditionOperator = "LESS_THAN"
LESS_THAN_INCLUSIVE: ConditionOperator = "LESS_THAN_INCLUSIVE"
CONTAINS: ConditionOperator = "CONTAINS"
GREATER_THAN_INCLUSIVE: ConditionOperator = "GREATER_THAN_INCLUSIVE"
NOT_CONTAINS: ConditionOperator = "NOT_CONTAINS"
NOT_EQUAL: ConditionOperator = "NOT_EQUAL"
REGEX: ConditionOperator = "REGEX"
PERCENTAGE_SPLIT: ConditionOperator = "PERCENTAGE_SPLIT"
MODULO: ConditionOperator = "MODULO"
IS_SET: ConditionOperator = "IS_SET"
IS_NOT_SET: ConditionOperator = "IS_NOT_SET"
IN: ConditionOperator = "IN"
2 changes: 1 addition & 1 deletion flag_engine/segments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _trait_value_typed(
@wraps(func)
def inner(
segment_value: typing.Optional[str],
trait_value: TraitValue,
trait_value: typing.Union[TraitValue, semver.Version],
) -> bool:
with suppress(TypeError, ValueError):
if isinstance(trait_value, str) and is_semver(segment_value):
Expand Down
6 changes: 3 additions & 3 deletions flag_engine/segments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class SegmentRuleModel(BaseModel):
conditions: typing.List[SegmentConditionModel] = Field(default_factory=list)

@staticmethod
def none(iterable: typing.Iterable) -> bool:
def none(iterable: typing.Iterable[object]) -> bool:
return not any(iterable)

@property
def matching_function(self) -> callable:
def matching_function(self) -> typing.Callable[[typing.Iterable[object]], bool]:
return {
constants.ANY_RULE: any,
constants.ALL_RULE: all,
constants.NONE_RULE: SegmentRuleModel.none,
}.get(self.type)
}[self.type]


class SegmentModel(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion flag_engine/utils/hashing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
import typing

from flag_engine.utils.types import SupportsStr


def get_hashed_percentage_for_object_ids(
object_ids: typing.Iterable[typing.Any], iterations: int = 1
object_ids: typing.Iterable[SupportsStr], iterations: int = 1
) -> float:
"""
Given a list of object ids, get a floating point number between 0 (inclusive) and
Expand Down
2 changes: 1 addition & 1 deletion flag_engine/utils/json/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class DecimalEncoder(json.JSONEncoder):
int/float(for us) converted to decimal by boto3/dynamodb.
"""

def default(self, obj):
def default(self, obj: object) -> object:
if isinstance(obj, decimal.Decimal):
if obj % 1 == 0:
return int(obj)
Expand Down
Loading

0 comments on commit 832de2c

Please sign in to comment.