Skip to content

Commit

Permalink
Merge pull request #114 from ibm-granite/pipeline_updates
Browse files Browse the repository at this point in the history
Update forecasting pipeline
  • Loading branch information
wgifford authored Aug 22, 2024
2 parents 2021c10 + 591681c commit 6cb3f1a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 42 deletions.
18 changes: 16 additions & 2 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd
import pytest
from transformers import PatchTSTForPrediction
from transformers import PatchTSTConfig, PatchTSTForPrediction

from tsfm_public import TinyTimeMixerConfig, TinyTimeMixerForPrediction
from tsfm_public.toolkit.time_series_forecasting_pipeline import (
Expand All @@ -17,7 +17,7 @@

@pytest.fixture(scope="module")
def patchtst_model():
model_path = "ibm-granite/granite-timeseries-patchtst"
model_path = "ibm/test-patchtst"
model = PatchTSTForPrediction.from_pretrained(model_path)

return model
Expand Down Expand Up @@ -81,6 +81,20 @@ def etth_data():
return train_data, test_data, params


def test_forecasting_pipeline_defaults():
model = PatchTSTForPrediction(PatchTSTConfig(prediction_length=3, context_length=33))

tspipe = TimeSeriesForecastingPipeline(model=model)

assert tspipe._preprocess_params["prediction_length"] == 3
assert tspipe._preprocess_params["context_length"] == 33

tspipe = TimeSeriesForecastingPipeline(model=model, prediction_length=6, context_length=66)

assert tspipe._preprocess_params["prediction_length"] == 6
assert tspipe._preprocess_params["context_length"] == 66


def test_forecasting_pipeline_forecasts(patchtst_model):
timestamp_column = "date"
id_columns = []
Expand Down
95 changes: 55 additions & 40 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,44 @@ def __init__(
kwargs["explode_forecasts"] = explode_forecasts
kwargs["inverse_scale_outputs"] = inverse_scale_outputs
kwargs["add_known_ground_truth"] = add_known_ground_truth

# autopopulate from feature extractor and model
if "feature_extractor" in kwargs:
for p in [
"id_columns",
"timestamp_column",
"target_columns",
"observable_columns",
"control_columns",
"conditional_columns",
"static_categorical_columns",
"freq",
]:
if p not in kwargs:
kwargs[p] = getattr(kwargs["feature_extractor"], p)

# get freq from kwargs or the preprocessor
if "freq" not in kwargs:
kwargs["freq"] = kwargs["feature_extractor"].freq

model = kwargs.get("model", None)
if not model:
raise ValueError("A model must be supplied during instantiation of a TimeSeriesForecastingPipeline")

if "context_length" not in kwargs:
kwargs["context_length"] = model.config.context_length

if "prediction_length" not in kwargs:
kwargs["prediction_length"] = model.config.prediction_length

# check if we need to use the frequency token, get token if needed
use_frequency_token = getattr(model.config, "resolution_prefix_tuning", False)

if use_frequency_token and "feature_extractor" in kwargs:
kwargs["frequency_token"] = kwargs["feature_extractor"].get_frequency_token(kwargs["freq"])
else:
kwargs["frequency_token"] = None

super().__init__(*args, **kwargs)

if self.framework == "tf":
Expand All @@ -140,48 +178,13 @@ def _sanitize_parameters(
For expected parameters see the call method below.
"""

context_length = kwargs.get("context_length", self.model.config.context_length)
prediction_length = kwargs.get("prediction_length", self.model.config.prediction_length)

# get freq from kwargs or the preprocessor
freq = kwargs.get("freq", None)
if self.feature_extractor and not freq:
freq = self.feature_extractor.freq

# check if we need to use the frequency token, get token in needed
use_frequency_token = getattr(self.model.config, "resolution_prefix_tuning", False)

if use_frequency_token and self.feature_extractor:
frequency_token = self.feature_extractor.get_frequency_token(freq)
else:
frequency_token = None

# autopopulate from feature extractor
if self.feature_extractor:
for p in [
"id_columns",
"timestamp_column",
"target_columns",
"observable_columns",
"control_columns",
"conditional_columns",
"static_categorical_columns",
"freq",
]:
if p not in kwargs:
kwargs[p] = getattr(self.feature_extractor, p)

preprocess_kwargs = {
"prediction_length": prediction_length,
"context_length": context_length,
"frequency_token": frequency_token,
}
postprocess_kwargs = {
"prediction_length": prediction_length,
"context_length": context_length,
}
preprocess_kwargs = {}
postprocess_kwargs = {}

preprocess_params = [
"prediction_length",
"context_length",
"frequency_token",
"id_columns",
"timestamp_column",
"target_columns",
Expand All @@ -192,6 +195,8 @@ def _sanitize_parameters(
"future_time_series",
]
postprocess_params = [
"prediction_length",
"context_length",
"id_columns",
"timestamp_column",
"target_columns",
Expand Down Expand Up @@ -356,6 +361,16 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li
if c not in time_series.columns:
raise ValueError(f"Future time series input contains an unknown column {c}.")

if id_columns:
id_count = time_series[id_columns].unique().shape[0]
else:
id_count = 1

if future_time_series.shape[0] != prediction_length * id_count:
raise ValueError(
f"If provided, `future_time_series` data should cover the prediction length for each of the time series in the test dataset. Received data of length {future_time_series.shape[0]} but expected {prediction_length * id_count}"
)

time_series = pd.concat((time_series, future_time_series), axis=0)
else:
# no additional exogenous data provided, extend with empty periods
Expand Down

0 comments on commit 6cb3f1a

Please sign in to comment.