Skip to content

Commit

Permalink
Merge pull request #20 from IBM/get_dataset
Browse files Browse the repository at this point in the history
Get dataset
  • Loading branch information
wgifford authored Mar 26, 2024
2 parents 1bb542b + 19fe429 commit 57f476a
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 4 deletions.
41 changes: 41 additions & 0 deletions hacking/datasets_from_tsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# %%
import pandas as pd

from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor


split_config = {"train": [0, 8640], "valid": [8640, 11520], "test": [11520, 14400]}


dataset_path = (
"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"
)

timestamp_column = "date"

df = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)

tsp = TimeSeriesPreprocessor(
id_columns=[],
timestamp_column=timestamp_column,
target_columns=["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"],
observable_columns=[],
control_columns=[],
conditional_columns=[],
static_categorical_columns=[],
scaling=True,
scaler_type="standard",
encode_categorical=False,
prediction_length=10,
context_length=96,
)

train, valid, test = tsp.get_datasets(df, split_config)

# %%
split_config = {"train": [0, 0.7], "valid": [0.7, 0.9], "test": [0.9, 1]}

train, valid, test = tsp.get_datasets(df, split_config)
71 changes: 70 additions & 1 deletion tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
create_timestamps,
extend_time_series,
)
from tsfm_public.toolkit.util import FractionLocation


def test_standard_scaler(sample_data):
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_time_series_preprocessor_encodes(sample_data):
static_categorical_columns = ["cat", "cat2"]

tsp = TimeSeriesPreprocessor(
input_columns=["val", "val2"],
target_columns=["val", "val2"],
static_categorical_columns=static_categorical_columns,
)
tsp.train(sample_data)
Expand Down Expand Up @@ -156,3 +157,71 @@ def test_create_timestamps():
# it is an error to provide neither freq or sequence
with pytest.raises(ValueError):
ts = create_timestamps(start, periods=periods)


def test_get_datasets(ts_data):
tsp = TimeSeriesPreprocessor(
id_columns=["id"],
target_columns=["value1", "value2"],
prediction_length=5,
context_length=10,
)

train, valid, test = tsp.get_datasets(
ts_data,
split_config={"train": [0, 1 / 3], "valid": [1 / 3, 2 / 3], "test": [2 / 3, 1]},
)

# 3 time series of length 50
assert len(train) == 3 * (int((1 / 3) * 50) - (tsp.context_length + tsp.prediction_length) + 1)

assert len(valid) == len(test)

# no id columns, so treat as one big time series
tsp = TimeSeriesPreprocessor(
id_columns=[],
target_columns=["value1", "value2"],
prediction_length=5,
context_length=10,
)

train, valid, test = tsp.get_datasets(
ts_data,
split_config={
"train": [0, 100],
"valid": [100, 125],
"test": [125, 150],
},
fewshot_fraction=0.2,
fewshot_location=FractionLocation.LAST.value,
)

# new train length should be 20% of 100, minus the usual for context length and prediction length
fewshot_train_size = int(100 * 0.2) - (tsp.context_length + tsp.prediction_length) + 1
assert len(train) == fewshot_train_size

assert len(valid) == len(test)

# no id columns, so treat as one big time series
tsp = TimeSeriesPreprocessor(
id_columns=[],
target_columns=["value1", "value2"],
prediction_length=5,
context_length=10,
)

train, valid, test = tsp.get_datasets(
ts_data,
split_config={
"train": [0, 100],
"valid": [100, 125],
"test": [125, 150],
},
fewshot_fraction=0.2,
fewshot_location=FractionLocation.FIRST.value,
)

# new train length should be 20% of 100, minus the usual for context length and prediction length
assert len(train) == fewshot_train_size

assert len(valid) == len(test)
26 changes: 26 additions & 0 deletions tests/toolkit/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Tests for util functions"""

import pytest

from tsfm_public.toolkit.util import get_split_params


split_cases = [
(0, 1, "select_by_index"),
(0, 0.1, "select_by_relative_fraction"),
(0.0, 0.1, "select_by_relative_fraction"),
(0.0, 200.0, "select_by_index"),
(0.0, 200, "select_by_index"),
(0.5, 1, "select_by_relative_fraction"),
(0.5, 1.0, "select_by_relative_fraction"),
(10, 100.0, "select_by_index"),
]


@pytest.mark.parametrize("left_arg,right_arg,expected", split_cases)
def test_get_split_params(left_arg, right_arg, expected):
"""Test that get_split_params gives the right split function"""

split_config, split_function = get_split_params({"train": [left_arg, right_arg], "valid": [0, 1], "test": [0, 1]})

assert split_function["train"].__name__ == expected
93 changes: 92 additions & 1 deletion tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
PreTrainedFeatureExtractor,
)

from .util import join_list_without_repeat
from .dataset import ForecastDFDataset
from .util import (
FractionLocation,
get_split_params,
join_list_without_repeat,
select_by_fixed_fraction,
)


INTERNAL_ID_COLUMN = "__id"
Expand Down Expand Up @@ -586,6 +592,91 @@ def scale_func(grp, id_columns):

return df

def get_datasets(
self,
dataset: Union[Dataset, pd.DataFrame],
split_config: Dict[str, Any],
fewshot_fraction: Optional[float] = None,
fewshot_location: str = FractionLocation.LAST.value,
) -> Tuple[Any]:
"""Creates the preprocessed pytorch datasets needed for training and evaluation
using the HuggingFace trainer
Args:
dataset (Union[Dataset, pd.DataFrame]): Loaded pandas dataframe
split_config (Dict[str, Any]): Dictionary of dictionaries containing
split parameters. For example:
{
train: [0, 50],
valid: [50, 70],
test: [70, 100]
}
end value is not inclusive
fewshot_fraction (float, optional): When non-null, return this percent of the original training
dataset. This is done to support fewshot fine-tuning. The fraction of data chosen is at the
end of the training dataset.
fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last"
as described in the enum FewshotLocation. Default is to choose the fewshot data at the end
of the training dataset (i.e., "last").
Returns:
Tuple of pytorch datasets, including: train, validation, test.
"""

data = self._standardize_dataframe(dataset)

# get split_params
# split_params = get_split_params(config, self.context_length, len(data))

split_params, split_function = get_split_params(split_config, context_length=self.context_length)

# specify columns
column_specifiers = {
"id_columns": self.id_columns,
"timestamp_column": self.timestamp_column,
"target_columns": self.target_columns,
"observable_columns": self.observable_columns,
"control_columns": self.control_columns,
"conditional_columns": self.conditional_columns,
"static_categorical_columns": self.static_categorical_columns,
}

# split data
train_data = split_function["train"](data, id_columns=self.id_columns, **split_params["train"])
valid_data = split_function["valid"](data, id_columns=self.id_columns, **split_params["valid"])
test_data = split_function["test"](data, id_columns=self.id_columns, **split_params["test"])

# data preprocessing
self.train(train_data)

# handle fewshot operation
if fewshot_fraction is not None:
if not ((fewshot_fraction <= 1.0) and (fewshot_fraction > 0.0)):
raise ValueError(f"Fewshot fraction should be between 0 and 1, received {fewshot_fraction}")

train_data = select_by_fixed_fraction(
train_data,
id_columns=self.id_columns,
fraction=fewshot_fraction,
location=fewshot_location,
)

params = column_specifiers
params["context_length"] = self.context_length
params["prediction_length"] = self.prediction_length

# get torch datasets
test_dataset = ForecastDFDataset(
self.preprocess(test_data),
**params,
)
train_dataset = ForecastDFDataset(self.preprocess(train_data), **params)
valid_dataset = ForecastDFDataset(
self.preprocess(valid_data),
**params,
)
return train_dataset, valid_dataset, test_dataset


def create_timestamps(
last_timestamp: Union[datetime.datetime, pd.Timestamp],
Expand Down
Loading

0 comments on commit 57f476a

Please sign in to comment.