Skip to content

Commit

Permalink
Merge pull request #124 from ibm-granite/minor_fixes
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
wgifford authored Aug 29, 2024
2 parents 6f55407 + 4a7439f commit ff96377
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
29 changes: 28 additions & 1 deletion tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Tests basic dataset functions"""

from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -317,3 +317,30 @@ def my_collate(batch):
b = next(iter(dl))

assert len(b["target_values"].shape) == 1


def test_datetime_handling(ts_data):
df = ts_data.copy()
df["time_date_utc_offset"] = pd.date_range(start="2022-10-01 10:00:00 +01:00", periods=10, freq="h")

ds = ForecastDFDataset(
df,
timestamp_column="time_date_utc_offset",
id_columns=["id"],
target_columns=["val"],
context_length=3,
prediction_length=2,
)

assert ds[0]["timestamp"].tz == timezone(timedelta(hours=1))

ds = ForecastDFDataset(
df,
timestamp_column="time_date",
id_columns=["id"],
target_columns=["val"],
context_length=3,
prediction_length=2,
)

assert ds[0]["timestamp"].tz is None
3 changes: 1 addition & 2 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def __init__(
data_df = self.pad_zero(data_df)

if timestamp_column in list(data_df.columns):
self.timestamps = data_df[timestamp_column].values

self.timestamps = data_df[timestamp_column].to_list() # .values coerces timestamps
# get the input data
if len(x_cols) > 0:
self.X = data_df[x_cols]
Expand Down
3 changes: 3 additions & 0 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def plot_predictions(
):
random_indices = indices

if random_indices is not None:
num_plots = len(random_indices)

# possible operations:
if test_df is not None and predictions_df is not None:
# 1) test_df and predictions plus column information is provided
Expand Down

0 comments on commit ff96377

Please sign in to comment.