Skip to content

Commit

Permalink
Fixes #1248 (#1337)
Browse files Browse the repository at this point in the history
* bug fix

* update affected test
  • Loading branch information
smastelini authored Oct 11, 2023
1 parent 3a78523 commit 23143c6
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
9 changes: 6 additions & 3 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
## clustering

- Add fixes to `cluster.DBSTREAM` algorithm, including:
- Addition of the `-` sign before the `fading_factor` in accordance with the algorithm 2 proposed by Hashler and Bolanos (2016) to allow clusters with low weights to be removed.
- The new `micro_cluster` is added with the key derived from the maximum key of the existing micro clusters. If the set of micro clusters is still empty (`len = 0`), a new micro cluster is added with key 0.
- Addition of the `-` sign before the `fading_factor` in accordance with the algorithm 2 proposed by Hashler and Bolanos (2016) to allow clusters with low weights to be removed.
- The new `micro_cluster` is added with the key derived from the maximum key of the existing micro clusters. If the set of micro clusters is still empty (`len = 0`), a new micro cluster is added with key 0.
- `cluster_is_up_to_date` is set to `True` at the end of the `self._recluster()` function.


## datasets

- Added `datasets.WebTraffic`, which is a dataset that counts the occurrences of events on a website. It is a multi-output regression dataset with two outputs.
Expand All @@ -31,3 +30,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
## proba

- Added `_from_state` method to `proba.MultivariateGaussian` to warm start from previous knowledge.

## tree

- Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches.
2 changes: 1 addition & 1 deletion river/ensemble/streaming_random_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class SRPClassifier(BaseSRPEnsemble, base.Classifier):
>>> metric = metrics.Accuracy()
>>> evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 72.77%
Accuracy: 71.97%
Notes
-----
Expand Down
2 changes: 1 addition & 1 deletion river/tree/extremely_fast_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _reevaluate_best_split(self, node, parent, branch_index, **kwargs):

# update EFDT
if parent is None:
# Root case : replace the root node by a new split node
# Root case : replace the root node by the new node
self._root = best_split
else:
parent.children[branch_index] = best_split
Expand Down
15 changes: 10 additions & 5 deletions river/tree/splitter/nominal_splitter_classif.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections
import functools

from ..utils import BranchFactory
from .base import Splitter
Expand All @@ -18,9 +17,7 @@ def __init__(self):
super().__init__()
self._total_weight_observed = 0.0
self._missing_weight_observed = 0.0
self._att_dist_per_class = collections.defaultdict(
functools.partial(collections.defaultdict, float)
)
self._att_dist_per_class = collections.defaultdict(dict)
self._att_values = set()

@property
Expand All @@ -32,12 +29,20 @@ def update(self, att_val, target_val, sample_weight):
self._missing_weight_observed += sample_weight
else:
self._att_values.add(att_val)
self._att_dist_per_class[target_val][att_val] += sample_weight

try:
self._att_dist_per_class[target_val][att_val] += sample_weight
except KeyError:
self._att_dist_per_class[target_val][att_val] = sample_weight

self._total_weight_observed += sample_weight

def cond_proba(self, att_val, target_val):
class_dist = self._att_dist_per_class[target_val]

if att_val not in class_dist:
return 0.0

value = class_dist[att_val]
try:
return value / sum(class_dist.values())
Expand Down

0 comments on commit 23143c6

Please sign in to comment.