diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index 503e5c46..baa862a0 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -4,23 +4,88 @@ """Tests the time series preprocessor and functions""" import pandas as pd +import pytest from transformers import PatchTSTForPrediction +from tsfm_public import TinyTimeMixerForPrediction from tsfm_public.toolkit.time_series_forecasting_pipeline import ( TimeSeriesForecastingPipeline, ) -from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor +from tsfm_public.toolkit.time_series_preprocessor import DEFAULT_FREQUENCY_MAPPING, TimeSeriesPreprocessor from tsfm_public.toolkit.util import select_by_index -def test_forecasting_pipeline_forecasts(): +@pytest.fixture(scope="module") +def patchtst_model(): + model_path = "ibm-granite/granite-timeseries-patchtst" + model = PatchTSTForPrediction.from_pretrained(model_path) + + return model + + +@pytest.fixture(scope="module") +def ttm_model(): + model_path = "ibm-granite/granite-timeseries-ttm-v1" + model = TinyTimeMixerForPrediction.from_pretrained(model_path) + + return model + + +@pytest.fixture(scope="module") +def etth_data(): timestamp_column = "date" id_columns = [] target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] prediction_length = 96 - model_path = "ibm/patchtst-etth1-forecasting" - model = PatchTSTForPrediction.from_pretrained(model_path) + dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" + data = pd.read_csv( + dataset_path, + parse_dates=[timestamp_column], + ) + train_end_index = 12 * 30 * 24 + + context_length = 512 # model.config.context_length + + test_end_index = 12 * 30 * 24 + 8 * 30 * 24 + test_start_index = test_end_index - context_length - 4 + + data = pd.read_csv( + dataset_path, + parse_dates=[timestamp_column], + ) + + train_data = select_by_index( + data, + id_columns=id_columns, + start_index=0, + end_index=train_end_index, + ) + test_data = select_by_index( + data, + id_columns=id_columns, + start_index=test_start_index, + end_index=test_end_index, + ) + + params = { + "timestamp_column": timestamp_column, + "id_columns": id_columns, + "target_columns": target_columns, + "prediction_length": prediction_length, + "context_length": context_length, + } + + return train_data, test_data, params + + +def test_forecasting_pipeline_forecasts(patchtst_model): + timestamp_column = "date" + id_columns = [] + target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] + prediction_length = 96 + + model = patchtst_model context_length = model.config.context_length forecast_pipeline = TimeSeriesForecastingPipeline( @@ -111,14 +176,13 @@ def test_forecasting_pipeline_forecasts(): assert forecasts.shape == (10, 2 * len(target_columns) + 1) -def test_forecasting_pipeline_forecasts_with_preprocessor(): +def test_forecasting_pipeline_forecasts_with_preprocessor(patchtst_model): timestamp_column = "date" id_columns = [] target_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] prediction_length = 96 - model_path = "ibm/patchtst-etth1-forecasting" - model = PatchTSTForPrediction.from_pretrained(model_path) + model = patchtst_model context_length = model.config.context_length dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" @@ -181,3 +245,53 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): # if we have inverse scaled mean should be larger assert forecasts["HUFL_prediction"].mean().mean() > 10 + + +def test_frequency_token(ttm_model, etth_data): + model = ttm_model + train_data, test_data, params = etth_data + + timestamp_column = params["timestamp_column"] + id_columns = params["id_columns"] + target_columns = params["target_columns"] + prediction_length = params["prediction_length"] + context_length = params["context_length"] + + tsp = TimeSeriesPreprocessor( + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + freq="1h", + scaling=True, + ) + + tsp.train(train_data) + + assert model.config.resolution_prefix_tuning is False + + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + feature_extractor=tsp, + explode_forecasts=False, + inverse_scale_outputs=True, + ) + assert forecast_pipeline._preprocess_params["frequency_token"] is None + + model.config.resolution_prefix_tuning = True + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + feature_extractor=tsp, + explode_forecasts=False, + inverse_scale_outputs=True, + ) + assert forecast_pipeline._preprocess_params["frequency_token"] == DEFAULT_FREQUENCY_MAPPING["h"] diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index ef53a67e..94e436c9 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -143,6 +143,19 @@ def _sanitize_parameters( context_length = kwargs.get("context_length", self.model.config.context_length) prediction_length = kwargs.get("prediction_length", self.model.config.prediction_length) + # get freq from kwargs or the preprocessor + freq = kwargs.get("freq", None) + if self.feature_extractor and not freq: + freq = self.feature_extractor.freq + + # check if we need to use the frequency token, get token in needed + use_frequency_token = getattr(self.model.config, "resolution_prefix_tuning", False) + + if use_frequency_token and self.feature_extractor: + frequency_token = self.feature_extractor.get_frequency_token(freq) + else: + frequency_token = None + # autopopulate from feature extractor if self.feature_extractor: for p in [ @@ -161,6 +174,7 @@ def _sanitize_parameters( preprocess_kwargs = { "prediction_length": prediction_length, "context_length": context_length, + "frequency_token": frequency_token, } postprocess_kwargs = { "prediction_length": prediction_length,