Skip to content

Commit

Permalink
Add test for real data
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvo committed Jun 7, 2024
1 parent d8b7b36 commit 5289425
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 77 deletions.
11 changes: 3 additions & 8 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
Generic,
NewType,
overload,
Iterable, NamedTuple,
Iterable,
NamedTuple,
)


Expand Down Expand Up @@ -876,9 +877,6 @@ class DatasetFlag(Flag):
BALL_STATE = 2





@dataclass
class DataRecord(ABC):
"""
Expand Down Expand Up @@ -906,10 +904,7 @@ def record_id(self) -> Union[int, str]:

@property
def abs_time(self) -> AbsTime:
return AbsTime(
period=self.period,
timestamp=self.timestamp
)
return AbsTime(period=self.period, timestamp=self.timestamp)

def set_refs(
self,
Expand Down
42 changes: 24 additions & 18 deletions kloppy/domain/models/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,22 @@ def contains(self, timestamp: datetime):
"This method can only be used when start_timestamp and end_timestamp are a datetime"
)

@property
def start_abs_time(self) -> "AbsTime":
return AbsTime(period=self, timestamp=self.start_timestamp)

@property
def end_abs_time(self) -> "AbsTime":
return AbsTime(period=self, timestamp=self.end_timestamp)

@property
def duration(self) -> timedelta:
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
return isinstance(other, Period) and other.id == self.id

def __lt__(self, other: 'Period'):
def __lt__(self, other: "Period"):
return self.id < other.id

def __ge__(self, other):
Expand All @@ -72,18 +80,20 @@ def set_refs(

@dataclass
class AbsTime:
period: 'Period'
period: "Period"
timestamp: timedelta

@overload
def __sub__(self, other: timedelta) -> 'AbsTime':
def __sub__(self, other: timedelta) -> "AbsTime":
...

@overload
def __sub__(self, other: 'AbsTime') -> timedelta:
def __sub__(self, other: "AbsTime") -> timedelta:
...

def __sub__(self, other: Union['AbsTime', timedelta]) -> Union['AbsTime', timedelta]:
def __sub__(
self, other: Union["AbsTime", timedelta]
) -> Union["AbsTime", timedelta]:
"""
Subtract a timedelta or AbsTime from the current AbsTime.
Expand All @@ -100,16 +110,14 @@ def __sub__(self, other: Union['AbsTime', timedelta]) -> Union['AbsTime', timede
if not current_period.prev_period:
# We reached start of the match, lets just return start itself
return AbsTime(
period=current_period,
timestamp=timedelta(0)
period=current_period, timestamp=timedelta(0)
)

current_period = current_period.prev_period
current_timestamp = current_period.duration

return AbsTime(
period=current_period,
timestamp=current_timestamp - other
period=current_period, timestamp=current_timestamp - other
)

elif isinstance(other, AbsTime):
Expand All @@ -124,34 +132,32 @@ def __sub__(self, other: Union['AbsTime', timedelta]) -> Union['AbsTime', timede
else:
return -other.__sub__(self)
else:
raise ValueError(f'Cannot subtract {other}')
raise ValueError(f"Cannot subtract {other}")

def __add__(self, other: timedelta) -> 'AbsTime':
def __add__(self, other: timedelta) -> "AbsTime":
assert isinstance(other, timedelta)
current_timestamp = self.timestamp
current_period = self.period
while other > current_period.duration:
# Subtract time left in this period

other -= (current_period.duration - current_timestamp)
other -= current_period.duration - current_timestamp
if not current_period.next_period:
# We reached start of the match, lets just return start itself
return AbsTime(
period=current_period,
timestamp=current_period.duration
period=current_period, timestamp=current_period.duration
)

current_period = current_period.next_period
current_timestamp = timedelta(0)

return AbsTime(
period=current_period,
timestamp=current_timestamp + other
period=current_period, timestamp=current_timestamp + other
)

def __radd__(self, other: timedelta) -> 'AbsTime':
def __radd__(self, other: timedelta) -> "AbsTime":
assert isinstance(other, timedelta)
return self.__add__(other)

def __rsub__(self, other):
raise RuntimeError("Doesn't make sense.")
raise RuntimeError("Doesn't make sense.")
92 changes: 41 additions & 51 deletions kloppy/tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from kloppy import statsbomb
from kloppy.domain import AbsTime, Period


Expand All @@ -11,17 +12,17 @@ def periods() -> Tuple[Period, Period, Period]:
period1 = Period(
id=1,
start_timestamp=timedelta(seconds=0),
end_timestamp=timedelta(seconds=2700)
end_timestamp=timedelta(seconds=2700),
)
period2 = Period(
id=2,
start_timestamp=timedelta(seconds=0),
end_timestamp=timedelta(seconds=3000)
end_timestamp=timedelta(seconds=3000),
)
period3 = Period(
id=3,
start_timestamp=timedelta(seconds=0),
end_timestamp=timedelta(seconds=1000)
end_timestamp=timedelta(seconds=1000),
)
period1.set_refs(None, period2)
period2.set_refs(period1, period3)
Expand All @@ -31,17 +32,13 @@ def periods() -> Tuple[Period, Period, Period]:

class TestAbsTime:
def test_subtract_timedelta_same_period(self, periods):
"""Test subtract with non-period overlapping timedelta. """
"""Test subtract with non-period overlapping timedelta."""
period1, *_ = periods

abs_time = AbsTime(
period=period1,
timestamp=timedelta(seconds=1800)
)
abs_time = AbsTime(period=period1, timestamp=timedelta(seconds=1800))

assert abs_time - timedelta(seconds=1000) == AbsTime(
period=period1,
timestamp=timedelta(seconds=800)
period=period1, timestamp=timedelta(seconds=800)
)

def test_subtract_timedelta_spans_periods(self, periods):
Expand All @@ -59,69 +56,45 @@ def test_subtract_timedelta_spans_periods(self, periods):
"""
period1, period2, period3 = periods

abs_time = AbsTime(
period=period3,
timestamp=timedelta(seconds=800)
)
abs_time = AbsTime(period=period3, timestamp=timedelta(seconds=800))

assert abs_time - timedelta(seconds=4000) == AbsTime(
period=period1,
timestamp=timedelta(seconds=2500)
period=period1, timestamp=timedelta(seconds=2500)
)

def test_subtract_timedelta_over_start(self, periods):
"""Test subtract that goes over start of first period. This should return start of match. """
"""Test subtract that goes over start of first period. This should return start of match."""
period1, *_ = periods

abs_time = AbsTime(
period=period1,
timestamp=timedelta(seconds=1800)
)
abs_time = AbsTime(period=period1, timestamp=timedelta(seconds=1800))

assert abs_time - timedelta(seconds=2000) == AbsTime(
period=period1,
timestamp=timedelta(0)
period=period1, timestamp=timedelta(0)
)

def test_subtract_two_abstime(self, periods):
"""Subtract two AbsTime in same period"""
period1, *_ = periods
abs_time1 = AbsTime(
period=period1,
timestamp=timedelta(seconds=1000)
)
abs_time2 = AbsTime(
period=period1,
timestamp=timedelta(seconds=800)
)
abs_time1 = AbsTime(period=period1, timestamp=timedelta(seconds=1000))
abs_time2 = AbsTime(period=period1, timestamp=timedelta(seconds=800))

assert abs_time1 - abs_time2 == timedelta(seconds=200)

def test_subtract_two_abstime_spans_periods(self, periods):
"""Subtract AbsTime over multiple periods."""
period1, period2, period3 = periods
abs_time1 = AbsTime(
period=period1,
timestamp=timedelta(seconds=800)
)
abs_time2 = AbsTime(
period=period2,
timestamp=timedelta(seconds=800)
)
abs_time1 = AbsTime(period=period1, timestamp=timedelta(seconds=800))
abs_time2 = AbsTime(period=period2, timestamp=timedelta(seconds=800))

assert abs_time2 - abs_time1 == timedelta(seconds=2700)

def test_add_timedelta_same_period(self, periods):
"""Test add timedelta in same period"""
period1, *_ = periods

abs_time = AbsTime(
period=period1,
timestamp=timedelta(seconds=800)
)
abs_time = AbsTime(period=period1, timestamp=timedelta(seconds=800))
assert abs_time + timedelta(seconds=100) == AbsTime(
period=period1,
timestamp=timedelta(seconds=900)
period=period1, timestamp=timedelta(seconds=900)
)

def test_add_timedelta_spans_periods(self, periods):
Expand All @@ -137,11 +110,28 @@ def test_add_timedelta_spans_periods(self, periods):
"""
period1, period2, period3 = periods

abs_time = AbsTime(
period=period1,
timestamp=timedelta(seconds=800)
)
abs_time = AbsTime(period=period1, timestamp=timedelta(seconds=800))
assert abs_time + timedelta(seconds=5000) == AbsTime(
period=period3,
timestamp=timedelta(seconds=100)
period=period3, timestamp=timedelta(seconds=100)
)

def test_statsbomb(self, base_dir):
dataset = statsbomb.load(
lineup_data=base_dir / "files/statsbomb_lineup.json",
event_data=base_dir / "files/statsbomb_event.json",
)
formation_changes = dataset.filter("formation_change")

# Determine time until first formation change
diff = (
formation_changes[0].abs_time
- dataset.metadata.periods[0].start_abs_time
)
assert diff == timedelta(seconds=2705.267)

# Time until last formation change
diff = (
formation_changes[-1].abs_time
- dataset.metadata.periods[0].start_abs_time
)
assert diff == timedelta(seconds=5067.367)

0 comments on commit 5289425

Please sign in to comment.