diff --git a/kloppy/domain/models/tracking.py b/kloppy/domain/models/tracking.py index ca6603dd..e5487d9b 100644 --- a/kloppy/domain/models/tracking.py +++ b/kloppy/domain/models/tracking.py @@ -1,35 +1,83 @@ +from collections import defaultdict from dataclasses import dataclass, field -from typing import List, Dict, Optional, Callable, Union, Any +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +from scipy.signal import savgol_filter from kloppy.domain.models.common import DatasetType +from kloppy.exceptions import KloppyError +from kloppy.utils import deprecated -from .common import Dataset, DataRecord, Player -from .pitch import Point, Point3D -from kloppy.utils import ( - deprecated, -) +from .common import DataRecord, Dataset, Player +from .pitch import Point @dataclass -class PlayerData: +class Detection: + """A single detection of a trackable object in a frame. + + Attributes: + coordinates: The coordinates of the object in the frame. + distance: The distance the object has traveled since the previous frame. + speed: The speed of the object in the frame. + acceleration: The acceleration of the object in the frame. + other_data: Additional data about the object in the frame. + """ + coordinates: Point distance: Optional[float] = None speed: Optional[float] = None + acceleration: Optional[float] = None other_data: Dict[str, Any] = field(default_factory=dict) +@dataclass +class Trajectory: + """Detections of a trackable object over a sequence of consecutive frames. + + Attributes: + trackable_object: The object being tracked. Either a player or "ball". + start_frame: The frame number of the first detection in the trajectory. + end_frame: The frame number of the last detection in the trajectory. + detections: A list of Detection objects, one for each frame in the trajectory. + """ + + trackable_object: Union[Player, str] + start_frame: int + end_frame: int + detections: List[Detection] + + def __iter__(self): + return iter(self.detections) + + def __len__(self): + return len(self.detections) + + @dataclass(repr=False) class Frame(DataRecord): frame_id: int - players_data: Dict[Player, PlayerData] + ball_data: Optional[Detection] + players_data: Dict[Player, Detection] other_data: Dict[str, Any] - ball_coordinates: Point3D - ball_speed: Optional[float] = None @property def record_id(self) -> int: return self.frame_id + @property + def ball_coordinates(self): + if self.ball_data is None: + return None + return self.ball_data.coordinates + + @property + def ball_speed(self): + if self.ball_data is None: + return None + return self.ball_data.speed + @property def players_coordinates(self): return { @@ -52,6 +100,167 @@ def frames(self): def frame_rate(self): return self.metadata.frame_rate + @property + def trajectories(self): + trajectories = defaultdict(list) + + # get ball trajectories + current_trajectory = None + for record in self.records: + if ( + record.ball_data + and record.ball_data.coordinates + and record.ball_data.coordinates.x != float("nan") + ): + if current_trajectory is None: + current_trajectory = Trajectory( + trackable_object="ball", + start_frame=record.frame_id, + end_frame=record.frame_id, + detections=[record.ball_data], + ) + else: + current_trajectory.end_frame = record.frame_id + current_trajectory.detections.append(record.ball_data) + else: + if current_trajectory: + trajectories["ball"].append(current_trajectory) + current_trajectory = None + if current_trajectory: + trajectories["ball"].append(current_trajectory) + + # get player trajectories + for team in self.metadata.teams: + for player in team.players: + current_trajectory = None + for record in self.records: + if ( + player in record.players_data + and record.players_data[player].coordinates is not None + and record.players_data[player].coordinates.x + != float("nan") + ): + if current_trajectory is None: + current_trajectory = Trajectory( + trackable_object=player, + start_frame=record.frame_id, + end_frame=record.frame_id, + detections=[record.players_data[player]], + ) + else: + current_trajectory.end_frame = record.frame_id + current_trajectory.detections.append( + record.players_data[player] + ) + else: + if current_trajectory: + trajectories[player].append(current_trajectory) + current_trajectory = None + if current_trajectory: + trajectories[player].append(current_trajectory) + + return trajectories + + def compute_kinematics( + self, + n_smooth_speed: int = 6, + n_smooth_acc: int = 10, + max_speed_player: float = 12.0, + max_speed_ball: float = 50.0, + ): + """Compute speed and acceleration for each object in the dataset. + + Args: + n_smooth_speed: The number of frames to smooth over when computing speed. + n_smooth_acc: The number of frames to smooth over when computing acceleration. + max_speed_player: The maximum speed allowed for a player (in m/s). + max_speed_ball: The maximum speed allowed for the ball (in m/s). + + """ + if self.metadata.frame_rate is None: + raise KloppyError( + "Frame rate is not set in metadata. Please set the frame rate before computing kinematics." + ) + for trackable_object, trajectories in self.trajectories.items(): + max_speed = ( + max_speed_player + if isinstance(trackable_object, Player) + else max_speed_ball + ) + + for trajectory in trajectories: + if len(trajectory) < n_smooth_speed + 1: + continue + + # get x-y coordinates in metric space + tracked_maps = np.empty((len(trajectory), 2)) + for i, detection in enumerate(trajectory): + point = detection.coordinates + metric_point = self.metadata.pitch_dimensions.to_base( + point + ) + tracked_maps[i] = [metric_point.x, metric_point.y] + + # apply a Savitzky-Golay filter for smoothing + tracked_maps = smoothing_savgol_3rd(tracked_maps) + + # get speed vect and speed norm + dist = ( + tracked_maps[n_smooth_speed:] + - tracked_maps[:-n_smooth_speed] + ) + dist_norm = np.linalg.norm(dist, axis=1) / n_smooth_speed + + speed_vect = dist / (n_smooth_speed / self.metadata.frame_rate) + speed_norm = np.linalg.norm(speed_vect, axis=1) + + # acc process for short tracks + if speed_vect.shape[0] < self.metadata.frame_rate: + acc_vect = np.nan * np.ones_like(speed_vect) + acc_norm = np.nan * np.ones_like(speed_norm) + else: + # acc vect process for other tracks + diff_acc = ( + speed_vect[n_smooth_acc:] - speed_vect[:-n_smooth_acc] + ) + acc_vect = diff_acc / ( + n_smooth_acc / self.metadata.frame_rate + ) + + # padding to respect the shape after the smoothing + add = np.zeros((n_smooth_acc // 2, 2)) + np.nan + acc_vect = np.concatenate((add, acc_vect, add)) + + # apply a physical check based on speed and acc + acc_vect = apply_criterion(speed_vect, acc_vect, max_speed) + + # acc norm process for other tracks + diff_acc = ( + speed_norm[n_smooth_acc:] - speed_norm[:-n_smooth_acc] + ) + acc_norm = diff_acc / ( + n_smooth_acc / self.metadata.frame_rate + ) + + # padding to respect the shape after the smoothing + add = np.zeros((n_smooth_acc // 2)) + np.nan + acc_norm = np.concatenate((add, acc_norm, add)) + + # apply last padding + add = np.zeros((n_smooth_speed // 2, 2)) + np.nan + speed_vect = np.concatenate((add, speed_vect, add)) + + add = np.zeros((n_smooth_speed // 2)) + np.nan + dist_norm = np.concatenate((add, dist_norm, add)) + speed_norm = np.concatenate((add, speed_norm, add)) + acc_norm = np.concatenate((add, acc_norm, add)) + + # fill detection dict with physical info + for i, detection in enumerate(trajectory): + detection.distance = dist_norm[i] + detection.speed = speed_norm[i] + detection.acceleration = acc_norm[i] + @deprecated( "to_pandas will be removed in the future. Please use to_df instead." ) @@ -94,4 +303,42 @@ def generic_record_converter(frame: Frame): ) -__all__ = ["Frame", "TrackingDataset", "PlayerData"] +def apply_criterion(speeds, acc, max_speed): + """ + Criterion used to spot tracking inaccuracies. + + Args: + speeds (np.array): one player/ball speed (vx, vy) per row + acc (np.array): one player/ball acceleration (ax, ay) per row, same shape as speeds + max_speed (float): maximum speed allowed for a player or the ball + + Returns: + acc: same as acc with value set to np.NaN if criterion <= 0 + """ + criterion = -(9.1 / max_speed) * speeds + 9.1 - acc + mask = np.isnan(criterion) + criterion[mask] = -np.inf + mask_criterion = criterion <= 0.0 + acc[mask_criterion] = np.nan + return acc + + +def smoothing_savgol_3rd(raw_maps): + """ + Smooth player/ball positions using a Savitzky-Golay filter. + + Args: + raw_maps (np.array): one player/ball position (x, y) per row + + Returns: + tracked_maps: smoothed player positions + """ + window_length = min(raw_maps.shape[0], 31) + if window_length % 2 == 0: + window_length = window_length - 1 + polyorder = min(window_length - 1, 3) + tracked_maps = savgol_filter(raw_maps, window_length, polyorder, axis=0) + return tracked_maps + + +__all__ = ["Frame", "TrackingDataset", "Detection", "Trajectory"] diff --git a/kloppy/domain/services/transformers/dataset.py b/kloppy/domain/services/transformers/dataset.py index 3e784cc6..576896f9 100644 --- a/kloppy/domain/services/transformers/dataset.py +++ b/kloppy/domain/services/transformers/dataset.py @@ -1,6 +1,6 @@ from dataclasses import fields, replace -from kloppy.domain.models.tracking import PlayerData +from kloppy.domain.models.tracking import Detection from typing import Union, Optional from kloppy.domain import ( @@ -204,17 +204,22 @@ def __change_frame_coordinate_system(self, frame: Frame): ball_state=frame.ball_state, period=frame.period, # changes - ball_coordinates=self.__change_point_coordinate_system( - frame.ball_coordinates - ), - ball_speed=frame.ball_speed, + ball_data=replace( + frame.ball_data, + coordinates=self.__change_point_coordinate_system( + frame.ball_data.coordinates + ), + ) + if frame.ball_data + else None, players_data={ - key: PlayerData( + key: Detection( coordinates=self.__change_point_coordinate_system( player_data.coordinates ), distance=player_data.distance, speed=player_data.speed, + acceleration=player_data.acceleration, other_data=player_data.other_data, ) for key, player_data in frame.players_data.items() @@ -231,11 +236,16 @@ def __change_frame_dimensions(self, frame: Frame): ball_state=frame.ball_state, period=frame.period, # changes - ball_coordinates=self.change_point_dimensions( - frame.ball_coordinates - ), + ball_data=replace( + frame.ball_data, + coordinates=self.change_point_dimensions( + frame.ball_data.coordinates + ), + ) + if frame.ball_data + else None, players_data={ - key: PlayerData( + key: Detection( coordinates=self.change_point_dimensions( player_data.coordinates ), @@ -285,7 +295,7 @@ def __change_point_coordinate_system( def __flip_frame(self, frame: Frame): players_data = {} for player, data in frame.players_data.items(): - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=self.flip_point(data.coordinates), distance=data.distance, speed=data.speed, @@ -300,7 +310,12 @@ def __flip_frame(self, frame: Frame): ball_state=frame.ball_state, period=frame.period, # changes - ball_coordinates=self.flip_point(frame.ball_coordinates), + ball_data=replace( + frame.ball_data, + coordinates=self.flip_point(frame.ball_data.coordinates), + ) + if frame.ball_data + else None, players_data=players_data, other_data=frame.other_data, ) diff --git a/kloppy/infra/serializers/event/statsbomb/helpers.py b/kloppy/infra/serializers/event/statsbomb/helpers.py index 85edcae1..db795fea 100644 --- a/kloppy/infra/serializers/event/statsbomb/helpers.py +++ b/kloppy/infra/serializers/event/statsbomb/helpers.py @@ -9,7 +9,7 @@ Frame, Period, Player, - PlayerData, + Detection, ) from kloppy.exceptions import DeserializationError @@ -111,14 +111,14 @@ def get_player_from_freeze_frame(player_data, team, i): freeze_frame_player, freeze_frame_team, i ) - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=parse_coordinates( freeze_frame_player["location"], fidelity_version ) ) if event.player not in players_data: - players_data[event.player] = PlayerData(coordinates=event.coordinates) + players_data[event.player] = Detection(coordinates=event.coordinates) FREEZE_FRAME_FPS = 25 frame_id = int( @@ -128,8 +128,10 @@ def get_player_from_freeze_frame(player_data, team, i): return Frame( frame_id=frame_id, - ball_coordinates=Point3D( - x=event.coordinates.x, y=event.coordinates.y, z=0 + ball_data=Detection( + coordinates=Point3D( + x=event.coordinates.x, y=event.coordinates.y, z=0 + ), ), players_data=players_data, period=event.period, diff --git a/kloppy/infra/serializers/tracking/metrica_csv.py b/kloppy/infra/serializers/tracking/metrica_csv.py index e20c1172..3cd5cf7f 100644 --- a/kloppy/infra/serializers/tracking/metrica_csv.py +++ b/kloppy/infra/serializers/tracking/metrica_csv.py @@ -18,7 +18,7 @@ Team, Ground, Player, - PlayerData, + Detection, ) from kloppy.infra.serializers.tracking.deserializer import ( TrackingDataDeserializer, @@ -108,7 +108,7 @@ def __create_iterator( period=period, frame_id=frame_id, players_data={ - player: PlayerData( + player: Detection( coordinates=Point( x=float(columns[3 + i * 2]), y=1 - float(columns[3 + i * 2 + 1]), @@ -183,6 +183,10 @@ def deserialize( period: Period = home_partial_frame.period frame_id: int = home_partial_frame.frame_id + ball_data = Detection( + coordinates=home_partial_frame.ball_coordinates + ) + players_data = { **home_partial_frame.players_data, **away_partial_frame.players_data, @@ -192,7 +196,7 @@ def deserialize( frame_id=frame_id, timestamp=timedelta(seconds=frame_id / frame_rate) - period.start_timestamp, - ball_coordinates=home_partial_frame.ball_coordinates, + ball_data=ball_data, players_data=players_data, period=period, ball_state=None, diff --git a/kloppy/infra/serializers/tracking/metrica_epts/deserializer.py b/kloppy/infra/serializers/tracking/metrica_epts/deserializer.py index b81a6d45..6b77f7c0 100644 --- a/kloppy/infra/serializers/tracking/metrica_epts/deserializer.py +++ b/kloppy/infra/serializers/tracking/metrica_epts/deserializer.py @@ -8,7 +8,7 @@ Point, Point3D, Provider, - PlayerData, + Detection, DatasetTransformer, ) from kloppy.utils import performance_logging @@ -57,7 +57,7 @@ def _frame_from_row( player_sensor_val = row.get(player_sensor_field_str) other_data.update({sensor.sensor_id: player_sensor_val}) - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=Point( x=row[f"player_{player.player_id}_x"], y=row[f"player_{player.player_id}_y"], @@ -81,8 +81,10 @@ def _frame_from_row( period=period, players_data=players_data, other_data={}, - ball_coordinates=Point3D( - x=row["ball_x"], y=row["ball_y"], z=row.get("ball_z") + ball_data=Detection( + coordinates=Point3D( + x=row["ball_x"], y=row["ball_y"], z=row.get("ball_z") + ), ), ) diff --git a/kloppy/infra/serializers/tracking/secondspectrum.py b/kloppy/infra/serializers/tracking/secondspectrum.py index 5933f13a..ecd37602 100644 --- a/kloppy/infra/serializers/tracking/secondspectrum.py +++ b/kloppy/infra/serializers/tracking/secondspectrum.py @@ -22,7 +22,7 @@ Ground, Player, Provider, - PlayerData, + Detection, ) from kloppy.utils import Readable, performance_logging @@ -62,13 +62,14 @@ def _frame_from_framedata(cls, teams, period, frame_data): if frame_data["ball"]["xyz"]: ball_x, ball_y, ball_z = frame_data["ball"]["xyz"] - ball_coordinates = Point3D( - float(ball_x), float(ball_y), float(ball_z) + ball_data = Detection( + coordinates=Point3D( + float(ball_x), float(ball_y), float(ball_z) + ), + speed=frame_data["ball"]["speed"], ) - ball_speed = frame_data["ball"]["speed"] else: - ball_coordinates = None - ball_speed = None + ball_data = None ball_state = BallState.ALIVE if frame_data["live"] else BallState.DEAD ball_owning_team = ( @@ -91,15 +92,14 @@ def _frame_from_framedata(cls, teams, period, frame_data): ) team.players.append(player) - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=Point(float(x), float(y)), speed=speed ) return Frame( frame_id=frame_id, timestamp=frame_timestamp, - ball_coordinates=ball_coordinates, - ball_speed=ball_speed, + ball_data=ball_data, ball_state=ball_state, ball_owning_team=ball_owning_team, players_data=players_data, diff --git a/kloppy/infra/serializers/tracking/skillcorner.py b/kloppy/infra/serializers/tracking/skillcorner.py index 39d07c1c..f6dca302 100644 --- a/kloppy/infra/serializers/tracking/skillcorner.py +++ b/kloppy/infra/serializers/tracking/skillcorner.py @@ -25,7 +25,7 @@ Score, Team, TrackingDataset, - PlayerData, + Detection, ) from kloppy.infra.serializers.tracking.deserializer import ( TrackingDataDeserializer, @@ -91,7 +91,7 @@ def _get_frame_data( else: raise ValueError(f"Unknown period id {frame_period}") - ball_coordinates = None + ball_data = None players_data = {} # ball_carrier = frame["possession"].get("trackable_object") @@ -119,7 +119,9 @@ def _get_frame_data( z = frame_record.get("z") if z is not None: z = float(z) - ball_coordinates = Point3D(x=float(x), y=float(y), z=z) + ball_data = Detection( + coordinates=Point3D(x=float(x), y=float(y), z=z) + ) continue elif trackable_object in referee_dict.keys(): @@ -152,12 +154,12 @@ def _get_frame_data( else: player = anon_players["AWAY"][f"anon_away_{player_id}"] - players_data[player] = PlayerData(coordinates=Point(x, y)) + players_data[player] = Detection(coordinates=Point(x, y)) return Frame( frame_id=frame_id, timestamp=frame_time, - ball_coordinates=ball_coordinates, + ball_data=ball_data, players_data=players_data, period=periods[frame_period], ball_state=None, diff --git a/kloppy/infra/serializers/tracking/sportec/deserializer.py b/kloppy/infra/serializers/tracking/sportec/deserializer.py index 038cb3ab..dfefe3db 100644 --- a/kloppy/infra/serializers/tracking/sportec/deserializer.py +++ b/kloppy/infra/serializers/tracking/sportec/deserializer.py @@ -19,7 +19,7 @@ attacking_direction_from_frame, Metadata, Provider, - PlayerData, + Detection, ) from kloppy.utils import performance_logging @@ -176,7 +176,7 @@ def _iter(): else BallState.DEAD, period=period, players_data={ - player_map[player_id]: PlayerData( + player_map[player_id]: Detection( coordinates=Point( x=float(raw_player_data["X"]), y=float(raw_player_data["Y"]), @@ -187,12 +187,14 @@ def _iter(): if player_id != "ball" }, other_data={}, - ball_coordinates=Point3D( - x=float(ball_data["X"]), - y=float(ball_data["Y"]), - z=float(ball_data["Z"]), + ball_data=Detection( + coordinates=Point3D( + x=float(ball_data["X"]), + y=float(ball_data["Y"]), + z=float(ball_data["Z"]), + ), + speed=float(ball_data["S"]), ), - ball_speed=float(ball_data["S"]), ) frames = [] diff --git a/kloppy/infra/serializers/tracking/statsperform.py b/kloppy/infra/serializers/tracking/statsperform.py index 2265f5ad..4c3de986 100644 --- a/kloppy/infra/serializers/tracking/statsperform.py +++ b/kloppy/infra/serializers/tracking/statsperform.py @@ -16,7 +16,7 @@ Orientation, Period, Player, - PlayerData, + Detection, Point, Point3D, Provider, @@ -84,11 +84,13 @@ def _frame_from_framedata(cls, teams_list, period, frame_data): ball_owning_team = None if len(components) > 2: - ball_data = components[2].split(";")[0].split(",") - ball_x, ball_y, ball_z = map(float, ball_data) - ball_coordinates = Point3D(ball_x, ball_y, ball_z) + raw_ball_data = components[2].split(";")[0].split(",") + ball_x, ball_y, ball_z = map(float, raw_ball_data) + ball_data = Detection( + coordinates=Point3D(ball_x, ball_y, ball_z), + ) else: - ball_coordinates = None + ball_data = None players_data = {} player_info = components[1].split(";")[:-1] @@ -115,12 +117,12 @@ def _frame_from_framedata(cls, teams_list, period, frame_data): ) team.players.append(player) - players_data[player] = PlayerData(coordinates=Point(x, y)) + players_data[player] = Detection(coordinates=Point(x, y)) return Frame( frame_id=frame_id, timestamp=frame_timestamp, - ball_coordinates=ball_coordinates, + ball_data=ball_data, ball_state=ball_state, ball_owning_team=ball_owning_team, players_data=players_data, diff --git a/kloppy/infra/serializers/tracking/tracab/tracab_dat.py b/kloppy/infra/serializers/tracking/tracab/tracab_dat.py index 4a636980..67788ae8 100644 --- a/kloppy/infra/serializers/tracking/tracab/tracab_dat.py +++ b/kloppy/infra/serializers/tracking/tracab/tracab_dat.py @@ -21,7 +21,7 @@ Ground, Player, Provider, - PlayerData, + Detection, ) from kloppy.exceptions import DeserializationError @@ -77,7 +77,7 @@ def _frame_from_line(cls, teams, period, line, frame_rate): ) team.players.append(player) - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=Point(float(x), float(y)), speed=float(speed) ) @@ -90,6 +90,11 @@ def _frame_from_line(cls, teams, period, line, frame_rate): ball_state, ) = ball.rstrip(";").split(",")[:6] + ball_data = Detection( + coordinates=Point3D(float(ball_x), float(ball_y), float(ball_z)), + speed=float(ball_speed), + ) + frame_id = int(frame_id) if ball_owning_team == "H": @@ -112,9 +117,7 @@ def _frame_from_line(cls, teams, period, line, frame_rate): frame_id=frame_id, timestamp=timedelta(seconds=frame_id / frame_rate) - period.start_timestamp, - ball_coordinates=Point3D( - float(ball_x), float(ball_y), float(ball_z) - ), + ball_data=ball_data, ball_state=ball_state, ball_owning_team=ball_owning_team, players_data=players_data, diff --git a/kloppy/infra/serializers/tracking/tracab/tracab_json.py b/kloppy/infra/serializers/tracking/tracab/tracab_json.py index ca361183..064dc6c0 100644 --- a/kloppy/infra/serializers/tracking/tracab/tracab_json.py +++ b/kloppy/infra/serializers/tracking/tracab/tracab_json.py @@ -20,7 +20,7 @@ Ground, Player, Provider, - PlayerData, + Detection, Position, attacking_direction_from_frame, ) @@ -75,7 +75,7 @@ def _create_frame(cls, teams, period, raw_frame, frame_rate): player = team.get_player_by_jersey_number(jersey_no) if player: - players_data[player] = PlayerData( + players_data[player] = Detection( coordinates=Point(x, y), speed=speed ) else: @@ -88,6 +88,9 @@ def _create_frame(cls, teams, period, raw_frame, frame_rate): ball_y = raw_ball_position["Y"] ball_z = raw_ball_position["Z"] ball_speed = raw_ball_position["Speed"] + ball_data = Detection( + coordinates=Point3D(ball_x, ball_y, ball_z), speed=ball_speed + ) if raw_ball_position["BallOwningTeam"] == "H": ball_owning_team = teams[0] elif raw_ball_position["BallOwningTeam"] == "A": @@ -109,10 +112,9 @@ def _create_frame(cls, teams, period, raw_frame, frame_rate): frame_id=frame_id, timestamp=timedelta(seconds=frame_id / frame_rate) - period.start_timestamp, - ball_coordinates=Point3D(ball_x, ball_y, ball_z), + ball_data=ball_data, ball_state=ball_state, ball_owning_team=ball_owning_team, - ball_speed=ball_speed, players_data=players_data, period=period, other_data={}, diff --git a/kloppy/tests/test_helpers.py b/kloppy/tests/test_helpers.py index 65df77ad..40a618ed 100644 --- a/kloppy/tests/test_helpers.py +++ b/kloppy/tests/test_helpers.py @@ -25,7 +25,7 @@ Team, Ground, Player, - PlayerData, + Detection, Point3D, ) @@ -79,7 +79,9 @@ def _get_tracking_dataset(self): period=periods[0], players_data={}, other_data=None, - ball_coordinates=Point3D(x=100, y=-50, z=0), + ball_data=Detection( + coordinates=Point3D(x=100, y=-50, z=0), + ), ), Frame( frame_id=2, @@ -90,7 +92,7 @@ def _get_tracking_dataset(self): players_data={ Player( team=home_team, player_id="home_1", jersey_no=1 - ): PlayerData( + ): Detection( coordinates=Point(x=15, y=35), distance=0.03, speed=10.5, @@ -98,7 +100,9 @@ def _get_tracking_dataset(self): ) }, other_data={"extra_data": 1}, - ball_coordinates=Point3D(x=0, y=50, z=1), + ball_data=Detection( + coordinates=Point3D(x=0, y=50, z=1), + ), ), ], )