diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 79f2c12327..ed1d466c6d 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -37,3 +37,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms ## tree - Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches. + +## utils + +- `utils.TimeRolling` now works correctly if two samples with the same timestamp are added in a row. diff --git a/river/utils/rolling.py b/river/utils/rolling.py index 849f88ba0b..6e01eb232b 100644 --- a/river/utils/rolling.py +++ b/river/utils/rolling.py @@ -125,12 +125,15 @@ class TimeRolling(BaseRolling): def __init__(self, obj: Rollable, period: dt.timedelta): super().__init__(obj) self.period = period - self._events: list[tuple[dt.datetime, typing.Any]] = [] + self._timestamps: list[dt.datetime] = [] + self._datum: list[typing.Any] = [] self._latest = dt.datetime(1, 1, 1) def update(self, *args, t: dt.datetime, **kwargs): self.obj.update(*args, **kwargs) - bisect.insort_left(self._events, (t, (args, kwargs))) + i = bisect.bisect_left(self._timestamps, t) + self._timestamps.insert(i, t) + self._datum.insert(i, (args, kwargs)) # There will only be events to revert if the new event if younger than the previously seen # youngest event @@ -138,7 +141,7 @@ def update(self, *args, t: dt.datetime, **kwargs): self._latest = t i = 0 - for ti, (argsi, kwargsi) in self._events: + for ti, (argsi, kwargsi) in zip(self._timestamps, self._datum): if ti > t - self.period: break self.obj.revert(*argsi, **kwargsi) @@ -146,6 +149,7 @@ def update(self, *args, t: dt.datetime, **kwargs): # Remove expired events if i > 0: - self._events = self._events[i:] + self._timestamps = self._timestamps[i:] + self._datum = self._datum[i:] return self diff --git a/river/utils/test_rolling.py b/river/utils/test_rolling.py index 6d1cedb4de..fef2b764ba 100644 --- a/river/utils/test_rolling.py +++ b/river/utils/test_rolling.py @@ -4,7 +4,7 @@ import pytest -from river import stats, utils +from river import proba, stats, utils def test_with_counter(): @@ -38,3 +38,15 @@ def test_rolling_with_not_rollable(): def test_time_rolling_with_not_rollable(): with pytest.raises(ValueError): utils.TimeRolling(stats.Quantile(), period=dt.timedelta(seconds=10)) + + +def test_issue_1343(): + """ + + https://github.com/online-ml/river/issues/1343 + + """ + rmean = utils.TimeRolling(proba.MultivariateGaussian(), period=dt.timedelta(microseconds=1)) + t = dt.datetime.utcnow() + rmean.update({"a": 0}, t=t) + rmean.update({"a": 1}, t=t)