Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Saulo Martiello Mastelini authored and agriyakhetarpal committed Nov 18, 2024
1 parent 3d239e3 commit 5220007
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
16 changes: 12 additions & 4 deletions river/ensemble/streaming_random_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def _generate_subspaces(self, features: list):
# Generate n_models subspaces from all possible
# feature combinations of size k
self._subspaces = []
for i, combination in enumerate(itertools.cycle(itertools.combinations(features, k))):
for i, combination in enumerate(
itertools.cycle(itertools.combinations(features, k))
):
if i == self.n_models:
break
self._subspaces.append(list(combination))
Expand All @@ -171,7 +173,8 @@ def _generate_subspaces(self, features: list):
# subspaces without worrying about repetitions.
else:
self._subspaces = [
random_subspace(all_features=features, k=k, rng=self._rng) for _ in range(self.n_models)
random_subspace(all_features=features, k=k, rng=self._rng)
for _ in range(self.n_models)
]
else:
# k == 0 or k > n_features (subspace size is larger than the
Expand All @@ -180,8 +183,13 @@ def _generate_subspaces(self, features: list):

def _init_ensemble(self, features: list):
self._generate_subspaces(features=features)
subspace_indexes = list(range(self.n_models)) # For matching subspaces with ensemble members
if self.training_method == self._TRAIN_RANDOM_PATCHES or self.training_method == self._TRAIN_RANDOM_SUBSPACES:
subspace_indexes = list(
range(self.n_models)
) # For matching subspaces with ensemble members
if (
self.training_method == self._TRAIN_RANDOM_PATCHES
or self.training_method == self._TRAIN_RANDOM_SUBSPACES
):
# Shuffle indexes
self._rng.shuffle(subspace_indexes)

Expand Down
9 changes: 7 additions & 2 deletions river/forest/adaptive_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
y_true=y,
y_pred=(
model.predict_proba_one(x)
if isinstance(self.metric, metrics.base.ClassificationMetric) and not self.metric.requires_labels
if isinstance(self.metric, metrics.base.ClassificationMetric)
and not self.metric.requires_labels
else y_pred
),
)
Expand All @@ -188,7 +189,11 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
self._warning_tracker[i] += 1

if not self._drift_detection_disabled:
drift_input = drift_input if drift_input is not None else self._drift_detector_input(i, y, y_pred)
drift_input = (
drift_input
if drift_input is not None
else self._drift_detector_input(i, y, y_pred)
)
self._drift_detectors[i].update(drift_input)

if self._drift_detectors[i].drift_detected:
Expand Down
5 changes: 4 additions & 1 deletion river/rules/amrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def learn_one(self, x: dict, y: base.typing.RegTarget, w: int = 1):
self._default_rule.learn_one(x, y, w)

expanded = False
if self._default_rule.total_weight - self._default_rule.last_expansion_attempt_at >= self.n_min:
if (
self._default_rule.total_weight - self._default_rule.last_expansion_attempt_at
>= self.n_min
):
updated_rule, expanded = self._default_rule.expand(self.delta, self.tau)

if expanded:
Expand Down
4 changes: 3 additions & 1 deletion river/tree/hoeffding_adaptive_tree_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def _new_leaf(self, initial_stats=None, parent=None, is_active=True):

return new_adaptive

def _branch_selector(self, numerical_feature=True, multiway_split=False) -> type[AdaBranchRegressor]:
def _branch_selector(
self, numerical_feature=True, multiway_split=False
) -> type[AdaBranchRegressor]:
"""Create a new split node."""
if numerical_feature:
if not multiway_split:
Expand Down

0 comments on commit 5220007

Please sign in to comment.