Skip to content

Commit

Permalink
Modified version conflict.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxjohn committed Dec 21, 2023
1 parent 61019f9 commit 6e4c2ff
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hyperts/framework/stats/sktime_ex/_sfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _igb(self, dft, y):
breakpoints = np.zeros((self.word_length, self.alphabet_size))
clf = DecisionTreeClassifier(
criterion="entropy",
max_depth=np.log2(self.alphabet_size),
max_depth=int(np.floor(np.log2(self.alphabet_size))),
max_leaf_nodes=self.alphabet_size,
random_state=1,
)
Expand Down
18 changes: 17 additions & 1 deletion hyperts/framework/stats/sktime_ex/_tsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ def __init__(
# We need to add is-fitted state when inheriting from scikit-learn
self._is_fitted = False

@property
def _estimator(self):
"""Access first parameter in self, self inheriting from sklearn BaseForest.
The attribute was renamed from base_estimator to estimator in sklearn 1.2.0.
"""
import sklearn
from packaging.specifiers import SpecifierSet

sklearn_version = sklearn.__version__

if sklearn_version in SpecifierSet(">=1.2.0"):
return self.estimator
else:
return self.base_estimator

def fit(self, X, y):
"""Build a forest of trees from the training set (X, y).
Expand Down Expand Up @@ -110,7 +126,7 @@ def fit(self, X, y):

self.estimators_ = Parallel(n_jobs=n_jobs)(
delayed(_fit_estimator)(
_clone_estimator(self.base_estimator, rng), X, y, self.intervals_[i]
_clone_estimator(self._estimator, rng), X, y, self.intervals_[i]
)
for i in range(self.n_estimators)
)
Expand Down

0 comments on commit 6e4c2ff

Please sign in to comment.