Skip to content

Commit

Permalink
Fixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Rathore committed Nov 6, 2024
1 parent 5c35540 commit 55f2a0a
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 103 deletions.
8 changes: 0 additions & 8 deletions firebase_admin/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,3 @@ def __init__(self, **kwargs):

def parse_body(self, resp):
return resp.json()

class RemoteConfigApiClient(HttpClient):
"""An HTTP client that parses response messages as JSON."""
def __init__(self, **kwargs):
HttpClient.__init__(self, **kwargs)
def parse_body(self, resp):
return resp.json()

232 changes: 138 additions & 94 deletions firebase_admin/remote_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import json
import logging
from typing import Dict, Optional, Literal, Callable, Union
from typing import Dict, Optional, Literal, Union
from enum import Enum
import re
import hashlib
Expand Down Expand Up @@ -228,15 +228,14 @@ def evaluate(self):
evaluated_conditions = self.evaluate_conditions(self._conditions, self._context)

# Overlays config Value objects derived by evaluating the template.
# evaluated_conditions = None
if self._parameters is not None:
if self._parameters:
for key, parameter in self._parameters.items():
conditional_values = parameter.get('conditionalValues', {})
default_value = parameter.get('defaultValue', {})
parameter_value_wrapper = None
# Iterates in order over condition list. If there is a value associated
# with a condition, this checks if the condition is true.
if evaluated_conditions is not None:
if evaluated_conditions:
for condition_name, condition_evaluation in evaluated_conditions.items():
if condition_name in conditional_values and condition_evaluation:
parameter_value_wrapper = conditional_values[condition_name]
Expand Down Expand Up @@ -404,6 +403,7 @@ def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int:
hash_object.update(seeded_randomization_id.encode('utf-8'))
hash64 = hash_object.hexdigest()
return abs(int(hash64, 16))

def evaluate_custom_signal_condition(self, custom_signal_condition,
context) -> bool:
"""Evaluates a custom signal condition.
Expand All @@ -417,124 +417,168 @@ def evaluate_custom_signal_condition(self, custom_signal_condition,
"""
custom_signal_operator = custom_signal_condition.get('custom_signal_operator') or {}
custom_signal_key = custom_signal_condition.get('custom_signal_key') or {}
tgt_custom_signal_values = custom_signal_condition.get('target_custom_signal_values') or {}
target_custom_signal_values = (
custom_signal_condition.get('target_custom_signal_values') or {})

if not all([custom_signal_operator, custom_signal_key, tgt_custom_signal_values]):
if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]):
logger.warning("Missing operator, key, or target values for custom signal condition.")
return False

if not tgt_custom_signal_values:
if not target_custom_signal_values:
return False
actual_custom_signal_value = getattr(context, custom_signal_key, None)
if actual_custom_signal_value is None:
actual_custom_signal_value = context.get(custom_signal_key) or {}

if not actual_custom_signal_value:
logger.warning("Custom signal value not found in context: %s", custom_signal_key)
return False

if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS:
return compare_strings(lambda target, actual: target in actual)
return self._compare_strings(target_custom_signal_values,
actual_custom_signal_value,
lambda target, actual: target in actual)
if custom_signal_operator == CustomSignalOperator.STRING_DOES_NOT_CONTAIN:
return not compare_strings(lambda target, actual: target in actual)
return not self._compare_strings(target_custom_signal_values,
actual_custom_signal_value,
lambda target, actual: target in actual)
if custom_signal_operator == CustomSignalOperator.STRING_EXACTLY_MATCHES:
return compare_strings(lambda target, actual: target.strip() == actual.strip())
return self._compare_strings(target_custom_signal_values,
actual_custom_signal_value,
lambda target, actual: target.strip() == actual.strip())
if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS_REGEX:
return compare_strings(lambda target, actual: re.search(target, actual) is not None)
return self._compare_strings(target_custom_signal_values,
actual_custom_signal_value,
re.search)

# For numeric operators only one target value is allowed.
if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_THAN:
return compare_numbers(lambda r: r < 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r < 0)
if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_EQUAL:
return compare_numbers(lambda r: r <= 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r <= 0)
if custom_signal_operator == CustomSignalOperator.NUMERIC_EQUAL:
return compare_numbers(lambda r: r == 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r == 0)
if custom_signal_operator == CustomSignalOperator.NUMERIC_NOT_EQUAL:
return compare_numbers(lambda r: r != 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r != 0)
if custom_signal_operator == CustomSignalOperator.NUMERIC_GREATER_THAN:
return compare_numbers(lambda r: r > 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r > 0)
if custom_signal_operator == CustomSignalOperator.NUMERIC_GREATER_EQUAL:
return compare_numbers(lambda r: r >= 0)
return self._compare_numbers(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r >= 0)

# For semantic operators only one target value is allowed.
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN:
return compare_semantic_versions(lambda r: r < 0)
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r < 0)
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_LESS_EQUAL:
return compare_semantic_versions(lambda r: r <= 0)
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r <= 0)
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_EQUAL:
return compare_semantic_versions(lambda r: r == 0)
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r == 0)
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_NOT_EQUAL:
return compare_semantic_versions(lambda r: r != 0)
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r != 0)
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_GREATER_THAN:
return compare_semantic_versions(lambda r: r > 0)
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r > 0)
if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_GREATER_EQUAL:
return compare_semantic_versions(lambda r: r >= 0)

def compare_strings(predicate_fn: Callable[[str, str], bool]) -> bool:
"""Compares the actual string value of a signal against a list of target values.
Args:
predicate_fn: A function that takes two string arguments (target and actual)
and returns a boolean indicating whether
the target matches the actual value.
Returns:
bool: True if the predicate function returns True for any target value in the list,
False otherwise.
"""
for target in tgt_custom_signal_values:
if predicate_fn(target, str(actual_custom_signal_value)):
return True
return False
return self._compare_semantic_versions(target_custom_signal_values[0],
actual_custom_signal_value,
lambda r: r >= 0)
logger.warning("Unknown custom signal operator: %s", custom_signal_operator)
return False

def compare_numbers(predicate_fn: Callable[[int], bool]) -> bool:
try:
target = float(tgt_custom_signal_values[0])
actual = float(actual_custom_signal_value)
result = -1 if actual < target else 1 if actual > target else 0
return predicate_fn(result)
except ValueError:
logger.warning("Invalid numeric value for comparison.")
return False
def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool:
"""Compares the actual string value of a signal against a list of target values.
def compare_semantic_versions(predicate_fn: Callable[[int], bool]) -> bool:
"""Compares the actual semantic version value of a signal against a target value.
Calls the predicate function with -1, 0, 1 if actual is less than, equal to,
or greater than target.
Args:
predicate_fn: A function that takes an integer (-1, 0, or 1) and returns a boolean.
Returns:
bool: True if the predicate function returns True for the result of the comparison,
False otherwise.
"""
return compare_versions(str(actual_custom_signal_value),
str(tgt_custom_signal_values[0]), predicate_fn)
def compare_versions(version1: str, version2: str,
predicate_fn: Callable[[int], bool]) -> bool:
"""Compares two semantic version strings.
Args:
version1: The first semantic version string.
version2: The second semantic version string.
predicate_fn: A function that takes an integer and returns a boolean.
Returns:
bool: The result of the predicate function.
"""
try:
v1_parts = [int(part) for part in version1.split('.')]
v2_parts = [int(part) for part in version2.split('.')]
max_length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (max_length - len(v1_parts)))
v2_parts.extend([0] * (max_length - len(v2_parts)))

for part1, part2 in zip(v1_parts, v2_parts):
if part1 < part2:
return predicate_fn(-1)
if part1 > part2:
return predicate_fn(1)
return predicate_fn(0)
except ValueError:
logger.warning("Invalid semantic version format for comparison.")
return False
Args:
target_values: A list of target string values.
actual_value: The actual value to compare, which can be a string or number.
predicate_fn: A function that takes two string arguments (target and actual)
and returns a boolean indicating whether
the target matches the actual value.
logger.warning("Unknown custom signal operator: %s", custom_signal_operator)
Returns:
bool: True if the predicate function returns True for any target value in the list,
False otherwise.
"""

for target in target_values:
if predicate_fn(target, str(actual_value)):
return True
return False

def _compare_numbers(self, target_value, actual_value, predicate_fn) -> bool:
try:
target = float(target_value)
actual = float(actual_value)
result = -1 if actual < target else 1 if actual > target else 0
return predicate_fn(result)
except ValueError:
logger.warning("Invalid numeric value for comparison.")
return False

def _compare_semantic_versions(self, target_value, actual_value, predicate_fn) -> bool:
"""Compares the actual semantic version value of a signal against a target value.
Calls the predicate function with -1, 0, 1 if actual is less than, equal to,
or greater than target.
Args:
target_values: A list of target string values.
actual_value: The actual value to compare, which can be a string or number.
predicate_fn: A function that takes an integer (-1, 0, or 1) and returns a boolean.
Returns:
bool: True if the predicate function returns True for the result of the comparison,
False otherwise.
"""
return self._compare_versions(str(actual_value),
str(target_value), predicate_fn)

def _compare_versions(self, version1, version2, predicate_fn) -> bool:
"""Compares two semantic version strings.
Args:
version1: The first semantic version string.
version2: The second semantic version string.
predicate_fn: A function that takes an integer and returns a boolean.
Returns:
bool: The result of the predicate function.
"""
try:
v1_parts = [int(part) for part in version1.split('.')]
v2_parts = [int(part) for part in version2.split('.')]
max_length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (max_length - len(v1_parts)))
v2_parts.extend([0] * (max_length - len(v2_parts)))

for part1, part2 in zip(v1_parts, v2_parts):
if part1 < part2:
return predicate_fn(-1)
if part1 > part2:
return predicate_fn(1)
return predicate_fn(0)
except ValueError:
logger.warning("Invalid semantic version format for comparison.")
return False


async def get_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None):
"""Initializes a new ServerTemplate instance and fetches the server template.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_remote_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_rc_instance_evaluate_between_approx(self):
assert truthy_assignments <= 20000 + tolerance
self.tear_down()

def test_rc_instance_evaluate_between_interquartile_range_approx(self):
def test_rc_instance_evaluate_between_interquartile_range_accuracy(self):
self.set_up()
condition = {
'name': 'is_true',
Expand Down

0 comments on commit 55f2a0a

Please sign in to comment.