From 96e4cea140be95912e2382da5a1a8ed1226bca54 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 21 Aug 2024 19:44:44 -0400 Subject: [PATCH 1/3] move defaults to init, per HF pipeline docs --- .../time_series_forecasting_pipeline.py | 95 +++++++++++-------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index 94e436c9..7dfa26f7 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -123,6 +123,44 @@ def __init__( kwargs["explode_forecasts"] = explode_forecasts kwargs["inverse_scale_outputs"] = inverse_scale_outputs kwargs["add_known_ground_truth"] = add_known_ground_truth + + # autopopulate from feature extractor and model + if "feature_extractor" in kwargs: + for p in [ + "id_columns", + "timestamp_column", + "target_columns", + "observable_columns", + "control_columns", + "conditional_columns", + "static_categorical_columns", + "freq", + ]: + if p not in kwargs: + kwargs[p] = getattr(kwargs["feature_extractor"], p) + + # get freq from kwargs or the preprocessor + if "freq" not in kwargs: + kwargs["freq"] = kwargs["feature_extractor"].freq + + model = kwargs.get("model", None) + if not model: + raise ValueError("A model must be supplied") + + if "context_length" not in kwargs: + kwargs["context_length"] = model.config.context_length + + if "prediction_length" not in kwargs: + kwargs["prediction_length"] = model.config.prediction_length + + # check if we need to use the frequency token, get token if needed + use_frequency_token = getattr(model.config, "resolution_prefix_tuning", False) + + if use_frequency_token and "feature_extractor" in kwargs: + kwargs["frequency_token"] = kwargs["feature_extractor"].get_frequency_token(kwargs["freq"]) + else: + kwargs["frequency_token"] = None + super().__init__(*args, **kwargs) if self.framework == "tf": @@ -140,48 +178,13 @@ def _sanitize_parameters( For expected parameters see the call method below. """ - context_length = kwargs.get("context_length", self.model.config.context_length) - prediction_length = kwargs.get("prediction_length", self.model.config.prediction_length) - - # get freq from kwargs or the preprocessor - freq = kwargs.get("freq", None) - if self.feature_extractor and not freq: - freq = self.feature_extractor.freq - - # check if we need to use the frequency token, get token in needed - use_frequency_token = getattr(self.model.config, "resolution_prefix_tuning", False) - - if use_frequency_token and self.feature_extractor: - frequency_token = self.feature_extractor.get_frequency_token(freq) - else: - frequency_token = None - - # autopopulate from feature extractor - if self.feature_extractor: - for p in [ - "id_columns", - "timestamp_column", - "target_columns", - "observable_columns", - "control_columns", - "conditional_columns", - "static_categorical_columns", - "freq", - ]: - if p not in kwargs: - kwargs[p] = getattr(self.feature_extractor, p) - - preprocess_kwargs = { - "prediction_length": prediction_length, - "context_length": context_length, - "frequency_token": frequency_token, - } - postprocess_kwargs = { - "prediction_length": prediction_length, - "context_length": context_length, - } + preprocess_kwargs = {} + postprocess_kwargs = {} preprocess_params = [ + "prediction_length", + "context_length", + "frequency_token", "id_columns", "timestamp_column", "target_columns", @@ -192,6 +195,8 @@ def _sanitize_parameters( "future_time_series", ] postprocess_params = [ + "prediction_length", + "context_length", "id_columns", "timestamp_column", "target_columns", @@ -356,6 +361,16 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li if c not in time_series.columns: raise ValueError(f"Future time series input contains an unknown column {c}.") + if id_columns: + id_count = time_series[id_columns].unique().shape[0] + else: + id_count = 1 + + if future_time_series.shape[0] != prediction_length * id_count: + raise ValueError( + f"If provided, `future_time_series` data should cover the prediction length for each of the time series in the test dataset. Received data of length {future_time_series.shape[0]} but expected {prediction_length * id_count}" + ) + time_series = pd.concat((time_series, future_time_series), axis=0) else: # no additional exogenous data provided, extend with empty periods From d99b140cce5df525905612a2442d5ead702ec35a Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:14:53 -0400 Subject: [PATCH 2/3] add tests for default handling --- .../test_time_series_forecasting_pipeline.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index c788bcd6..d77f781c 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -5,7 +5,7 @@ import pandas as pd import pytest -from transformers import PatchTSTForPrediction +from transformers import PatchTSTConfig, PatchTSTForPrediction from tsfm_public import TinyTimeMixerConfig, TinyTimeMixerForPrediction from tsfm_public.toolkit.time_series_forecasting_pipeline import ( @@ -17,7 +17,7 @@ @pytest.fixture(scope="module") def patchtst_model(): - model_path = "ibm-granite/granite-timeseries-patchtst" + model_path = "ibm/test-patchtst" model = PatchTSTForPrediction.from_pretrained(model_path) return model @@ -81,6 +81,20 @@ def etth_data(): return train_data, test_data, params +def test_forecasting_pipeline_defaults(): + model = PatchTSTForPrediction(PatchTSTConfig(prediction_length=3, context_length=33)) + + tspipe = TimeSeriesForecastingPipeline(model=model) + + assert tspipe._preprocess_params["prediction_length"] == 3 + assert tspipe._preprocess_params["context_length"] == 33 + + tspipe = TimeSeriesForecastingPipeline(model=model, prediction_length=6, context_length=66) + + assert tspipe._preprocess_params["prediction_length"] == 6 + assert tspipe._preprocess_params["context_length"] == 66 + + def test_forecasting_pipeline_forecasts(patchtst_model): timestamp_column = "date" id_columns = [] From 591681c428ec7629efbae6c75db8bccddeb40985 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:15:05 -0400 Subject: [PATCH 3/3] improve message --- tsfm_public/toolkit/time_series_forecasting_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index 7dfa26f7..f51dbe2a 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -145,7 +145,7 @@ def __init__( model = kwargs.get("model", None) if not model: - raise ValueError("A model must be supplied") + raise ValueError("A model must be supplied during instantiation of a TimeSeriesForecastingPipeline") if "context_length" not in kwargs: kwargs["context_length"] = model.config.context_length