diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index 87e8753c..defdd123 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -81,6 +81,34 @@ def test_forecasting_pipeline_forecasts(): forecasts_exploded = forecast_pipeline(test_data) assert forecasts_exploded.shape == (prediction_length, len(target_columns) + 1) + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + batch_size=10, + ) + + dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" + test_end_index = 12 * 30 * 24 + 8 * 30 * 24 + test_start_index = test_end_index - context_length - 9 + + data = pd.read_csv( + dataset_path, + parse_dates=[timestamp_column], + ) + + test_data = select_by_index( + data, + id_columns=id_columns, + start_index=test_start_index, + end_index=test_end_index, + ) + forecasts = forecast_pipeline(test_data) + assert forecast_pipeline._batch_size == 10 + assert forecasts.shape == (10, 2 * len(target_columns) + 1) + def test_forecasting_pipeline_forecasts_with_preprocessor(): timestamp_column = "date" @@ -92,30 +120,13 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): model = PatchTSTForPrediction.from_pretrained(model_path) context_length = model.config.context_length - tsp = TimeSeriesPreprocessor( - timestamp_column=timestamp_column, - id_columns=id_columns, - target_columns=target_columns, - context_length=context_length, - prediction_length=prediction_length, - freq="1h", - ) - - forecast_pipeline = TimeSeriesForecastingPipeline( - model=model, - timestamp_column=timestamp_column, - id_columns=id_columns, - target_columns=target_columns, - freq="1h", - feature_extractor=tsp, - explode_forecasts=False, - ) - dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" data = pd.read_csv( dataset_path, parse_dates=[timestamp_column], ) + train_end_index = 12 * 30 * 24 + test_end_index = 12 * 30 * 24 + 8 * 30 * 24 test_start_index = test_end_index - context_length - 4 @@ -124,6 +135,12 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): parse_dates=[timestamp_column], ) + train_data = select_by_index( + data, + id_columns=id_columns, + start_index=0, + end_index=train_end_index, + ) test_data = select_by_index( data, id_columns=id_columns, @@ -131,11 +148,35 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): end_index=test_end_index, ) - forecasts = forecast_pipeline(test_data) + tsp = TimeSeriesPreprocessor( + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + freq="1h", + scaling=True, + ) + + tsp.train(train_data) + + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + feature_extractor=tsp, + explode_forecasts=False, + inverse_scale_outputs=True, + ) + + forecasts = forecast_pipeline(tsp.preprocess(test_data)) assert forecasts.shape == ( test_end_index - test_start_index - context_length + 1, 2 * len(target_columns) + 1, ) - # to do: add check on the scaling + # if we have inverse scaled mean should be larger + assert forecasts["HUFL_prediction"].mean().mean() > 10 diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index 15f55db6..35bb2ec9 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -7,11 +7,14 @@ import pandas as pd import torch +from torch.utils.data import DataLoader +from transformers.data.data_collator import default_data_collator from transformers.pipelines.base import ( GenericTensor, Pipeline, build_pipeline_init_args, ) +from transformers.trainer_utils import RemoveColumnsCollator from transformers.utils import add_end_docstrings, logging from .dataset import ForecastDFDataset @@ -31,10 +34,75 @@ logger = logging.get_logger(__name__) +class TimeSeriesPipeline(Pipeline): + def run_single(self, inputs, preprocess_params, forward_params, postprocess_params): + """Replaces base `run_single` method which does batching during inference. This is needed to support + large inference requests. + + Args: + inputs (_type_): _description_ + preprocess_params (_type_): _description_ + forward_params (_type_): _description_ + postprocess_params (_type_): _description_ + + Returns: + _type_: _description_ + """ + # our preprocess returns a dataset + dataset = self.preprocess(inputs, **preprocess_params) + + batch_size = forward_params["batch_size"] + signature = inspect.signature(self.model.forward) + signature_columns = list(signature.parameters.keys()) + + # if len(dataset) < batch_size: + # build a dataloader + # collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor) + + remove_columns_collator = RemoveColumnsCollator( + data_collator=default_data_collator, + signature_columns=signature_columns, + logger=None, + description=None, + model_name=self.model.__class__.__name__, + ) + dataloader = DataLoader( + dataset, num_workers=1, batch_size=batch_size, collate_fn=remove_columns_collator, shuffle=False + ) + + # iterate over dataloader + it = iter(dataloader) + accumulator = [] + model_output_key = None + while (batch := next(it, None)) is not None: + item = self.forward(batch, **forward_params) + if not model_output_key: + model_output_key = "prediction_outputs" if "prediction_outputs" in item.keys() else "prediction_logits" + accumulator.append(item[model_output_key]) + + # collect all ouputs needed for post processing + first = dataset[0] + model_outputs = {} + for k, v in first.items(): + if isinstance(v, torch.Tensor): + model_outputs[k] = torch.stack(tuple(r[k] for r in dataset)) + else: + model_outputs[k] = [r[k] for r in dataset] + + # without shuffling in the dataloader above, we assume that order is preserved + # otherwise we need to incorporate sequence id somewhere and do a proper join + model_outputs["prediction_outputs"] = torch.cat(accumulator, axis=0) + + # call postprocess + outputs = self.postprocess(model_outputs, **postprocess_params) + + return outputs + + @add_end_docstrings( build_pipeline_init_args(has_tokenizer=False, has_feature_extractor=True, has_image_processor=False) ) -class TimeSeriesForecastingPipeline(Pipeline): +class TimeSeriesForecastingPipeline(TimeSeriesPipeline): """Hugging Face Pipeline for Time Series Forecasting feature_extractor (TimeSeriesPreprocessor): A time series preprpocessor object that specifies how the time @@ -112,6 +180,16 @@ def _sanitize_parameters( if c in kwargs: postprocess_kwargs[c] = kwargs[c] + # same logic as HF Pipeline + batch_size = kwargs.get("batch_size", self._batch_size) + if batch_size is None: + if self._batch_size is None: + batch_size = 1 + else: + batch_size = self._batch_size + + forward_kwargs = {"batch_size": batch_size} + # if "id_columns" in kwargs: # preprocess_kwargs["id_columns"] = kwargs["id_columns"] # postprocess_kwargs["id_columns"] = kwargs["id_columns"] @@ -128,7 +206,7 @@ def _sanitize_parameters( # preprocess_kwargs["output_columns"] = kwargs["input_columns"] # postprocess_kwargs["output_columns"] = kwargs["input_columns"] - return preprocess_kwargs, {}, postprocess_kwargs + return preprocess_kwargs, forward_kwargs, postprocess_kwargs def __call__( self, @@ -248,17 +326,18 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li **kwargs, ) - # stack all the outputs - # torch tensors are stacked, but other values are passed through as a list - first = dataset[0] - full_output = {} - for k, v in first.items(): - if isinstance(v, torch.Tensor): - full_output[k] = torch.stack(tuple(r[k] for r in dataset)) - else: - full_output[k] = [r[k] for r in dataset] + # # stack all the outputs + # # torch tensors are stacked, but other values are passed through as a list + # first = dataset[0] + # full_output = {} + # for k, v in first.items(): + # if isinstance(v, torch.Tensor): + # full_output[k] = torch.stack(tuple(r[k] for r in dataset)) + # else: + # full_output[k] = [r[k] for r in dataset] - return full_output + # return full_output + return dataset def _forward(self, model_inputs, **kwargs): """Forward step @@ -279,20 +358,22 @@ def _forward(self, model_inputs, **kwargs): # "freq_token", # } # todo: this should not be hardcoded - signature = inspect.signature(self.model.forward) - model_input_keys = list(signature.parameters.keys()) + # signature = inspect.signature(self.model.forward) + # model_input_keys = list(signature.parameters.keys()) + + # model_inputs_only = {} + # for k in model_input_keys: + # if k in model_inputs: + # model_inputs_only[k] = model_inputs[k] - model_inputs_only = {} - for k in model_input_keys: - if k in model_inputs: - model_inputs_only[k] = model_inputs[k] + # model_outputs = self.model(**model_inputs_only) - model_outputs = self.model(**model_inputs_only) + # # copy the other inputs + # copy_inputs = True + # for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]: + # model_outputs[k] = model_inputs[k] - # copy the other inputs - copy_inputs = True - for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]: - model_outputs[k] = model_inputs[k] + model_outputs = self.model(**model_inputs) return model_outputs @@ -307,7 +388,7 @@ def postprocess(self, input, **kwargs): """ out = {} - model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits" + model_output_key = "prediction_outputs" # if "prediction_outputs" in input.keys() else "prediction_logits" # name the predictions of target columns # outputs should only have size equal to target columns diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index 4de30c24..43a4745f 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -670,6 +670,7 @@ def get_datasets( split_config: Dict[str, Union[List[Union[int, float]], float]], fewshot_fraction: Optional[float] = None, fewshot_location: str = FractionLocation.LAST.value, + return_dataframe: bool = False, ) -> Tuple[Any]: """Creates the preprocessed pytorch datasets needed for training and evaluation using the HuggingFace trainer @@ -697,6 +698,8 @@ def get_datasets( fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last" as described in the enum FewshotLocation. Default is to choose the fewshot data at the end of the training dataset (i.e., "last"). + return_dataframe: Instead for returning a pytorch dataset, return tuples of pandas dataframes, after any + preprocessing. Returns: Tuple of pytorch datasets, including: train, validation, test. @@ -752,16 +755,23 @@ def get_datasets( params["prediction_length"] = self.prediction_length # get torch datasets - test_dataset = ForecastDFDataset( - self.preprocess(test_data), - **params, - ) - train_dataset = ForecastDFDataset(self.preprocess(train_data), **params) - valid_dataset = ForecastDFDataset( - self.preprocess(valid_data), - **params, - ) - return train_dataset, valid_dataset, test_dataset + train_valid_test = [train_data, valid_data, test_data] + + if return_dataframe: + return tuple(train_valid_test) + + return tuple([ForecastDFDataset(self.preprocess(d), **params) for d in train_valid_test]) + + # test_dataset = ForecastDFDataset( + # self.preprocess(test_data), + # **params, + # ) + # train_dataset = ForecastDFDataset(self.preprocess(train_data), **params) + # valid_dataset = ForecastDFDataset( + # self.preprocess(valid_data), + # **params, + # ) + # return train_dataset, valid_dataset, test_dataset def create_timestamps(