From e29b2573297b91056bf102c56e4448e646513bf6 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Thu, 25 Jul 2024 15:51:58 -0400 Subject: [PATCH] option to add ground truth data when available --- .../toolkit/time_series_forecasting_pipeline.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index 9226eaa2..39e3e0e5 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -114,11 +114,13 @@ def __init__( freq: Optional[str] = None, explode_forecasts: bool = False, inverse_scale_outputs: bool = True, + add_known_ground_truth: bool = True, **kwargs, ): kwargs["freq"] = freq kwargs["explode_forecasts"] = explode_forecasts kwargs["inverse_scale_outputs"] = inverse_scale_outputs + kwargs["add_known_ground_truth"] = add_known_ground_truth super().__init__(*args, **kwargs) if self.framework == "tf": @@ -168,6 +170,7 @@ def _sanitize_parameters( "freq", "explode_forecasts", "inverse_scale_outputs", + "add_known_ground_truth", ] for c in preprocess_params: @@ -367,14 +370,16 @@ def postprocess(self, input, **kwargs): # name the predictions of target columns # outputs should only have size equal to target columns + prediction_columns = [] for i, c in enumerate(kwargs["target_columns"]): - prediction_columns.append(f"{c}_prediction") + prediction_columns.append(f"{c}_prediction" if kwargs["add_known_ground_truth"] else c) out[prediction_columns[-1]] = input[model_output_key][:, :, i].numpy().tolist() # provide the ground truth values for the targets # when future is unknown, we will have augmented the provided dataframe with NaN values to cover the future - for i, c in enumerate(kwargs["target_columns"]): - out[c] = input["future_values"][:, :, i].numpy().tolist() + if kwargs["add_known_ground_truth"]: + for i, c in enumerate(kwargs["target_columns"]): + out[c] = input["future_values"][:, :, i].numpy().tolist() if "timestamp_column" in kwargs: out[kwargs["timestamp_column"]] = input["timestamp"]