diff --git a/README.md b/README.md index ee1feb5f..29eb4cbb 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,16 @@ class ProductVariables(BaseVariables): @select_rule_variable(options=Products.top_holiday_items()) def goes_well_with(self): return products.related_products + + @numeric_rule_variable(params=[{'fieldType': FIELD_NUMERIC, + 'name': 'days, + 'label': 'Days'}]) + def orders_sold_in_last_x_days(self, days): + count = 0 + for order in self.product.orders: + if (datetime.date.today() - order.date_sold).days < days: + count += 1 + return count ``` ### 2. Define your set of actions @@ -146,7 +156,23 @@ rules = [ { "name": "order_more", "fields":[{"name":"number_to_order", "value": 40}]} ] +}, + +# orders_sold_in_last_x_days(5) > 10 +{ "conditions": { "all": [ + { "name": "orders_sold_in_last_x_days", + "operator": "greater_than", + "value": 10, + "params": {"days": 5}, + } + ]}, + "actions": [ + { "name": "order_more", + "fields": [{"name": "number_to_order", "value": 40}] + } + ] }] + ``` ### Export the available variables, operators and actions @@ -165,15 +191,23 @@ that returns { "name": "expiration_days", "label": "Days until expiration", "field_type": "numeric", - "options": []}, + "options": [], + "params": []}, { "name": "current_month", "label": "Current Month", "field_type": "string", - "options": []}, + "options": [], + "params": []}, { "name": "goes_well_with", "label": "Goes Well With", "field_type": "select", - "options": ["Eggnog", "Cookies", "Beef Jerkey"]} + "options": ["Eggnog", "Cookies", "Beef Jerkey"], + "params": []}, + { "name": "orders_sold_in_last_x_days", + "label": "Orders Sold In Last X Days", + "field_type": "numeric", + "options": [], + "params": [{"fieldType": "numeric", "name": "days", "label": "Days"}]} ], "actions": [ { "name": "put_on_sale", @@ -227,6 +261,7 @@ The type represents the type of the value that will be returned for the variable All decorators can optionally take the arguments: - `label` - A human-readable label to show on the frontend. By default we just split the variable name on underscores and capitalize the words. - `cache_result` - Whether to cache the value of the variable for this instance of the variable container object. Defaults to `True` (to avoid re-doing expensive DB queries or computations if you have many rules based on the same variables). +- `params` - A list of parameters that will be passed to the variable when its value is calculated. The list elements should be dictionaries with a `fieldType` to specify the type and `name` that corresponds to an argument of the variable function. The available types and decorators are: diff --git a/business_rules/engine.py b/business_rules/engine.py index eb3c00ad..d39bec94 100644 --- a/business_rules/engine.py +++ b/business_rules/engine.py @@ -51,10 +51,11 @@ def check_condition(condition, defined_variables): object must have a variable defined for any variables in this condition. """ name, op, value = condition['name'], condition['operator'], condition['value'] - operator_type = _get_variable_value(defined_variables, name) + params = condition.get('params') or {} + operator_type = _get_variable_value(defined_variables, name, params) return _do_operator_comparison(operator_type, op, value) -def _get_variable_value(defined_variables, name): +def _get_variable_value(defined_variables, name, params): """ Call the function provided on the defined_variables object with the given name (raise exception if that doesn't exist) and casts it to the specified type. @@ -65,7 +66,7 @@ def fallback(*args, **kwargs): raise AssertionError("Variable {0} is not defined in class {1}".format( name, defined_variables.__class__.__name__)) method = getattr(defined_variables, name, fallback) - val = method() + val = method(**params) return method.field_type(val) def _do_operator_comparison(operator_type, operator_name, comparison_value): diff --git a/business_rules/fields.py b/business_rules/fields.py index a1c64240..efc7749d 100644 --- a/business_rules/fields.py +++ b/business_rules/fields.py @@ -1,5 +1,6 @@ FIELD_TEXT = 'text' FIELD_NUMERIC = 'numeric' +FIELD_DATE = 'date' FIELD_NO_INPUT = 'none' FIELD_SELECT = 'select' FIELD_SELECT_MULTIPLE = 'select_multiple' diff --git a/business_rules/operators.py b/business_rules/operators.py index 817a9356..67239464 100644 --- a/business_rules/operators.py +++ b/business_rules/operators.py @@ -1,10 +1,11 @@ import inspect +import datetime import re from functools import wraps from .six import string_types, integer_types from .fields import (FIELD_TEXT, FIELD_NUMERIC, FIELD_NO_INPUT, - FIELD_SELECT, FIELD_SELECT_MULTIPLE) + FIELD_SELECT, FIELD_SELECT_MULTIPLE, FIELD_DATE) from .utils import fn_name_to_pretty_label, float_to_decimal from decimal import Decimal, Inexact, Context @@ -136,13 +137,52 @@ def less_than_or_equal_to(self, other_numeric): return self.less_than(other_numeric) or self.equal_to(other_numeric) +@export_type +class DateType(BaseType): + + name = 'date' + + def _assert_valid_value_and_cast(self, value): + if isinstance(value, (datetime.date, datetime.datetime)): + return value + elif isinstance(value, string_types): + from dateutil import parser + try: + return parser.parse(value) + except ValueError: + raise AssertionError( + "{0} is not a valid date type.".format(value)) + else: + raise AssertionError("{0} is not a valid date type.".format(value)) + + @type_operator(FIELD_DATE) + def equal_to(self, other_date): + return self.value == other_date + + @type_operator(FIELD_DATE) + def greater_than(self, other_numeric): + return self.value > other_numeric + + @type_operator(FIELD_DATE) + def greater_than_or_equal_to(self, other_numeric): + return self.value >= other_numeric + + @type_operator(FIELD_DATE) + def less_than(self, other_numeric): + return self.value < other_numeric + + @type_operator(FIELD_DATE) + def less_than_or_equal_to(self, other_numeric): + return self.value <= other_numeric + + @export_type class BooleanType(BaseType): name = "boolean" def _assert_valid_value_and_cast(self, value): - if type(value) != bool: + if not isinstance(value, bool): raise AssertionError("{0} is not a valid boolean type". format(value)) return value diff --git a/business_rules/utils.py b/business_rules/utils.py index 078a076f..2da63264 100644 --- a/business_rules/utils.py +++ b/business_rules/utils.py @@ -8,7 +8,7 @@ def export_rule_data(variables, actions): """ export_rule_data is used to export all information about the variables, actions, and operators to the client. This will return a dictionary with three keys: - - variables: a list of all available variables along with their label, type and options + - variables: a list of all available variables along with their label, type, options, and params - actions: a list of all actions along with their label and params - variable_type_operators: a dictionary of all field_types -> list of available operators """ diff --git a/business_rules/variables.py b/business_rules/variables.py index 520099e3..57b8360d 100644 --- a/business_rules/variables.py +++ b/business_rules/variables.py @@ -1,8 +1,10 @@ import inspect from functools import wraps from .utils import fn_name_to_pretty_label +from . import fields from .operators import (BaseType, NumericType, + DateType, StringType, BooleanType, SelectType, @@ -19,17 +21,20 @@ def get_all_variables(cls): 'label': m[1].label, 'field_type': m[1].field_type.name, 'options': m[1].options, + 'params': m[1].params } for m in methods if getattr(m[1], 'is_rule_variable', False)] - -def rule_variable(field_type, label=None, options=None, cache_result=True): +def rule_variable(field_type, label=None, options=None, cache_result=True, params=None): """ Decorator to make a function into a rule variable """ options = options or [] + params = params or [] def wrapper(func): if not (type(field_type) == type and issubclass(field_type, BaseType)): raise AssertionError("{0} is not instance of BaseType in"\ " rule_variable field_type".format(field_type)) + _validate_variable_parameters(func, params) + func.params = params func.field_type = field_type if cache_result: func = _memoize_return_values(func) @@ -40,20 +45,30 @@ def wrapper(func): return func return wrapper -def numeric_rule_variable(label=None): - return rule_variable(NumericType, label=label) -def string_rule_variable(label=None): - return rule_variable(StringType, label=label) +def _rule_variable_wrapper(field_type, label, params=None): + if callable(label): + # Decorator is being called with no args, label is actually the decorated func + return rule_variable(field_type, params=params)(label) + return rule_variable(field_type, label=label, params=params) + +def numeric_rule_variable(label=None, params=None): + return _rule_variable_wrapper(NumericType, label, params=params) + +def date_rule_variable(label=None, params=None): + return _rule_variable_wrapper(DateType, label, params=params) + +def string_rule_variable(label=None, params=None): + return _rule_variable_wrapper(StringType, label, params=params) -def boolean_rule_variable(label=None): - return rule_variable(BooleanType, label=label) +def boolean_rule_variable(label=None, params=None): + return _rule_variable_wrapper(BooleanType, label, params=params) -def select_rule_variable(label=None, options=None): - return rule_variable(SelectType, label=label, options=options) +def select_rule_variable(label=None, options=None, params=None): + return rule_variable(SelectType, label=label, options=options, params=params) -def select_multiple_rule_variable(label=None, options=None): - return rule_variable(SelectMultipleType, label=label, options=options) +def select_multiple_rule_variable(label=None, options=None, params=None): + return rule_variable(SelectMultipleType, label=label, options=options, params=params) def _memoize_return_values(func): """ Simple memoization (cacheing) decorator, copied from @@ -67,3 +82,23 @@ def memf(*args, **kwargs): cache[key] = func(*args, **kwargs) return cache[key] return memf + +def _validate_variable_parameters(func, params): + """ Verifies that the parameters specified are actual parameters for the + function `func`, and that the field types are FIELD_* types in fields. + """ + if params is not None: + # Verify field name is valid + valid_fields = [getattr(fields, f) for f in dir(fields) \ + if f.startswith("FIELD_")] + for param in params: + param_name, field_type = param['name'], param['field_type'] + if param_name not in func.__code__.co_varnames: + raise AssertionError("Unknown parameter name {0} specified for"\ + " variable {1}".format( + param_name, func.__name__)) + + if field_type not in valid_fields: + raise AssertionError("Unknown field type {0} specified for"\ + " variable {1} param {2}".format( + field_type, func.__name__, param_name)) diff --git a/setup.py b/setup.py index cefc6e51..05b0e4e5 100644 --- a/setup.py +++ b/setup.py @@ -12,5 +12,8 @@ author_email='open-source@venmo.com', url='https://github.com/venmo/business-rules', packages=['business_rules'], + extras_require={ + 'DateType': ["dateutil"] + }, license='MIT' ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 6dbf1ace..10be20e0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,7 @@ from business_rules.engine import check_condition from business_rules import export_rule_data from business_rules.actions import rule_action, BaseActions -from business_rules.variables import BaseVariables, string_rule_variable, numeric_rule_variable, boolean_rule_variable +from business_rules.variables import BaseVariables, string_rule_variable, numeric_rule_variable, boolean_rule_variable, date_rule_variable from business_rules.fields import FIELD_TEXT, FIELD_NUMERIC, FIELD_SELECT from . import TestCase @@ -12,10 +12,18 @@ class SomeVariables(BaseVariables): def foo(self): return "foo" + @numeric_rule_variable(params=[{'field_type': FIELD_NUMERIC, 'name': 'x', 'label': 'X'}]) + def x_plus_one(self, x): + return x + 1 + @numeric_rule_variable(label="Diez") def ten(self): return 10 + @date_rule_variable(label="MyDate") + def january_one_2015(self): + return '01/01/2015' + @boolean_rule_variable() def true_bool(self): return True @@ -71,6 +79,13 @@ def test_check_false_condition_happy_path(self): 'value': 'm'} self.assertFalse(check_condition(condition, SomeVariables())) + def test_numeric_variable_with_params(self): + condition = {'name': 'x_plus_one', + 'operator': 'equal_to', + 'value': 10, + 'params': {'x': 9}} + self.assertTrue(check_condition(condition, SomeVariables())) + def test_check_incorrect_method_name(self): condition = {'name': 'food', 'operator': 'equal_to', @@ -112,18 +127,32 @@ def test_export_rule_data(self): ]) self.assertEqual(all_data.get("variables"), - [{"name": "foo", - "label": "Foo", - "field_type": "string", - "options": []}, - {"name": "ten", - "label": "Diez", - "field_type": "numeric", - "options": []}, - {'name': 'true_bool', + [{'field_type': 'string', + 'label': 'Foo', + 'name': 'foo', + 'options': [], + 'params': []}, + {'field_type': 'date', + 'label': 'MyDate', + 'name': 'january_one_2015', + 'options': [], + 'params': []}, + {'field_type': 'numeric', + 'label': 'Diez', + 'name': 'ten', + 'options': [], + 'params': []}, + {'field_type': 'boolean', 'label': 'True Bool', - 'field_type': 'boolean', - 'options': []}]) + 'name': 'true_bool', + 'options': [], + 'params': []}, + {'field_type': 'numeric', + 'label': 'X Plus One', + 'name': 'x_plus_one', + 'options': [], + 'params': [{'field_type': 'numeric', 'label': 'X', 'name': 'x'}]}] + ) self.assertEqual(all_data.get("variable_type_operators"), {'boolean': [{'input_type': 'none', @@ -132,6 +161,21 @@ def test_export_rule_data(self): {'input_type': 'none', 'label': 'Is True', 'name': 'is_true'}], + 'date': [{'input_type': 'date', + 'label': 'Equal To', + 'name': 'equal_to'}, + {'input_type': 'date', + 'label': 'Greater Than', + 'name': 'greater_than'}, + {'input_type': 'date', + 'label': 'Greater Than Or Equal To', + 'name': 'greater_than_or_equal_to'}, + {'input_type': 'date', + 'label': 'Less Than', + 'name': 'less_than'}, + {'input_type': 'date', + 'label': 'Less Than Or Equal To', + 'name': 'less_than_or_equal_to'}], 'numeric': [{'input_type': 'numeric', 'label': 'Equal To', 'name': 'equal_to'}, diff --git a/tests/test_variables.py b/tests/test_variables.py index 2490542f..dbe08c00 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -1,5 +1,6 @@ from . import TestCase from business_rules.utils import fn_name_to_pretty_label +from business_rules.fields import FIELD_NUMERIC from business_rules.variables import (rule_variable, numeric_rule_variable, string_rule_variable, @@ -34,13 +35,15 @@ def test_rule_variable_decorator_internals(self): """ Make sure that the expected attributes are attached to a function by the variable decorators. """ - def some_test_function(self): pass - wrapper = rule_variable(StringType, 'Foo Name', options=['op1', 'op2']) + def some_test_function(self, param1): pass + params = [{'field_type': FIELD_NUMERIC, 'name': 'param1', 'label': 'Param1'}] + wrapper = rule_variable(StringType, 'Foo Name', options=['op1', 'op2'], params=params) func = wrapper(some_test_function) self.assertTrue(func.is_rule_variable) self.assertEqual(func.label, 'Foo Name') self.assertEqual(func.field_type, StringType) self.assertEqual(func.options, ['op1', 'op2']) + self.assertEqual(func.params, params) def test_rule_variable_works_as_decorator(self): @rule_variable(StringType, 'Blah') @@ -52,6 +55,28 @@ def test_rule_variable_decorator_auto_fills_label(self): def some_test_function(self): pass self.assertTrue(some_test_function.label, 'Some Test Function') + def test_rule_variable_doesnt_allow_unknown_field_types(self): + """ Tests that the variable decorator throws an error if a param + is defined with an invalid field type. + """ + params = [{'field_type': 'blah', 'name': 'foo', 'label': 'Foo'}] + err_string = "Unknown field type blah specified for variable "\ + "some_test_function param foo" + with self.assertRaisesRegexp(AssertionError, err_string): + @rule_variable(StringType, params=params) + def some_test_function(self, foo): pass + + def test_rule_variable_doesnt_allow_unknown_parameter_name(self): + """ Tests that decorator throws an error if a param name does not match + an argument in the function definition. + """ + params = [{'field_type': FIELD_NUMERIC, 'name': 'bar', 'label': 'Bar'}] + err_string = "Unknown parameter name bar specified for variable "\ + "some_test_function" + with self.assertRaisesRegexp(AssertionError, err_string): + @rule_variable(StringType, params=params) + def some_test_function(self, foo): pass + def test_rule_variable_decorator_caches_value(self): foo = 1 @rule_variable(NumericType) @@ -69,32 +94,59 @@ def foo_func(): self.assertEqual(foo_func(), 1) foo = 2 self.assertEqual(foo_func(), 2) - + ### ### rule_variable wrappers for each variable type ### def test_numeric_rule_variable(self): - @numeric_rule_variable() + @numeric_rule_variable('My Label') def numeric_var(): pass - + + self.assertTrue(getattr(numeric_var, 'is_rule_variable')) + self.assertEqual(getattr(numeric_var, 'field_type'), NumericType) + self.assertEqual(getattr(numeric_var, 'label'), 'My Label') + + def test_numeric_rule_variable_no_parens(self): + + @numeric_rule_variable + def numeric_var(): pass + self.assertTrue(getattr(numeric_var, 'is_rule_variable')) self.assertEqual(getattr(numeric_var, 'field_type'), NumericType) def test_string_rule_variable(self): - @string_rule_variable() + @string_rule_variable(label='My Label') + def string_var(): pass + + self.assertTrue(getattr(string_var, 'is_rule_variable')) + self.assertEqual(getattr(string_var, 'field_type'), StringType) + self.assertEqual(getattr(string_var, 'label'), 'My Label') + + def test_string_rule_variable_no_parens(self): + + @string_rule_variable def string_var(): pass - + self.assertTrue(getattr(string_var, 'is_rule_variable')) self.assertEqual(getattr(string_var, 'field_type'), StringType) - + def test_boolean_rule_variable(self): - @boolean_rule_variable() + @boolean_rule_variable(label='My Label') + def boolean_var(): pass + + self.assertTrue(getattr(boolean_var, 'is_rule_variable')) + self.assertEqual(getattr(boolean_var, 'field_type'), BooleanType) + self.assertEqual(getattr(boolean_var, 'label'), 'My Label') + + def test_boolean_rule_variable_no_parens(self): + + @boolean_rule_variable def boolean_var(): pass - + self.assertTrue(getattr(boolean_var, 'is_rule_variable')) self.assertEqual(getattr(boolean_var, 'field_type'), BooleanType) @@ -103,7 +155,7 @@ def test_select_rule_variable(self): options = {'foo':'bar'} @select_rule_variable(options=options) def select_var(): pass - + self.assertTrue(getattr(select_var, 'is_rule_variable')) self.assertEqual(getattr(select_var, 'field_type'), SelectType) self.assertEqual(getattr(select_var, 'options'), options) @@ -113,7 +165,7 @@ def test_select_multiple_rule_variable(self): options = {'foo':'bar'} @select_multiple_rule_variable(options=options) def select_multiple_var(): pass - + self.assertTrue(getattr(select_multiple_var, 'is_rule_variable')) self.assertEqual(getattr(select_multiple_var, 'field_type'), SelectMultipleType) self.assertEqual(getattr(select_multiple_var, 'options'), options) diff --git a/tests/test_variables_class.py b/tests/test_variables_class.py index 1a4bf7da..81ec616f 100644 --- a/tests/test_variables_class.py +++ b/tests/test_variables_class.py @@ -1,5 +1,6 @@ from business_rules.variables import BaseVariables, rule_variable from business_rules.operators import StringType +from business_rules.fields import FIELD_TEXT from . import TestCase class VariablesClassTests(TestCase): @@ -14,8 +15,8 @@ def test_get_all_variables(self): """ class SomeVariables(BaseVariables): - @rule_variable(StringType) - def this_is_rule_1(self): + @rule_variable(StringType, params=[{'field_type': FIELD_TEXT, 'name': 'foo', 'label': 'Foo'}]) + def this_is_rule_1(self, foo): return "blah" def non_rule(self): @@ -27,6 +28,7 @@ def non_rule(self): self.assertEqual(vars[0]['label'], 'This Is Rule 1') self.assertEqual(vars[0]['field_type'], 'string') self.assertEqual(vars[0]['options'], []) + self.assertEqual(vars[0]['params'], [{'field_type': FIELD_TEXT, 'name': 'foo', 'label': 'Foo'}]) # should work on an instance of the class too self.assertEqual(len(SomeVariables().get_all_variables()), 1)