Skip to content

Commit

Permalink
Merge pull request #91 from ibm-granite/freq_token
Browse files Browse the repository at this point in the history
Freq token with pipeline
  • Loading branch information
wgifford authored Jul 30, 2024
2 parents 5493539 + cdb4fae commit 9b4801d
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 7 deletions.
128 changes: 121 additions & 7 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,88 @@
"""Tests the time series preprocessor and functions"""

import pandas as pd
import pytest
from transformers import PatchTSTForPrediction

from tsfm_public import TinyTimeMixerForPrediction
from tsfm_public.toolkit.time_series_forecasting_pipeline import (
TimeSeriesForecastingPipeline,
)
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
from tsfm_public.toolkit.time_series_preprocessor import DEFAULT_FREQUENCY_MAPPING, TimeSeriesPreprocessor
from tsfm_public.toolkit.util import select_by_index


def test_forecasting_pipeline_forecasts():
@pytest.fixture(scope="module")
def patchtst_model():
model_path = "ibm-granite/granite-timeseries-patchtst"
model = PatchTSTForPrediction.from_pretrained(model_path)

return model


@pytest.fixture(scope="module")
def ttm_model():
model_path = "ibm-granite/granite-timeseries-ttm-v1"
model = TinyTimeMixerForPrediction.from_pretrained(model_path)

return model


@pytest.fixture(scope="module")
def etth_data():
timestamp_column = "date"
id_columns = []
target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
prediction_length = 96

model_path = "ibm/patchtst-etth1-forecasting"
model = PatchTSTForPrediction.from_pretrained(model_path)
dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)
train_end_index = 12 * 30 * 24

context_length = 512 # model.config.context_length

test_end_index = 12 * 30 * 24 + 8 * 30 * 24
test_start_index = test_end_index - context_length - 4

data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)

train_data = select_by_index(
data,
id_columns=id_columns,
start_index=0,
end_index=train_end_index,
)
test_data = select_by_index(
data,
id_columns=id_columns,
start_index=test_start_index,
end_index=test_end_index,
)

params = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": target_columns,
"prediction_length": prediction_length,
"context_length": context_length,
}

return train_data, test_data, params


def test_forecasting_pipeline_forecasts(patchtst_model):
timestamp_column = "date"
id_columns = []
target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
prediction_length = 96

model = patchtst_model
context_length = model.config.context_length

forecast_pipeline = TimeSeriesForecastingPipeline(
Expand Down Expand Up @@ -111,14 +176,13 @@ def test_forecasting_pipeline_forecasts():
assert forecasts.shape == (10, 2 * len(target_columns) + 1)


def test_forecasting_pipeline_forecasts_with_preprocessor():
def test_forecasting_pipeline_forecasts_with_preprocessor(patchtst_model):
timestamp_column = "date"
id_columns = []
target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
prediction_length = 96

model_path = "ibm/patchtst-etth1-forecasting"
model = PatchTSTForPrediction.from_pretrained(model_path)
model = patchtst_model
context_length = model.config.context_length

dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
Expand Down Expand Up @@ -181,3 +245,53 @@ def test_forecasting_pipeline_forecasts_with_preprocessor():

# if we have inverse scaled mean should be larger
assert forecasts["HUFL_prediction"].mean().mean() > 10


def test_frequency_token(ttm_model, etth_data):
model = ttm_model
train_data, test_data, params = etth_data

timestamp_column = params["timestamp_column"]
id_columns = params["id_columns"]
target_columns = params["target_columns"]
prediction_length = params["prediction_length"]
context_length = params["context_length"]

tsp = TimeSeriesPreprocessor(
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
freq="1h",
scaling=True,
)

tsp.train(train_data)

assert model.config.resolution_prefix_tuning is False

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
inverse_scale_outputs=True,
)
assert forecast_pipeline._preprocess_params["frequency_token"] is None

model.config.resolution_prefix_tuning = True
forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
inverse_scale_outputs=True,
)
assert forecast_pipeline._preprocess_params["frequency_token"] == DEFAULT_FREQUENCY_MAPPING["h"]
14 changes: 14 additions & 0 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ def _sanitize_parameters(
context_length = kwargs.get("context_length", self.model.config.context_length)
prediction_length = kwargs.get("prediction_length", self.model.config.prediction_length)

# get freq from kwargs or the preprocessor
freq = kwargs.get("freq", None)
if self.feature_extractor and not freq:
freq = self.feature_extractor.freq

# check if we need to use the frequency token, get token in needed
use_frequency_token = getattr(self.model.config, "resolution_prefix_tuning", False)

if use_frequency_token and self.feature_extractor:
frequency_token = self.feature_extractor.get_frequency_token(freq)
else:
frequency_token = None

# autopopulate from feature extractor
if self.feature_extractor:
for p in [
Expand All @@ -161,6 +174,7 @@ def _sanitize_parameters(
preprocess_kwargs = {
"prediction_length": prediction_length,
"context_length": context_length,
"frequency_token": frequency_token,
}
postprocess_kwargs = {
"prediction_length": prediction_length,
Expand Down

0 comments on commit 9b4801d

Please sign in to comment.