Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add balanced class functionality #53

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion skrules/skope_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class SkopeRules(BaseEstimator):
`ceil(min_samples_split * n_samples)` are the minimum
number of samples for each split.

class_weight: dict, list of dict or "balanced", default=None
The weights to be used for the DecisionTreeClassifier. Weights associated
with classes in the form {class_label: weight}. If None,
all classes are supposed to have weight one.

n_jobs : integer, optional (default=1)
The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.
Expand Down Expand Up @@ -150,6 +155,7 @@ def __init__(self,
max_depth_duplication=None,
max_features=1.,
min_samples_split=2,
class_weight=None,
n_jobs=1,
random_state=None,
verbose=0):
Expand All @@ -164,6 +170,7 @@ def __init__(self,
self.max_depth = max_depth
self.max_depth_duplication = max_depth_duplication
self.max_features = max_features
self.class_weight = class_weight
self.min_samples_split = min_samples_split
self.n_jobs = n_jobs
self.random_state = random_state
Expand Down Expand Up @@ -270,7 +277,8 @@ def fit(self, X, y, sample_weight=None):
base_estimator=DecisionTreeClassifier(
max_depth=max_depth,
max_features=self.max_features,
min_samples_split=self.min_samples_split),
min_samples_split=self.min_samples_split,
class_weight=self.class_weight),
n_estimators=self.n_estimators,
max_samples=self.max_samples_,
max_features=self.max_samples_features,
Expand Down
3 changes: 3 additions & 0 deletions skrules/tests/test_skope_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def test_skope_rules():
recall_min=0.,
precision_min=0.).fit(X_train, y_train).predict(X_test)

# with additional class weights
SkopeRules(n_estimators=50, class_weight='balanced').fit(X_train, y_train).predict(X_test)


def test_skope_rules_error():
"""Test that it gives proper exception on deficient input."""
Expand Down