Skip to content

Commit

Permalink
Merge branch 'master' into feat/pitchdimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
probberechts committed Apr 3, 2024
2 parents a9c41cf + e678d62 commit 0786488
Show file tree
Hide file tree
Showing 44 changed files with 7,029 additions and 271 deletions.
39 changes: 35 additions & 4 deletions kloppy/_providers/tracab.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Optional
from typing import Optional, Union, Type


from kloppy.domain import TrackingDataset
from kloppy.infra.serializers.tracking.tracab import (
TRACABDeserializer,
from kloppy.infra.serializers.tracking.tracab.tracab_dat import (
TRACABDatDeserializer,
TRACABInputs,
)
from kloppy.infra.serializers.tracking.tracab.tracab_json import (
TRACABJSONDeserializer,
TRACABInputs,
)
from kloppy.io import FileLike, open_as_file
Expand All @@ -15,8 +20,16 @@ def load(
limit: Optional[int] = None,
coordinates: Optional[str] = None,
only_alive: Optional[bool] = True,
file_format: Optional[str] = None,
) -> TrackingDataset:
deserializer = TRACABDeserializer(
if file_format == "dat":
deserializer_class = TRACABDatDeserializer
elif file_format == "json":
deserializer_class = TRACABJSONDeserializer
else:
deserializer_class = identify_deserializer(meta_data, raw_data)

deserializer = deserializer_class(
sample_rate=sample_rate,
limit=limit,
coordinate_system=coordinates,
Expand All @@ -28,3 +41,21 @@ def load(
return deserializer.deserialize(
inputs=TRACABInputs(meta_data=meta_data_fp, raw_data=raw_data_fp)
)


def identify_deserializer(
meta_data: FileLike,
raw_data: FileLike,
) -> Union[Type[TRACABDatDeserializer], Type[TRACABJSONDeserializer]]:
deserializer = None
if "xml" in meta_data.name and "dat" in raw_data.name:
deserializer = TRACABDatDeserializer
if "json" in meta_data.name and "json" in raw_data.name:
deserializer = TRACABJSONDeserializer

if deserializer is None:
raise ValueError(
"Tracab file format could not be recognized, please specify"
)

return deserializer
3 changes: 2 additions & 1 deletion kloppy/domain/models/code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from dataclasses import dataclass, field
from typing import List, Dict, Callable, Union, Any

Expand Down Expand Up @@ -26,7 +27,7 @@ class Code(DataRecord):

code_id: str
code: str
end_timestamp: float
end_timestamp: timedelta
labels: Dict[str, Union[bool, str]] = field(default_factory=dict)

@property
Expand Down
35 changes: 24 additions & 11 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from enum import Enum, Flag
from typing import (
Dict,
Expand Down Expand Up @@ -44,6 +45,7 @@
OrientationError,
InvalidFilterError,
KloppyParameterError,
KloppyError,
)


Expand Down Expand Up @@ -248,21 +250,32 @@ class Period:
Period
Attributes:
id: `1` for first half, `2` for second half, `3` for first overtime,
`4` for second overtime, and `5` for penalty shootouts
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
id: `1` for first half, `2` for second half, `3` for first half of
overtime, `4` for second half of overtime, `5` for penalty shootout
start_timestamp: The UTC datetime of the kick-off or, if the
absolute datetime is not available, the offset between the start
of the data feed and the period's kick-off
end_timestamp: The UTC datetime of the final whistle or, if the
absolute datetime is not available, the offset between the start
of the data feed and the period's final whistle
attacking_direction: See [`AttackingDirection`][kloppy.domain.models.common.AttackingDirection]
"""

id: int
start_timestamp: float
end_timestamp: float
start_timestamp: Union[datetime, timedelta]
end_timestamp: Union[datetime, timedelta]

def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp
def contains(self, timestamp: datetime):
if isinstance(self.start_timestamp, datetime) and isinstance(
self.end_timestamp, datetime
):
return self.start_timestamp <= timestamp <= self.end_timestamp
raise KloppyError(
"This method can only be used when start_timestamp and end_timestamp are a datetime"
)

@property
def duration(self):
def duration(self) -> timedelta:
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
Expand Down Expand Up @@ -854,7 +867,7 @@ class DataRecord(ABC):
Attributes:
period: See [`Period`][kloppy.domain.models.common.Period]
timestamp: Timestamp of occurrence
timestamp: Timestamp of occurrence, relative to the period kick-off
ball_owning_team: See [`Team`][kloppy.domain.models.common.Team]
ball_state: See [`Team`][kloppy.domain.models.common.BallState]
"""
Expand All @@ -863,7 +876,7 @@ class DataRecord(ABC):
prev_record: Optional["DataRecord"] = field(init=False)
next_record: Optional["DataRecord"] = field(init=False)
period: Period
timestamp: float
timestamp: timedelta
ball_owning_team: Optional[Team]
ball_state: Optional[BallState]

Expand Down
23 changes: 22 additions & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class EventType(Enum):
BALL_OUT (EventType):
FOUL_COMMITTED (EventType):
GOALKEEPER (EventType):
PRESSURE (EventType):
FORMATION_CHANGE (EventType):
"""

Expand All @@ -238,6 +239,7 @@ class EventType(Enum):
BALL_OUT = "BALL_OUT"
FOUL_COMMITTED = "FOUL_COMMITTED"
GOALKEEPER = "GOALKEEPER"
PRESSURE = "PRESSURE"
FORMATION_CHANGE = "FORMATION_CHANGE"

def __repr__(self):
Expand Down Expand Up @@ -354,6 +356,7 @@ class PassType(Enum):
THROUGH_BALL = "THROUGH_BALL"
CHIPPED_PASS = "CHIPPED_PASS"
FLICK_ON = "FLICK_ON"
SHOT_ASSIST = "SHOT_ASSIST"
ASSIST = "ASSIST"
ASSIST_2ND = "ASSIST_2ND"
SWITCH_OF_PLAY = "SWITCH_OF_PLAY"
Expand Down Expand Up @@ -698,7 +701,7 @@ def matches(self, filter_) -> bool:
return True

def __str__(self):
m, s = divmod(self.timestamp, 60)
m, s = divmod(self.timestamp.total_seconds(), 60)

event_type = (
self.__class__.__name__
Expand Down Expand Up @@ -1023,6 +1026,24 @@ class GoalkeeperEvent(Event):
event_name: str = "goalkeeper"


@dataclass(repr=False)
@docstring_inherit_attributes(Event)
class PressureEvent(Event):
"""
PressureEvent
Attributes:
event_type (EventType): `EventType.Pressure` (See [`EventType`][kloppy.domain.models.event.EventType])
event_name (str): `"pressure"`,
end_timestamp (float):
"""

end_timestamp: float

event_type: EventType = EventType.PRESSURE
event_name: str = "pressure"


@dataclass(repr=False)
class EventDataset(Dataset[Event]):
"""
Expand Down
4 changes: 4 additions & 0 deletions kloppy/domain/services/event_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SubstitutionEvent,
GoalkeeperEvent,
)
from kloppy.domain.models.event import PressureEvent

T = TypeVar("T")

Expand Down Expand Up @@ -122,3 +123,6 @@ def build_substitution(self, **kwargs) -> SubstitutionEvent:

def build_goalkeeper_event(self, **kwargs) -> GoalkeeperEvent:
return create_event(GoalkeeperEvent, **kwargs)

def build_pressure_event(self, **kwargs) -> PressureEvent:
return create_event(PressureEvent, **kwargs)
23 changes: 17 additions & 6 deletions kloppy/infra/serializers/code/sportscode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from datetime import timedelta
from typing import Union, IO, NamedTuple

from lxml import objectify, etree
Expand Down Expand Up @@ -50,15 +51,19 @@ def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset:
all_instances = objectify.fromstring(inputs.data.read())

codes = []
period = Period(id=1, start_timestamp=0, end_timestamp=0)
period = Period(
id=1,
start_timestamp=timedelta(seconds=0),
end_timestamp=timedelta(seconds=0),
)
for instance in all_instances.ALL_INSTANCES.iterchildren():
end_timestamp = float(instance.end)
end_timestamp = timedelta(seconds=float(instance.end))

code = Code(
period=period,
code_id=str(instance.ID),
code=str(instance.code),
timestamp=float(instance.start),
timestamp=timedelta(seconds=float(instance.start)),
end_timestamp=end_timestamp,
labels=parse_labels(instance),
ball_state=None,
Expand Down Expand Up @@ -88,7 +93,7 @@ def serialize(self, dataset: CodeDataset) -> bytes:
root = etree.Element("file")
all_instances = etree.SubElement(root, "ALL_INSTANCES")
for i, code in enumerate(dataset.codes):
relative_period_start = 0
relative_period_start = timedelta(seconds=0)
for period in dataset.metadata.periods:
if period == code.period:
break
Expand All @@ -100,10 +105,16 @@ def serialize(self, dataset: CodeDataset) -> bytes:
id_.text = code.code_id or str(i + 1)

start = etree.SubElement(instance, "start")
start.text = str(relative_period_start + code.start_timestamp)
start.text = str(
relative_period_start.total_seconds()
+ code.start_timestamp.total_seconds()
)

end = etree.SubElement(instance, "end")
end.text = str(relative_period_start + code.end_timestamp)
end.text = str(
relative_period_start.total_seconds()
+ code.end_timestamp.total_seconds()
)

code_ = etree.SubElement(instance, "code")
code_.text = code.code
Expand Down
57 changes: 46 additions & 11 deletions kloppy/infra/serializers/event/datafactory/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
from datetime import timedelta, datetime, timezone
from dataclasses import replace
from typing import Dict, List, Tuple, Union, IO, NamedTuple

from kloppy.domain import (
Expand Down Expand Up @@ -155,8 +157,10 @@
DF_EVENT_TYPE_PENALTY_SHOOTOUT_POST = 183


def parse_str_ts(raw_event: Dict) -> float:
return raw_event["t"]["m"] * 60 + (raw_event["t"]["s"] or 0)
def parse_str_ts(raw_event: Dict) -> timedelta:
return timedelta(
seconds=raw_event["t"]["m"] * 60 + (raw_event["t"]["s"] or 0)
)


def _parse_coordinates(coordinates: Dict[str, float]) -> Point:
Expand Down Expand Up @@ -397,8 +401,21 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
# setup periods
status = incidences.pop(DF_EVENT_CLASS_STATUS)
# start timestamps are fixed
start_ts = {1: 0, 2: 45 * 60, 3: 90 * 60, 4: 105 * 60, 5: 120 * 60}
start_ts = {
1: timedelta(minutes=0),
2: timedelta(minutes=45),
3: timedelta(minutes=90),
4: timedelta(minutes=105),
5: timedelta(minutes=120),
}
# check for end status updates to setup periods
start_event_types = {
DF_EVENT_TYPE_STATUS_MATCH_START,
DF_EVENT_TYPE_STATUS_SECOND_HALF_START,
DF_EVENT_TYPE_STATUS_FIRST_EXTRA_START,
DF_EVENT_TYPE_STATUS_SECOND_EXTRA_START,
DF_EVENT_TYPE_STATUS_PENALTY_SHOOTOUT_START,
}
end_event_types = {
DF_EVENT_TYPE_STATUS_MATCH_END,
DF_EVENT_TYPE_STATUS_FIRST_HALF_END,
Expand All @@ -408,15 +425,33 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
}
periods = {}
for status_update in status.values():
if status_update["type"] not in end_event_types:
if status_update["type"] not in (
start_event_types | end_event_types
):
continue
timestamp = datetime.strptime(
match["date"]
+ status_update["time"]
+ match["stadiumGMT"],
"%Y%m%d%H:%M:%S%z",
).astimezone(timezone.utc)
half = status_update["t"]["half"]
end_ts = parse_str_ts(status_update)
periods[half] = Period(
id=half,
start_timestamp=start_ts[half],
end_timestamp=end_ts,
)
if status_update["type"] == DF_EVENT_TYPE_STATUS_MATCH_START:
half = 1
if status_update["type"] in start_event_types:
periods[half] = Period(
id=half,
start_timestamp=timestamp,
end_timestamp=None,
)
elif status_update["type"] in end_event_types:
if half not in periods:
raise DeserializationError(
f"Missing start event for period {half}"
)
periods[half] = replace(
periods[half], end_timestamp=timestamp
)

# exclude goals, already listed as shots too
incidences.pop(DF_EVENT_CLASS_GOALS)
Expand Down Expand Up @@ -444,7 +479,7 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
# skip invalid event
continue

timestamp = parse_str_ts(raw_event)
timestamp = parse_str_ts(raw_event) - start_ts[period.id]
if (
previous_event is not None
and previous_event["t"]["half"] != raw_event["t"]["half"]
Expand Down
Loading

0 comments on commit 0786488

Please sign in to comment.