Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow for linear tree classifiers #135

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
44 changes: 40 additions & 4 deletions src/omlt/linear_tree/lt_definition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import lineartree
import sklearn


class LinearTreeDefinition:
Expand Down Expand Up @@ -266,7 +267,10 @@
# Include checks to ensure that the input dict is the model summary which
# is obtained by calling the summary() method contained within the
# linear-tree package (e.g. dict = model.summary())
if isinstance(model, lineartree.lineartree.LinearTreeRegressor) is True:
if (
isinstance(model, lineartree.lineartree.LinearTreeRegressor) is True
or isinstance(model, lineartree.lineartree.LinearTreeClassifier) is True
):
leaves = model.summary(only_leaves=True)
splits = model.summary()
elif isinstance(model, dict) is True:
Expand Down Expand Up @@ -294,11 +298,43 @@
raise TypeError("Model entry must be dict or linear-tree instance")

# This loop adds keys for the slopes and intercept and removes the leaf
# keys in the splits dictionary
# keys in the splits dictionary. For LinearTreeClassifier, check if
# the model in the leaf is a DummyClassifier. If so, use the information
# in the prior to determine whether the intercept is 1, or -1. Otherwise
# use the slope/intercept information in the RidgeClassifier or
# LinearTreeRegressor classes
for leaf in leaves:
del splits[leaf]
leaves[leaf]["slope"] = list(leaves[leaf]["models"].coef_)
leaves[leaf]["intercept"] = leaves[leaf]["models"].intercept_
if isinstance(model, lineartree.lineartree.LinearTreeClassifier) is True:
num_classes = len(leaves[leaf]["classes"])

Check warning on line 309 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L309

Added line #L309 was not covered by tests
else:
num_classes = 999

if num_classes < 2:
class_val = int(leaves[leaf]["classes"][0])
leaves[leaf]["slope"] = list(np.zeros(len(input_bounds.keys())))

Check warning on line 315 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L314-L315

Added lines #L314 - L315 were not covered by tests
if class_val == 0:
leaves[leaf]["intercept"] = -1

Check warning on line 317 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L317

Added line #L317 was not covered by tests
else:
leaves[leaf]["intercept"] = 1

Check warning on line 319 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L319

Added line #L319 was not covered by tests
else:
model_in_leaf = leaves[leaf]["models"]
if isinstance(model_in_leaf, sklearn.dummy.DummyClassifier):
prior = model_in_leaf.class_prior_
leaves[leaf]["slope"] = list(np.zeros(len(input_bounds.keys())))

Check warning on line 324 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L323-L324

Added lines #L323 - L324 were not covered by tests
if len(prior) < 2:
pred_val = int(model_in_leaf.predict([0])[0])

Check warning on line 326 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L326

Added line #L326 was not covered by tests
if pred_val == 0:
leaves[leaf]["intercept"] = -1

Check warning on line 328 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L328

Added line #L328 was not covered by tests
else:
leaves[leaf]["intercept"] = 1

Check warning on line 330 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L330

Added line #L330 was not covered by tests
elif prior[0] <= prior[1]:
leaves[leaf]["intercept"] = 1

Check warning on line 332 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L332

Added line #L332 was not covered by tests
else:
leaves[leaf]["intercept"] = -1

Check warning on line 334 in src/omlt/linear_tree/lt_definition.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/linear_tree/lt_definition.py#L334

Added line #L334 was not covered by tests
else:
leaves[leaf]["slope"] = list(model_in_leaf.coef_.reshape((-1,)))
leaves[leaf]["intercept"] = model_in_leaf.intercept_.reshape((-1,))[0]

# This loop creates an parent node id entry for each node in the tree
for split in splits:
Expand Down
Loading