diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 388879c7..1fc452f0 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -152,6 +152,47 @@ def test_forecasting_df_dataset(ts_data_with_categorical): assert np.all(ds[0]["future_values"][:, 2].numpy() == 0) +def test_forecasting_df_dataset_stride(ts_data_with_categorical): + prediction_length = 2 + context_length = 3 + stride = 13 + target_columns = ["value1", "value2"] + + df = ts_data_with_categorical + + ds = ForecastDFDataset( + df, + timestamp_column="timestamp", + id_columns=["id"], + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + stride=stride, + ) + + # length check + series_len = len(df) / len(df["id"].unique()) + assert len(ds) == ((series_len - prediction_length - context_length + 1) // stride) * len(df["id"].unique()) + + # check proper windows are selected based on chosen stride + ds_past_np = np.array([v["past_values"].numpy() for v in ds]) + ds_past_np_expected = np.array( + [ + [[0.0, 10.0], [1.0, 10.333333], [2.0, 10.666667]], + [[13.0, 14.333333], [14.0, 14.666667], [15.0, 15.0]], + [[26.0, 18.666666], [27.0, 19.0], [28.0, 19.333334]], + [[50.0, 26.666666], [51.0, 27.0], [52.0, 27.333334]], + [[63.0, 31.0], [64.0, 31.333334], [65.0, 31.666666]], + [[76.0, 35.333332], [77.0, 35.666668], [78.0, 36.0]], + [[100.0, 43.333332], [101.0, 43.666668], [102.0, 44.0]], + [[113.0, 47.666668], [114.0, 48.0], [115.0, 48.333332]], + [[126.0, 52.0], [127.0, 52.333332], [128.0, 52.666668]], + ] + ) + + np.testing.assert_allclose(ds_past_np, ds_past_np_expected) + + def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical): prediction_length = 2 target_columns = ["value1"] diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 57cb9c6c..e4222328 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -42,6 +42,7 @@ def __init__( context_length: int = 1, prediction_length: int = 0, zero_padding: bool = True, + stride: int = 1, ): super().__init__() if not isinstance(x_cols, list): @@ -72,6 +73,7 @@ def __init__( self.zero_padding = zero_padding self.timestamps = None self.group_id = group_id + self.stride = stride # sort the data by datetime if timestamp_column in list(data_df.columns): @@ -116,7 +118,7 @@ def pad_zero(self, data_df): ) def __len__(self): - return len(self.X) - self.context_length - self.prediction_length + 1 + return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride def __getitem__(self, index: int): """ @@ -153,6 +155,7 @@ def __init__( prediction_length: int = 1, num_workers: int = 1, cls=BaseDFDataset, + stride: int = 1, **kwargs, ): if len(id_columns) > 0: @@ -166,6 +169,7 @@ def __init__( self.num_workers = num_workers self.cls = cls self.prediction_length = prediction_length + self.stride = stride self.extra_kwargs = kwargs # create groupby object @@ -208,6 +212,7 @@ def concat_dataset(self): self.context_length, self.prediction_length, self.drop_cols, + self.stride, self.extra_kwargs, ) for group_id, group in group_df @@ -228,6 +233,7 @@ def get_group_data( context_length: int = 1, prediction_length: int = 1, drop_cols: Optional[List[str]] = None, + stride: int = 1, extra_kwargs: Dict[str, Any] = {}, ): return cls( @@ -238,6 +244,7 @@ def get_group_data( context_length=context_length, prediction_length=prediction_length, drop_cols=drop_cols, + stride=stride, **extra_kwargs, ) @@ -264,6 +271,7 @@ def __init__( target_columns: List[str] = [], context_length: int = 1, num_workers: int = 1, + stride: int = 1, ): super().__init__( data_df=data, @@ -274,6 +282,7 @@ def __init__( prediction_length=0, cls=self.BasePretrainDFDataset, target_columns=target_columns, + stride=stride, ) self.n_inp = 1 @@ -288,6 +297,7 @@ def __init__( id_columns: List[str] = [], timestamp_column: Optional[str] = None, target_columns: List[str] = [], + stride: int = 1, ): self.target_columns = target_columns @@ -304,9 +314,11 @@ def __init__( prediction_length=prediction_length, group_id=group_id, drop_cols=drop_cols, + stride=stride, ) - def __getitem__(self, time_id): + def __getitem__(self, index): + time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values ret = {"past_values": np_to_torch(seq_x)} if self.datetime_col: @@ -346,6 +358,7 @@ def __init__( num_workers: int = 1, frequency_token: Optional[int] = None, autoregressive_modeling: bool = True, + stride: int = 1, ): # output_columns_tmp = input_columns if output_columns == [] else output_columns @@ -357,6 +370,7 @@ def __init__( context_length=context_length, prediction_length=prediction_length, cls=self.BaseForecastDFDataset, + stride=stride, # extra_args target_columns=target_columns, observable_columns=observable_columns, @@ -391,6 +405,7 @@ def __init__( static_categorical_columns: List[str] = [], frequency_token: Optional[int] = None, autoregressive_modeling: bool = True, + stride: int = 1, ): self.frequency_token = frequency_token self.target_columns = target_columns @@ -430,10 +445,14 @@ def __init__( prediction_length=prediction_length, group_id=group_id, drop_cols=drop_cols, + stride=stride, ) - def __getitem__(self, time_id): + def __getitem__(self, index): # seq_x: batch_size x seq_len x num_x_cols + + time_id = index * self.stride + seq_x = self.X[time_id : time_id + self.context_length].values if not self.autoregressive_modeling: seq_x[:, self.x_mask_targets] = 0 @@ -465,7 +484,7 @@ def __getitem__(self, time_id): return ret def __len__(self): - return len(self.X) - self.context_length - self.prediction_length + 1 + return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride class RegressionDFDataset(BaseConcatDFDataset): @@ -492,6 +511,7 @@ def __init__( static_categorical_columns: List[str] = [], context_length: int = 1, num_workers: int = 1, + stride: int = 1, ): # self.y_cols = y_cols @@ -505,6 +525,7 @@ def __init__( input_columns=input_columns, target_columns=target_columns, static_categorical_columns=static_categorical_columns, + stride=stride, ) self.n_inp = 2 @@ -526,6 +547,7 @@ def __init__( target_columns: List[str] = [], input_columns: List[str] = [], static_categorical_columns: List[str] = [], + stride: int = 1, ): self.target_columns = target_columns self.input_columns = input_columns @@ -544,10 +566,13 @@ def __init__( prediction_length=prediction_length, group_id=group_id, drop_cols=drop_cols, + stride=stride, ) - def __getitem__(self, time_id): + def __getitem__(self, index): # seq_x: batch_size x seq_len x num_x_cols + + time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values seq_y = self.y[time_id + self.context_length - 1 : time_id + self.context_length].values.ravel() # return _torch(seq_x, seq_y)