diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 0961a1cd..dc5376f8 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -3,7 +3,7 @@ """Tests basic dataset functions""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import numpy as np import pandas as pd @@ -317,3 +317,30 @@ def my_collate(batch): b = next(iter(dl)) assert len(b["target_values"].shape) == 1 + + +def test_datetime_handling(ts_data): + df = ts_data.copy() + df["time_date_utc_offset"] = pd.date_range(start="2022-10-01 10:00:00 +01:00", periods=10, freq="h") + + ds = ForecastDFDataset( + df, + timestamp_column="time_date_utc_offset", + id_columns=["id"], + target_columns=["val"], + context_length=3, + prediction_length=2, + ) + + assert ds[0]["timestamp"].tz == timezone(timedelta(hours=1)) + + ds = ForecastDFDataset( + df, + timestamp_column="time_date", + id_columns=["id"], + target_columns=["val"], + context_length=3, + prediction_length=2, + ) + + assert ds[0]["timestamp"].tz is None