Skip to content

Commit

Permalink
Merge pull request #199 from openclimatefix/fill_select
Browse files Browse the repository at this point in the history
Fill select
  • Loading branch information
dfulu authored May 16, 2023
2 parents 39ea02b + a2bf1ce commit 6edcfce
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 11 deletions.
28 changes: 22 additions & 6 deletions ocf_datapipes/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
forecast_duration: Optional[timedelta] = None,
interval_start: Optional[timedelta] = None,
interval_end: Optional[timedelta] = None,
fill_selection: Optional[bool] = False,
):
"""
Selects time slice. Either `history_duration` and `history_duration` or `interval_start` and
Expand All @@ -38,9 +39,14 @@ def __init__(
forecast_duration (optional): Forecast time used
interval_start (optional): timedelta with respect to t0 where the open interval begins
interval_end (optional): timedelta with respect to t0 where the open interval ends
fill_selection (optional): If True, and if the data yielded from `source_datapipe` does
not extend over the entire requested time period. The missing timestamps are filled
with NaN values in the returned xarray object. Else the default xarray slicing
behaviour is used.
"""
self.source_datapipe = source_datapipe
self.t0_datapipe = t0_datapipe
self.fill_selection = fill_selection

used_duration = history_duration is not None and forecast_duration is not None
used_intervals = interval_start is not None and interval_end is not None
Expand All @@ -55,6 +61,18 @@ def __init__(

self.sample_period_duration = sample_period_duration

def _sel_fillnan(self, xr_data, start_dt, end_dt):
requested_times = pd.date_range(
start_dt,
end_dt,
freq=self.sample_period_duration,
)
# Missing time indexes are returned with all NaN values
return xr_data.reindex(time_utc=requested_times)

def _sel_default(self, xr_data, start_dt, end_dt):
return xr_data.sel(time_utc=slice(start_dt, end_dt))

def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
xr_data = next(iter(self.source_datapipe))

Expand All @@ -66,9 +84,7 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
start_dt = start_dt.ceil(self.sample_period_duration)
end_dt = end_dt.ceil(self.sample_period_duration)

yield xr_data.sel(
time_utc=slice(
start_dt,
end_dt,
)
)
if self.fill_selection:
yield self._sel_fillnan(xr_data, start_dt, end_dt)
else:
yield self._sel_default(xr_data, start_dt, end_dt)
26 changes: 21 additions & 5 deletions ocf_datapipes/training/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def minutes(num_mins: int):


def slice_datapipes_by_time(
datapipes_dict: Dict, t0_datapipe: IterDataPipe, configuration: Configuration
datapipes_dict: Dict,
t0_datapipe: IterDataPipe,
configuration: Configuration,
production: bool = False,
) -> None:
"""
Modifies a dictionary of datapipes in-place to yield samples for given times t0.
Expand Down Expand Up @@ -289,6 +292,9 @@ def slice_datapipes_by_time(
datapipes_dict: Dictionary of used datapipes and t0 ones
t0_datapipe: Datapipe which yields t0 times for sample
configuration: Configuration object.
production: Whether constucting pipeline for production inference. No dropout is used if
True.
"""

conf_in = configuration.input_data
Expand All @@ -301,7 +307,7 @@ def slice_datapipes_by_time(
# In samples where dropout is applied, the first non-nan value could be 20 - 45 mins before
# time t0.
dropout_timedeltas=[minutes(m) for m in range(-45, -15, 5)],
dropout_frac=0.5,
dropout_frac=0 if production else 0.5,
)

# Satellite data never more recent than t0-15mins
Expand All @@ -315,7 +321,7 @@ def slice_datapipes_by_time(
forecast_duration=minutes(conf_in.nwp.forecast_minutes),
# The NWP forecast will always be at least 90 minutes stale
dropout_timedeltas=[minutes(-90)],
dropout_frac=1.0,
dropout_frac=0 if production else 1.0,
)

if "sat" in datapipes_dict:
Expand All @@ -325,6 +331,7 @@ def slice_datapipes_by_time(
sample_period_duration=minutes(5),
interval_start=minutes(-conf_in.satellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
)

# Generate randomly sampled dropout times
Expand Down Expand Up @@ -357,6 +364,7 @@ def slice_datapipes_by_time(
sample_period_duration=minutes(5),
interval_start=minutes(-conf_in.hrvsatellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
)

# Apply the dropout
Expand All @@ -372,13 +380,15 @@ def slice_datapipes_by_time(
sample_period_duration=minutes(5),
interval_start=minutes(5),
interval_end=minutes(conf_in.pv.forecast_minutes),
fill_selection=production,
)

datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice(
t0_datapipe=get_t0_datapipe("pv"),
sample_period_duration=minutes(5),
interval_start=minutes(-conf_in.pv.history_minutes),
interval_end=minutes(0),
fill_selection=production,
)

if "gsp" in datapipes_dict:
Expand All @@ -389,20 +399,22 @@ def slice_datapipes_by_time(
sample_period_duration=minutes(30),
interval_start=minutes(30),
interval_end=minutes(conf_in.gsp.forecast_minutes),
fill_selection=production,
)

datapipes_dict["gsp"] = datapipes_dict["gsp"].select_time_slice(
t0_datapipe=get_t0_datapipe(None),
sample_period_duration=minutes(30),
interval_start=-minutes(conf_in.gsp.history_minutes),
interval_end=minutes(0),
fill_selection=production,
)

# Dropout on the GSP, but not the future GSP
gsp_dropout_time_datapipe = get_t0_datapipe("gsp").select_dropout_time(
# GSP data for time t0 may be missing. Only have value for t0-30mins
dropout_timedeltas=[minutes(-30)],
dropout_frac=0.1,
dropout_frac=0 if production else 0.1,
)

datapipes_dict["gsp"] = datapipes_dict["gsp"].apply_dropout_time(
Expand All @@ -420,6 +432,7 @@ def construct_sliced_data_pipeline(
t0_datapipe: IterDataPipe,
block_sat: bool = False,
block_nwp: bool = False,
production: bool = False,
) -> IterDataPipe:
"""Constructs data pipeline for the input data config file.
Expand All @@ -431,6 +444,7 @@ def construct_sliced_data_pipeline(
t0_datapipe: Datapipe yielding times.
block_sat: Whether to load zeroes for satellite data.
block_nwp: Whether to load zeroes for NWP data.
production: Whether constucting pipeline for production inference.
"""

datapipes_dict = _get_datapipes_dict(
Expand All @@ -439,14 +453,16 @@ def construct_sliced_data_pipeline(
block_nwp,
)

assert not (production and (block_sat or block_nwp))

configuration = datapipes_dict.pop("config")

# Unpack for convenience
conf_sat = configuration.input_data.satellite
conf_nwp = configuration.input_data.nwp

# Slice all of the datasets by time - this is an in-place operation
slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration)
slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production)

# Spatially slice, normalize, and convert data to numpy arrays
numpy_modalities = []
Expand Down
29 changes: 29 additions & 0 deletions tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import timedelta
import pandas as pd
import numpy as np
from torchdata.datapipes.iter import IterableWrapper
from ocf_datapipes.select import SelectTimeSlice

Expand Down Expand Up @@ -39,3 +40,31 @@ def test_select_time_slice_sat(sat_datapipe):
for sat_sample, t0 in zip(sat_samples, t0_values):
assert len(sat_sample.time_utc) == 3
assert sat_sample.time_utc[1] == t0

# Check with out of bounds selection
t_last = pd.to_datetime(data.time_utc.values[-1])
t0_values = [
t_last - timedelta(minutes=5),
t_last,
t_last + timedelta(minutes=5),
t_last + timedelta(minutes=10),
]
t0_datapipe = IterableWrapper(t0_values)

dp = SelectTimeSlice(
sat_datapipe,
t0_datapipe,
sample_period_duration=timedelta(minutes=5),
interval_start=timedelta(minutes=-5),
interval_end=timedelta(minutes=5),
fill_selection=True,
)

sat_samples = list(dp)

for i, (sat_sample, t0) in enumerate(zip(sat_samples, t0_values)):
assert len(sat_sample.time_utc) == 3
assert sat_sample.time_utc[1] == t0
# Correct number of time steps are all NaN
sat_sel = sat_sample.isel(x_geostationary=0, y_geostationary=0, channel=0)
assert np.isnan(sat_sel.values).sum() == i
12 changes: 12 additions & 0 deletions tests/training/test_pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def test_construct_sliced_data_pipeline(configuration_filename):

batch = next(iter(dp))

# Chosen to lie beyond end of test data
t0_pipe = IterableWrapper([datetime(2020, 4, 2, 0, 30)])

dp = construct_sliced_data_pipeline(
configuration_filename,
location_pipe=loc_pipe,
t0_datapipe=t0_pipe,
production=True,
)

batch = next(iter(dp))


def test_pvnet_datapipe(configuration_filename):
start_time = datetime(1900, 1, 1)
Expand Down

0 comments on commit 6edcfce

Please sign in to comment.