From ca16f993f66e0377bc2cf6d8d64fcb4bb4146233 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 13 Aug 2024 11:28:46 +0000 Subject: [PATCH] slight tidy --- .../select/find_contiguous_time_periods.py | 29 +++++++++++++++---- ocf_data_sampler/select/select_time_slice.py | 2 -- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/ocf_data_sampler/select/find_contiguous_time_periods.py b/ocf_data_sampler/select/find_contiguous_time_periods.py index 31f4a07..9013513 100644 --- a/ocf_data_sampler/select/find_contiguous_time_periods.py +++ b/ocf_data_sampler/select/find_contiguous_time_periods.py @@ -66,7 +66,7 @@ def trim_contiguous_time_periods( history_duration: pd.Timedelta, forecast_duration: pd.Timedelta, ) -> pd.DataFrame: - """Trim the contiguous time periods in-place to allow for history and forecast durations. + """Trim the contiguous time periods to allow for history and forecast durations. Args: contiguous_time_periods: DataFrame where each row represents a single time period. The @@ -78,9 +78,14 @@ def trim_contiguous_time_periods( Returns: The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated. """ + contiguous_time_periods = contiguous_time_periods.copy() + contiguous_time_periods["start_dt"] += history_duration contiguous_time_periods["end_dt"] -= forecast_duration - assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all() + + valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"] + contiguous_time_periods = contiguous_time_periods.loc[valid_mask] + return contiguous_time_periods @@ -91,7 +96,19 @@ def find_contiguous_t0_periods( forecast_duration: pd.Timedelta, sample_period_duration: pd.Timedelta, ) -> pd.DataFrame: - + """Return a pd.DataFrame where each row records the boundary of a contiguous time period. + + Args: + datetimes: pd.DatetimeIndex. Must be sorted. + history_duration: Length of the historical slice used for each sample + forecast_duration: Length of the forecast slice used for each sample + sample_period_duration: The sample frequency of the timeseries + + + Returns: + pd.DataFrame where each row represents a single time period. The pd.DataFrame + has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). + """ total_duration = history_duration + forecast_duration contiguous_time_periods = find_contiguous_time_periods( @@ -106,6 +123,8 @@ def find_contiguous_t0_periods( forecast_duration=forecast_duration, ) + assert len(contiguous_t0_periods) > 0 + return contiguous_t0_periods @@ -192,13 +211,13 @@ def find_contiguous_t0_periods_nwp( # considering dropout - then the contiguous period breaks, and new starts with considering # dropout and history duration if end_this_period < dt_init + max_dropout: - contiguous_periods += [[start_this_period, end_this_period]] + contiguous_periods.append([start_this_period, end_this_period]) # And start a new period start_this_period = dt_init + hist_drop_buffer end_this_period = dt_init + max_staleness - contiguous_periods += [[start_this_period, end_this_period]] + contiguous_periods.append([start_this_period, end_this_period]) return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"]) diff --git a/ocf_data_sampler/select/select_time_slice.py b/ocf_data_sampler/select/select_time_slice.py index 4b548e0..9fa2641 100644 --- a/ocf_data_sampler/select/select_time_slice.py +++ b/ocf_data_sampler/select/select_time_slice.py @@ -103,8 +103,6 @@ def select_time_slice_nwp( ds[channel_dim_name].values, accum_channels ) - # TODO: Once all times have been converted to pd.Timestamp, remove this - t0 = pd.Timestamp(t0) start_dt = (t0 - history_duration).ceil(sample_period_duration) end_dt = (t0 + forecast_duration).ceil(sample_period_duration)