Skip to content

Commit

Permalink
Split events into timestamps and datum (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford authored Oct 21, 2023
1 parent 424cc38 commit 4fc1575
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
## tree

- Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches.

## utils

- `utils.TimeRolling` now works correctly if two samples with the same timestamp are added in a row.
12 changes: 8 additions & 4 deletions river/utils/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,31 @@ class TimeRolling(BaseRolling):
def __init__(self, obj: Rollable, period: dt.timedelta):
super().__init__(obj)
self.period = period
self._events: list[tuple[dt.datetime, typing.Any]] = []
self._timestamps: list[dt.datetime] = []
self._datum: list[typing.Any] = []
self._latest = dt.datetime(1, 1, 1)

def update(self, *args, t: dt.datetime, **kwargs):
self.obj.update(*args, **kwargs)
bisect.insort_left(self._events, (t, (args, kwargs)))
i = bisect.bisect_left(self._timestamps, t)
self._timestamps.insert(i, t)
self._datum.insert(i, (args, kwargs))

# There will only be events to revert if the new event if younger than the previously seen
# youngest event
if t > self._latest:
self._latest = t

i = 0
for ti, (argsi, kwargsi) in self._events:
for ti, (argsi, kwargsi) in zip(self._timestamps, self._datum):
if ti > t - self.period:
break
self.obj.revert(*argsi, **kwargsi)
i += 1

# Remove expired events
if i > 0:
self._events = self._events[i:]
self._timestamps = self._timestamps[i:]
self._datum = self._datum[i:]

return self
14 changes: 13 additions & 1 deletion river/utils/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from river import stats, utils
from river import proba, stats, utils


def test_with_counter():
Expand Down Expand Up @@ -38,3 +38,15 @@ def test_rolling_with_not_rollable():
def test_time_rolling_with_not_rollable():
with pytest.raises(ValueError):
utils.TimeRolling(stats.Quantile(), period=dt.timedelta(seconds=10))


def test_issue_1343():
"""
https://github.com/online-ml/river/issues/1343
"""
rmean = utils.TimeRolling(proba.MultivariateGaussian(), period=dt.timedelta(microseconds=1))
t = dt.datetime.utcnow()
rmean.update({"a": 0}, t=t)
rmean.update({"a": 1}, t=t)

0 comments on commit 4fc1575

Please sign in to comment.