Skip to content

Commit

Permalink
Merge pull request #39 from IBM/issue_38
Browse files Browse the repository at this point in the history
Issue 38
  • Loading branch information
wgifford authored Apr 24, 2024
2 parents 064cbb8 + 6017681 commit 92d3fa2
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 23 deletions.
17 changes: 8 additions & 9 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
}
},
"isort.args": [
"--profile",
"black"
],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"ruff.format.args": [
"--config=./pyproject.toml"
],
"ruff.lint.args": [
"--config=./pyproject.toml"
]
}
17 changes: 15 additions & 2 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def test_create_timestamps():
2,
[103.5, 107.0],
),
(
pd.Timestamp(2021, 12, 31),
"QE",
None,
4,
[
pd.Timestamp(2022, 3, 31),
pd.Timestamp(2022, 6, 30),
pd.Timestamp(2022, 9, 30),
pd.Timestamp(2022, 12, 31),
],
),
]

for start, freq, sequence, periods, expected in test_cases:
Expand All @@ -220,8 +232,9 @@ def test_create_timestamps():
assert ts == expected

# test based on provided sequence
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected
if sequence is not None:
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected

# it is an error to provide neither freq or sequence
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_group_data(
):
return cls(
data_df=group,
group_id=group_id,
group_id=group_id if isinstance(group_id, tuple) else (group_id,),
id_columns=id_columns,
timestamp_column=timestamp_column,
context_length=context_length,
Expand Down
4 changes: 4 additions & 0 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def postprocess(self, input, **kwargs):
"""Postprocess step
Takes the dictionary of outputs from the previous step and converts to a more user
readable pandas format.
If the explode forecasts option is True, then individual forecasts are expanded as multiple
rows in the dataframe. This should only be used when producing a single forecast (i.e., unexploded
result is one row per ID).
"""
out = {}

Expand Down
54 changes: 43 additions & 11 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,29 @@ def _standardize_dataframe(

return df

def _clean_up_dataframe(self, df: pd.DataFrame) -> None:
"""Removes columns added during internal processing of the provided dataframe.
Currently, the following checks are done:
- Remove INTERNAL_ID_COLUMN if present
Args:
df (pd.DataFrame): Input pandas dataframe
Returns:
pd.DataFrame: Cleaned up dataframe
"""

if not self.id_columns:
if INTERNAL_ID_COLUMN in df.columns:
df.drop(columns=INTERNAL_ID_COLUMN, inplace=True)

def _get_groups(
self,
dataset: pd.DataFrame,
) -> Generator[Tuple[Any, pd.DataFrame], None, None]:
"""Get groups of the time series dataset (multi-time series) based on the ID columns.
"""Get groups of the time series dataset (multi-time series) based on the ID columns for scaling.
Note that this is used for scaling purposes only.
Args:
dataset (pd.DataFrame): Input dataset
Expand Down Expand Up @@ -472,7 +490,7 @@ def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]):

def _set_targets(self, dataset: pd.DataFrame) -> None:
if self.target_columns == []:
skip_columns = copy.copy(self.id_columns)
skip_columns = copy.copy(self.id_columns) + [INTERNAL_ID_COLUMN]
if self.timestamp_column:
skip_columns.append(self.timestamp_column)

Expand Down Expand Up @@ -531,6 +549,7 @@ def train(
if self.encode_categorical:
self._train_categorical_encoder(df)

self._clean_up_dataframe(df)
return self

def inverse_scale_targets(
Expand Down Expand Up @@ -581,10 +600,12 @@ def inverse_scale_func(grp, id_columns):
else:
id_columns = INTERNAL_ID_COLUMN

return df.groupby(id_columns, group_keys=False).apply(
df_inv = df.groupby(id_columns, group_keys=False).apply(
inverse_scale_func,
id_columns=id_columns,
)
self._clean_up_dataframe(df_inv)
return df_inv

def preprocess(
self,
Expand Down Expand Up @@ -640,6 +661,7 @@ def scale_func(grp, id_columns):
raise RuntimeError("Attempt to encode categorical columns, but the encoder has not been trained yet.")
df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode])

self._clean_up_dataframe(df)
return df

def get_datasets(
Expand Down Expand Up @@ -759,14 +781,24 @@ def create_timestamps(

# more complex logic is required to support all edge cases
if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)):
if isinstance(freq, str):
freq = pd._libs.tslibs.timedeltas.Timedelta(freq)

return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
try:
# try date range directly
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
except ValueError as e:
# if it fails, we can try to compute a timedelta from the provided string
if isinstance(freq, str):
freq = pd._libs.tslibs.timedeltas.Timedelta(freq)
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
else:
raise e
else:
# numerical timestamp column
return [last_timestamp + i * freq for i in range(1, periods + 1)]
Expand Down

0 comments on commit 92d3fa2

Please sign in to comment.