Skip to content

Commit

Permalink
option to add ground truth data when available
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Jul 25, 2024
1 parent 52214ca commit e29b257
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def __init__(
freq: Optional[str] = None,
explode_forecasts: bool = False,
inverse_scale_outputs: bool = True,
add_known_ground_truth: bool = True,
**kwargs,
):
kwargs["freq"] = freq
kwargs["explode_forecasts"] = explode_forecasts
kwargs["inverse_scale_outputs"] = inverse_scale_outputs
kwargs["add_known_ground_truth"] = add_known_ground_truth
super().__init__(*args, **kwargs)

if self.framework == "tf":
Expand Down Expand Up @@ -168,6 +170,7 @@ def _sanitize_parameters(
"freq",
"explode_forecasts",
"inverse_scale_outputs",
"add_known_ground_truth",
]

for c in preprocess_params:
Expand Down Expand Up @@ -367,14 +370,16 @@ def postprocess(self, input, **kwargs):

# name the predictions of target columns
# outputs should only have size equal to target columns

prediction_columns = []
for i, c in enumerate(kwargs["target_columns"]):
prediction_columns.append(f"{c}_prediction")
prediction_columns.append(f"{c}_prediction" if kwargs["add_known_ground_truth"] else c)
out[prediction_columns[-1]] = input[model_output_key][:, :, i].numpy().tolist()
# provide the ground truth values for the targets
# when future is unknown, we will have augmented the provided dataframe with NaN values to cover the future
for i, c in enumerate(kwargs["target_columns"]):
out[c] = input["future_values"][:, :, i].numpy().tolist()
if kwargs["add_known_ground_truth"]:
for i, c in enumerate(kwargs["target_columns"]):
out[c] = input["future_values"][:, :, i].numpy().tolist()

if "timestamp_column" in kwargs:
out[kwargs["timestamp_column"]] = input["timestamp"]
Expand Down

0 comments on commit e29b257

Please sign in to comment.