Skip to content

Commit

Permalink
better forecast output, support freq specifier
Browse files Browse the repository at this point in the history
Signed-off-by: Wesley M. Gifford <[email protected]>
  • Loading branch information
wgifford committed Mar 8, 2024
1 parent 050ea67 commit a85be02
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 80 deletions.
131 changes: 53 additions & 78 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
"""Hugging Face Pipeline for Time Series Tasks"""

from typing import Any, Dict, List, Union
import inspect
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import torch
Expand All @@ -14,6 +15,7 @@
from transformers.utils import add_end_docstrings, logging

from .dataset import ForecastDFDataset
from .time_series_preprocessor import create_timestamps, extend_time_series


# Eventually we should support all time series models
Expand All @@ -37,12 +39,22 @@
class TimeSeriesForecastingPipeline(Pipeline):
"""Hugging Face Pipeline for Time Series Forecasting"""

def __init__(self, *args, **kwargs):
# has_feature_extractor means we can pass feature_extractor=TimeSeriesPreprocessor

def __init__(
self,
*args,
single_forecast: bool = True,
freq: Optional[Union[Any]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)

if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

self.single_forecast = single_forecast
self.freq = freq
# self.check_model_type(MODEL_FOR_TIME_SERIES_FORECASTING_MAPPING)

def _sanitize_parameters(self, **kwargs):
Expand All @@ -64,13 +76,6 @@ def _sanitize_parameters(self, **kwargs):
"prediction_length": prediction_length,
"context_length": context_length,
}
# id_columns: List[str] = [],
# timestamp_column: Optional[str] = None,
# target_columns: List[str] = [],
# observable_columns: List[str] = [],
# control_columns: List[str] = [],
# conditional_columns: List[str] = [],
# static_categorical_columns: List[str] = [],

preprocess_params = [
"id_columns",
Expand Down Expand Up @@ -196,6 +201,7 @@ def preprocess(
parse_dates=[timestamp_column],
)
elif isinstance(future_time_series, pd.DataFrame):
# do we need to check the timestamp column?
pass
else:
raise ValueError(
Expand All @@ -206,14 +212,13 @@ def preprocess(
for c in future_time_series.columns:
if c not in time_series.columns:
raise ValueError(
f"Future time series input contains an unknown column {c}"
f"Future time series input contains an unknown column {c}."
)

time_series = pd.concat((time_series, future_time_series), axis=0)
else:
# not additional exogenous data provided, augment with empty periods

time_series = augment_time_series(
# no additional exogenous data provided, extend with empty periods
time_series = extend_time_series(
time_series=time_series,
timestamp_column=timestamp_column,
grouping_columns=id_columns,
Expand Down Expand Up @@ -251,11 +256,15 @@ def _forward(self, model_inputs, **kwargs):

# Eventually we should use inspection somehow
# inspect.signature(model_forward).parameters.keys()
model_input_keys = {
"past_values",
"static_categorical_values",
"freq_token",
} # todo: this should not be hardcoded
# model_input_keys = {
# "past_values",
# "static_categorical_values",
# "freq_token",
# } # todo: this should not be hardcoded

signature = inspect.signature(self.model.forward)
model_input_keys = list(signature.parameters.keys())

model_inputs_only = {}
for k in model_input_keys:
if k in model_inputs:
Expand Down Expand Up @@ -289,8 +298,12 @@ def postprocess(self, input, **kwargs):

# name the predictions of target columns
# outputs should only have size equal to target columns
prediction_columns = []
for i, c in enumerate(kwargs["target_columns"]):
out[f"{c}_prediction"] = input[model_output_key][:, :, i].numpy().tolist()
prediction_columns.append(f"{c}_prediction")
out[prediction_columns[-1]] = (
input[model_output_key][:, :, i].numpy().tolist()
)
# provide the ground truth values for the targets
# when future is unknown, we will have augmented the provided dataframe with NaN values to cover the future
for i, c in enumerate(kwargs["target_columns"]):
Expand All @@ -302,6 +315,27 @@ def postprocess(self, input, **kwargs):
out[c] = [elem[i] for elem in input["id"]]
out = pd.DataFrame(out)

if self.single_forecast:
# we made only one forecast per time series, explode results
# explode == expand the lists in the dataframe
out_explode = []
for _, row in out.iterrows():
l = len(row[prediction_columns[0]])
tmp = {}
if "timestamp_column" in kwargs:
tmp[kwargs["timestamp_column"]] = create_timestamps(
row[kwargs["timestamp_column"]], freq=self.freq, periods=l
) # expand timestamps
if "id_columns" in kwargs:
for c in kwargs["id_columns"]:
tmp[c] = row[c]
for p in prediction_columns:
tmp[p] = row[p]

out_explode.append(pd.DataFrame(tmp))

out = pd.concat(out_explode)

# reorder columns
cols = out.columns.to_list()
cols_ordered = []
Expand All @@ -313,62 +347,3 @@ def postprocess(self, input, **kwargs):

out = out[cols_ordered]
return out


def augment_time_series(
time_series: pd.DataFrame,
# last_known_timestamp,
timestamp_column: str,
grouping_columns: List[str],
periods: int = 1,
# delta: datetime.timedelta = datetime.timedelta(days=1),
):
"""Augments the provided time series with empty data for the number of periods specified. For each time series, based
on groups defined by grouping columns, adds emptry records following the last timestamp. The empty records contain
only timestamps and grouping indicators, remaining fields will be null.
Args:
time_series (pd.DataFrame): _description_
start_timestamp (_type_): _description_
column_name (str): _description_
grouping_columns (List[str]): _description_
periods (int, optional): _description_. Defaults to 1.
delta (datetime.timedelta, optional): _description_. Defaults to datetime.timedelta(days=1).
"""

def augment_one_series(group: Union[pd.Series, pd.DataFrame]):

last_timestamp = group[timestamp_column].iloc[-1]
delta = group[timestamp_column].iloc[-1] - group[timestamp_column].iloc[-2]

new_data = pd.DataFrame(
{
timestamp_column: pd.date_range(
last_timestamp + delta,
freq=delta,
periods=periods,
)
}
)

# for c in grouping_columns:
# new_data[c] = group[c].iloc[0]

df = pd.concat(
(group, new_data),
axis=0,
)
return df.reset_index(drop=True)

if grouping_columns == []:
new_time_series = augment_one_series(time_series)
else:
new_time_series = time_series.groupby(grouping_columns).apply(
augment_one_series, include_groups=False
)
idx_names = list(new_time_series.index.names)
idx_names[-1] = "__delete"
new_time_series = new_time_series.reset_index(names=idx_names)
new_time_series.drop(columns=["__delete"], inplace=True)

return new_time_series
117 changes: 115 additions & 2 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#
"""Preprocessor for time series data preparation"""

import datetime
import enum
import json
from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -104,11 +106,12 @@ def __init__(
context_length: int = 64,
prediction_length: Optional[int] = None,
scaling: bool = False,
scale_outputs: bool = False,
# scale_outputs: bool = False,
scaler_type: ScalerType = ScalerType.STANDARD.value,
encode_categorical: bool = True,
time_series_task: str = TimeSeriesTask.FORECASTING.value,
frequency_mapping: Dict[str, int] = DEFAULT_FREQUENCY_MAPPING,
freq: Optional[Union[int, float, timedelta, pd.Timedelta, str]] = None,
**kwargs,
):
# note base class __init__ methods sets all arguments as attributes
Expand All @@ -131,14 +134,15 @@ def __init__(
self.scaling = scaling
self.encode_categorical = encode_categorical
self.time_series_task = time_series_task
self.scale_outputs = scale_outputs
# self.scale_outputs = scale_outputs
self.scaler_type = scaler_type

# we maintain two scalers per time series to facilitate inverse scaling of the targets
self.scaler_dict = dict()
self.target_scaler_dict = dict()
self.categorical_encoder = None
self.frequency_mapping = frequency_mapping
self.freq = freq

kwargs["processor_class"] = self.__class__.__name__

Expand Down Expand Up @@ -468,6 +472,24 @@ def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]):
if dataset is None or len(dataset) == 0:
raise ValueError("Input dataset must not be null or zero length.")

def _estimate_frequency(self, df: pd.DataFrame):
if self.timestamp_column:
if self.id_columns:
# to do: be more efficient
grps = df.groupby(self.id_columns)
_, df_subset = list(grps)[0]
else:
df_subset = df

# to do: make more robust
self.freq = (
df_subset[self.timestamp_column].iloc[-1]
- df_subset[self.timestamp_column].iloc[-2]
)
else:
# no timestamp, assume sequential count?
self.freq = 1

def train(
self,
dataset: Union[Dataset, pd.DataFrame],
Expand All @@ -487,6 +509,9 @@ def train(

df = self._standardize_dataframe(dataset)

if self.freq is None:
self._estimate_frequency(df)

if self.scaling:
self._train_scaler(df)

Expand Down Expand Up @@ -590,3 +615,91 @@ def scale_func(grp, id_columns):
df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode])

return df


def create_timestamps(
last_timestamp: Union[datetime.datetime, pd.Timestamp],
freq: Optional[Union[int, float, datetime.timedelta, pd.Timedelta, str]] = None,
time_sequence: Optional[
Union[List[int], List[float], List[datetime.datetime], List[pd.Timestamp]]
] = None,
periods: int = 1,
):
"""Simple utility to create a list of timestamps based on start, delta and number of periods"""

if freq is None and time_sequence is None:
raise ValueError(
"Neither `freq` nor `time_sequence` provided, cannot determine frequency."
)

if freq is None:
# to do: make more robust
freq = time_sequence[-1] - time_sequence[-2]

# more complex logic is required to support all edge cases
if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)):
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
else:
# numerical timestamp column
return [last_timestamp + i * freq for i in range(1, periods + 1)]


def extend_time_series(
time_series: pd.DataFrame,
# last_known_timestamp,
timestamp_column: str,
grouping_columns: List[str],
freq: Optional[Union[int, float, datetime.timedelta, pd.Timedelta]] = None,
periods: int = 1,
# delta: datetime.timedelta = datetime.timedelta(days=1),
):
"""Extends the provided time series with empty data for the number of periods specified. For each time series, based
on groups defined by grouping columns, adds emptry records following the last timestamp. The empty records contain
only timestamps and grouping indicators, remaining fields will be null.
Args:
time_series (pd.DataFrame): _description_
start_timestamp (_type_): _description_
column_name (str): _description_
grouping_columns (List[str]): _description_
periods (int, optional): _description_. Defaults to 1.
delta (datetime.timedelta, optional): _description_. Defaults to datetime.timedelta(days=1).
"""

def augment_one_series(group: Union[pd.Series, pd.DataFrame]):

last_timestamp = group[timestamp_column].iloc[-1]

new_data = pd.DataFrame(
{
timestamp_column: create_timestamps(
last_timestamp,
freq=freq,
time_sequence=group[timestamp_column].values,
periods=periods,
)
}
)

df = pd.concat(
(group, new_data),
axis=0,
)
return df.reset_index(drop=True)

if grouping_columns == []:
new_time_series = augment_one_series(time_series)
else:
new_time_series = time_series.groupby(grouping_columns).apply(
augment_one_series, include_groups=False
)
idx_names = list(new_time_series.index.names)
idx_names[-1] = "__delete"
new_time_series = new_time_series.reset_index(names=idx_names)
new_time_series.drop(columns=["__delete"], inplace=True)

return new_time_series

0 comments on commit a85be02

Please sign in to comment.