Skip to content

Commit

Permalink
Merge pull request #59 from ibm-granite/stride
Browse files Browse the repository at this point in the history
Striding for datasets
  • Loading branch information
wgifford authored May 31, 2024
2 parents 02dc9dc + 25a1d0d commit d1d5ebe
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
41 changes: 41 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,47 @@ def test_forecasting_df_dataset(ts_data_with_categorical):
assert np.all(ds[0]["future_values"][:, 2].numpy() == 0)


def test_forecasting_df_dataset_stride(ts_data_with_categorical):
prediction_length = 2
context_length = 3
stride = 13
target_columns = ["value1", "value2"]

df = ts_data_with_categorical

ds = ForecastDFDataset(
df,
timestamp_column="timestamp",
id_columns=["id"],
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
stride=stride,
)

# length check
series_len = len(df) / len(df["id"].unique())
assert len(ds) == ((series_len - prediction_length - context_length + 1) // stride) * len(df["id"].unique())

# check proper windows are selected based on chosen stride
ds_past_np = np.array([v["past_values"].numpy() for v in ds])
ds_past_np_expected = np.array(
[
[[0.0, 10.0], [1.0, 10.333333], [2.0, 10.666667]],
[[13.0, 14.333333], [14.0, 14.666667], [15.0, 15.0]],
[[26.0, 18.666666], [27.0, 19.0], [28.0, 19.333334]],
[[50.0, 26.666666], [51.0, 27.0], [52.0, 27.333334]],
[[63.0, 31.0], [64.0, 31.333334], [65.0, 31.666666]],
[[76.0, 35.333332], [77.0, 35.666668], [78.0, 36.0]],
[[100.0, 43.333332], [101.0, 43.666668], [102.0, 44.0]],
[[113.0, 47.666668], [114.0, 48.0], [115.0, 48.333332]],
[[126.0, 52.0], [127.0, 52.333332], [128.0, 52.666668]],
]
)

np.testing.assert_allclose(ds_past_np, ds_past_np_expected)


def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical):
prediction_length = 2
target_columns = ["value1"]
Expand Down
35 changes: 30 additions & 5 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
context_length: int = 1,
prediction_length: int = 0,
zero_padding: bool = True,
stride: int = 1,
):
super().__init__()
if not isinstance(x_cols, list):
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
self.zero_padding = zero_padding
self.timestamps = None
self.group_id = group_id
self.stride = stride

# sort the data by datetime
if timestamp_column in list(data_df.columns):
Expand Down Expand Up @@ -116,7 +118,7 @@ def pad_zero(self, data_df):
)

def __len__(self):
return len(self.X) - self.context_length - self.prediction_length + 1
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride

def __getitem__(self, index: int):
"""
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(
prediction_length: int = 1,
num_workers: int = 1,
cls=BaseDFDataset,
stride: int = 1,
**kwargs,
):
if len(id_columns) > 0:
Expand All @@ -166,6 +169,7 @@ def __init__(
self.num_workers = num_workers
self.cls = cls
self.prediction_length = prediction_length
self.stride = stride
self.extra_kwargs = kwargs

# create groupby object
Expand Down Expand Up @@ -208,6 +212,7 @@ def concat_dataset(self):
self.context_length,
self.prediction_length,
self.drop_cols,
self.stride,
self.extra_kwargs,
)
for group_id, group in group_df
Expand All @@ -228,6 +233,7 @@ def get_group_data(
context_length: int = 1,
prediction_length: int = 1,
drop_cols: Optional[List[str]] = None,
stride: int = 1,
extra_kwargs: Dict[str, Any] = {},
):
return cls(
Expand All @@ -238,6 +244,7 @@ def get_group_data(
context_length=context_length,
prediction_length=prediction_length,
drop_cols=drop_cols,
stride=stride,
**extra_kwargs,
)

Expand All @@ -264,6 +271,7 @@ def __init__(
target_columns: List[str] = [],
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
):
super().__init__(
data_df=data,
Expand All @@ -274,6 +282,7 @@ def __init__(
prediction_length=0,
cls=self.BasePretrainDFDataset,
target_columns=target_columns,
stride=stride,
)
self.n_inp = 1

Expand All @@ -288,6 +297,7 @@ def __init__(
id_columns: List[str] = [],
timestamp_column: Optional[str] = None,
target_columns: List[str] = [],
stride: int = 1,
):
self.target_columns = target_columns

Expand All @@ -304,9 +314,11 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
ret = {"past_values": np_to_torch(seq_x)}
if self.datetime_col:
Expand Down Expand Up @@ -346,6 +358,7 @@ def __init__(
num_workers: int = 1,
frequency_token: Optional[int] = None,
autoregressive_modeling: bool = True,
stride: int = 1,
):
# output_columns_tmp = input_columns if output_columns == [] else output_columns

Expand All @@ -357,6 +370,7 @@ def __init__(
context_length=context_length,
prediction_length=prediction_length,
cls=self.BaseForecastDFDataset,
stride=stride,
# extra_args
target_columns=target_columns,
observable_columns=observable_columns,
Expand Down Expand Up @@ -391,6 +405,7 @@ def __init__(
static_categorical_columns: List[str] = [],
frequency_token: Optional[int] = None,
autoregressive_modeling: bool = True,
stride: int = 1,
):
self.frequency_token = frequency_token
self.target_columns = target_columns
Expand Down Expand Up @@ -430,10 +445,14 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols

time_id = index * self.stride

seq_x = self.X[time_id : time_id + self.context_length].values
if not self.autoregressive_modeling:
seq_x[:, self.x_mask_targets] = 0
Expand Down Expand Up @@ -465,7 +484,7 @@ def __getitem__(self, time_id):
return ret

def __len__(self):
return len(self.X) - self.context_length - self.prediction_length + 1
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride


class RegressionDFDataset(BaseConcatDFDataset):
Expand All @@ -492,6 +511,7 @@ def __init__(
static_categorical_columns: List[str] = [],
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
):
# self.y_cols = y_cols

Expand All @@ -505,6 +525,7 @@ def __init__(
input_columns=input_columns,
target_columns=target_columns,
static_categorical_columns=static_categorical_columns,
stride=stride,
)

self.n_inp = 2
Expand All @@ -526,6 +547,7 @@ def __init__(
target_columns: List[str] = [],
input_columns: List[str] = [],
static_categorical_columns: List[str] = [],
stride: int = 1,
):
self.target_columns = target_columns
self.input_columns = input_columns
Expand All @@ -544,10 +566,13 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
seq_y = self.y[time_id + self.context_length - 1 : time_id + self.context_length].values.ravel()
# return _torch(seq_x, seq_y)
Expand Down

0 comments on commit d1d5ebe

Please sign in to comment.