diff --git a/src/omlt/linear_tree/lt_definition.py b/src/omlt/linear_tree/lt_definition.py index e45274fd..74882666 100644 --- a/src/omlt/linear_tree/lt_definition.py +++ b/src/omlt/linear_tree/lt_definition.py @@ -1,5 +1,6 @@ import numpy as np import lineartree +import sklearn class LinearTreeDefinition: @@ -266,7 +267,10 @@ def _parse_tree_data(model, input_bounds): # Include checks to ensure that the input dict is the model summary which # is obtained by calling the summary() method contained within the # linear-tree package (e.g. dict = model.summary()) - if isinstance(model, lineartree.lineartree.LinearTreeRegressor) is True: + if ( + isinstance(model, lineartree.lineartree.LinearTreeRegressor) is True + or isinstance(model, lineartree.lineartree.LinearTreeClassifier) is True + ): leaves = model.summary(only_leaves=True) splits = model.summary() elif isinstance(model, dict) is True: @@ -294,11 +298,43 @@ def _parse_tree_data(model, input_bounds): raise TypeError("Model entry must be dict or linear-tree instance") # This loop adds keys for the slopes and intercept and removes the leaf - # keys in the splits dictionary + # keys in the splits dictionary. For LinearTreeClassifier, check if + # the model in the leaf is a DummyClassifier. If so, use the information + # in the prior to determine whether the intercept is 1, or -1. Otherwise + # use the slope/intercept information in the RidgeClassifier or + # LinearTreeRegressor classes for leaf in leaves: del splits[leaf] - leaves[leaf]["slope"] = list(leaves[leaf]["models"].coef_) - leaves[leaf]["intercept"] = leaves[leaf]["models"].intercept_ + if isinstance(model, lineartree.lineartree.LinearTreeClassifier) is True: + num_classes = len(leaves[leaf]["classes"]) + else: + num_classes = 999 + + if num_classes < 2: + class_val = int(leaves[leaf]["classes"][0]) + leaves[leaf]["slope"] = list(np.zeros(len(input_bounds.keys()))) + if class_val == 0: + leaves[leaf]["intercept"] = -1 + else: + leaves[leaf]["intercept"] = 1 + else: + model_in_leaf = leaves[leaf]["models"] + if isinstance(model_in_leaf, sklearn.dummy.DummyClassifier): + prior = model_in_leaf.class_prior_ + leaves[leaf]["slope"] = list(np.zeros(len(input_bounds.keys()))) + if len(prior) < 2: + pred_val = int(model_in_leaf.predict([0])[0]) + if pred_val == 0: + leaves[leaf]["intercept"] = -1 + else: + leaves[leaf]["intercept"] = 1 + elif prior[0] <= prior[1]: + leaves[leaf]["intercept"] = 1 + else: + leaves[leaf]["intercept"] = -1 + else: + leaves[leaf]["slope"] = list(model_in_leaf.coef_.reshape((-1,))) + leaves[leaf]["intercept"] = model_in_leaf.intercept_.reshape((-1,))[0] # This loop creates an parent node id entry for each node in the tree for split in splits: