Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Striding for datasets #59

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,47 @@ def test_forecasting_df_dataset(ts_data_with_categorical):
assert np.all(ds[0]["future_values"][:, 2].numpy() == 0)


def test_forecasting_df_dataset_stride(ts_data_with_categorical):
prediction_length = 2
context_length = 3
stride = 13
target_columns = ["value1", "value2"]

df = ts_data_with_categorical

ds = ForecastDFDataset(
df,
timestamp_column="timestamp",
id_columns=["id"],
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
stride=stride,
)

# length check
series_len = len(df) / len(df["id"].unique())
assert len(ds) == ((series_len - prediction_length - context_length + 1) // stride) * len(df["id"].unique())

# check proper windows are selected based on chosen stride
ds_past_np = np.array([v["past_values"].numpy() for v in ds])
ds_past_np_expected = np.array(
[
[[0.0, 10.0], [1.0, 10.333333], [2.0, 10.666667]],
[[13.0, 14.333333], [14.0, 14.666667], [15.0, 15.0]],
[[26.0, 18.666666], [27.0, 19.0], [28.0, 19.333334]],
[[50.0, 26.666666], [51.0, 27.0], [52.0, 27.333334]],
[[63.0, 31.0], [64.0, 31.333334], [65.0, 31.666666]],
[[76.0, 35.333332], [77.0, 35.666668], [78.0, 36.0]],
[[100.0, 43.333332], [101.0, 43.666668], [102.0, 44.0]],
[[113.0, 47.666668], [114.0, 48.0], [115.0, 48.333332]],
[[126.0, 52.0], [127.0, 52.333332], [128.0, 52.666668]],
]
)

np.testing.assert_allclose(ds_past_np, ds_past_np_expected)


def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical):
prediction_length = 2
target_columns = ["value1"]
Expand Down
35 changes: 30 additions & 5 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
context_length: int = 1,
prediction_length: int = 0,
zero_padding: bool = True,
stride: int = 1,
):
super().__init__()
if not isinstance(x_cols, list):
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
self.zero_padding = zero_padding
self.timestamps = None
self.group_id = group_id
self.stride = stride

# sort the data by datetime
if timestamp_column in list(data_df.columns):
Expand Down Expand Up @@ -116,7 +118,7 @@ def pad_zero(self, data_df):
)

def __len__(self):
return len(self.X) - self.context_length - self.prediction_length + 1
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride

def __getitem__(self, index: int):
"""
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(
prediction_length: int = 1,
num_workers: int = 1,
cls=BaseDFDataset,
stride: int = 1,
**kwargs,
):
if len(id_columns) > 0:
Expand All @@ -166,6 +169,7 @@ def __init__(
self.num_workers = num_workers
self.cls = cls
self.prediction_length = prediction_length
self.stride = stride
self.extra_kwargs = kwargs

# create groupby object
Expand Down Expand Up @@ -208,6 +212,7 @@ def concat_dataset(self):
self.context_length,
self.prediction_length,
self.drop_cols,
self.stride,
self.extra_kwargs,
)
for group_id, group in group_df
Expand All @@ -228,6 +233,7 @@ def get_group_data(
context_length: int = 1,
prediction_length: int = 1,
drop_cols: Optional[List[str]] = None,
stride: int = 1,
extra_kwargs: Dict[str, Any] = {},
):
return cls(
Expand All @@ -238,6 +244,7 @@ def get_group_data(
context_length=context_length,
prediction_length=prediction_length,
drop_cols=drop_cols,
stride=stride,
**extra_kwargs,
)

Expand All @@ -264,6 +271,7 @@ def __init__(
target_columns: List[str] = [],
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
):
super().__init__(
data_df=data,
Expand All @@ -274,6 +282,7 @@ def __init__(
prediction_length=0,
cls=self.BasePretrainDFDataset,
target_columns=target_columns,
stride=stride,
)
self.n_inp = 1

Expand All @@ -288,6 +297,7 @@ def __init__(
id_columns: List[str] = [],
timestamp_column: Optional[str] = None,
target_columns: List[str] = [],
stride: int = 1,
):
self.target_columns = target_columns

Expand All @@ -304,9 +314,11 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
ret = {"past_values": np_to_torch(seq_x)}
if self.datetime_col:
Expand Down Expand Up @@ -346,6 +358,7 @@ def __init__(
num_workers: int = 1,
frequency_token: Optional[int] = None,
autoregressive_modeling: bool = True,
stride: int = 1,
):
# output_columns_tmp = input_columns if output_columns == [] else output_columns

Expand All @@ -357,6 +370,7 @@ def __init__(
context_length=context_length,
prediction_length=prediction_length,
cls=self.BaseForecastDFDataset,
stride=stride,
# extra_args
target_columns=target_columns,
observable_columns=observable_columns,
Expand Down Expand Up @@ -391,6 +405,7 @@ def __init__(
static_categorical_columns: List[str] = [],
frequency_token: Optional[int] = None,
autoregressive_modeling: bool = True,
stride: int = 1,
):
self.frequency_token = frequency_token
self.target_columns = target_columns
Expand Down Expand Up @@ -430,10 +445,14 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols

time_id = index * self.stride

seq_x = self.X[time_id : time_id + self.context_length].values
if not self.autoregressive_modeling:
seq_x[:, self.x_mask_targets] = 0
Expand Down Expand Up @@ -465,7 +484,7 @@ def __getitem__(self, time_id):
return ret

def __len__(self):
return len(self.X) - self.context_length - self.prediction_length + 1
return (len(self.X) - self.context_length - self.prediction_length + 1) // self.stride


class RegressionDFDataset(BaseConcatDFDataset):
Expand All @@ -492,6 +511,7 @@ def __init__(
static_categorical_columns: List[str] = [],
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
):
# self.y_cols = y_cols

Expand All @@ -505,6 +525,7 @@ def __init__(
input_columns=input_columns,
target_columns=target_columns,
static_categorical_columns=static_categorical_columns,
stride=stride,
)

self.n_inp = 2
Expand All @@ -526,6 +547,7 @@ def __init__(
target_columns: List[str] = [],
input_columns: List[str] = [],
static_categorical_columns: List[str] = [],
stride: int = 1,
):
self.target_columns = target_columns
self.input_columns = input_columns
Expand All @@ -544,10 +566,13 @@ def __init__(
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
)

def __getitem__(self, time_id):
def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
seq_y = self.y[time_id + self.context_length - 1 : time_id + self.context_length].values.ravel()
# return _torch(seq_x, seq_y)
Expand Down
Loading