Skip to content

Commit

Permalink
slight tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Aug 13, 2024
1 parent 7d46f3f commit ca16f99
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
29 changes: 24 additions & 5 deletions ocf_data_sampler/select/find_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def trim_contiguous_time_periods(
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
) -> pd.DataFrame:
"""Trim the contiguous time periods in-place to allow for history and forecast durations.
"""Trim the contiguous time periods to allow for history and forecast durations.
Args:
contiguous_time_periods: DataFrame where each row represents a single time period. The
Expand All @@ -78,9 +78,14 @@ def trim_contiguous_time_periods(
Returns:
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
"""
contiguous_time_periods = contiguous_time_periods.copy()

contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()

valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]

return contiguous_time_periods


Expand All @@ -91,7 +96,19 @@ def find_contiguous_t0_periods(
forecast_duration: pd.Timedelta,
sample_period_duration: pd.Timedelta,
) -> pd.DataFrame:

"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
Args:
datetimes: pd.DatetimeIndex. Must be sorted.
history_duration: Length of the historical slice used for each sample
forecast_duration: Length of the forecast slice used for each sample
sample_period_duration: The sample frequency of the timeseries
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
total_duration = history_duration + forecast_duration

contiguous_time_periods = find_contiguous_time_periods(
Expand All @@ -106,6 +123,8 @@ def find_contiguous_t0_periods(
forecast_duration=forecast_duration,
)

assert len(contiguous_t0_periods) > 0

return contiguous_t0_periods


Expand Down Expand Up @@ -192,13 +211,13 @@ def find_contiguous_t0_periods_nwp(
# considering dropout - then the contiguous period breaks, and new starts with considering
# dropout and history duration
if end_this_period < dt_init + max_dropout:
contiguous_periods += [[start_this_period, end_this_period]]
contiguous_periods.append([start_this_period, end_this_period])

# And start a new period
start_this_period = dt_init + hist_drop_buffer
end_this_period = dt_init + max_staleness

contiguous_periods += [[start_this_period, end_this_period]]
contiguous_periods.append([start_this_period, end_this_period])

return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"])

Expand Down
2 changes: 0 additions & 2 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def select_time_slice_nwp(
ds[channel_dim_name].values, accum_channels
)

# TODO: Once all times have been converted to pd.Timestamp, remove this
t0 = pd.Timestamp(t0)
start_dt = (t0 - history_duration).ceil(sample_period_duration)
end_dt = (t0 + forecast_duration).ceil(sample_period_duration)

Expand Down

0 comments on commit ca16f99

Please sign in to comment.