Skip to content

Commit

Permalink
Merge commit 'b15a6562d16c90777c6486fbf59fd6c1a70708e1' into development
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Aug 13, 2024
2 parents 3ed14c0 + b15a656 commit 0b6f06e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 27 deletions.
31 changes: 11 additions & 20 deletions ocf_data_sampler/datasets/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,43 +67,37 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr

in_config = config.input_data

# Check which modalities to use
# TODO: Clean these up
use_nwp = (
(in_config.nwp is not None)
and len(in_config.nwp) != 0
and all(v.nwp_zarr_path != "" for _, v in in_config.nwp.items())
)
use_sat = is_config_and_path_valid(True, in_config.satellite, "satellite_zarr_path")

datasets = {}
datasets_dict = {}

# We always assume GSP will be included
datasets["gsp"] = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)

# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))

# Load NWP data if in config
if use_nwp:
if in_config.nwp:

datasets["nwp"] = {}
datasets_dict["nwp"] = {}
for nwp_source, nwp_config in in_config.nwp.items():

da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)

da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))

datasets["nwp"][nwp_source] = da_nwp
datasets_dict["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if use_sat:
if in_config.satellite:
sat_config = config.input_data.satellite

da_sat = open_sat_data(sat_config.satellite_zarr_path)

da_sat = da_sat.sel(channel=list(sat_config.satellite_channels))

datasets["sat"] = da_sat
datasets_dict["sat"] = da_sat

return datasets
return datasets_dict



Expand Down Expand Up @@ -445,9 +439,6 @@ def __init__(
config = load_yaml_configuration(config_filename)

datasets_dict = get_dataset_dict(config)

# Remove national GSP ID
datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=slice(1, None))

# Get t0 times where all input data is available
valid_t0_times = find_valid_t0_times(datasets_dict, config)
Expand Down
29 changes: 24 additions & 5 deletions ocf_data_sampler/select/find_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def trim_contiguous_time_periods(
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
) -> pd.DataFrame:
"""Trim the contiguous time periods in-place to allow for history and forecast durations.
"""Trim the contiguous time periods to allow for history and forecast durations.
Args:
contiguous_time_periods: DataFrame where each row represents a single time period. The
Expand All @@ -78,9 +78,14 @@ def trim_contiguous_time_periods(
Returns:
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
"""
contiguous_time_periods = contiguous_time_periods.copy()

contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()

valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]

return contiguous_time_periods


Expand All @@ -91,7 +96,19 @@ def find_contiguous_t0_periods(
forecast_duration: pd.Timedelta,
sample_period_duration: pd.Timedelta,
) -> pd.DataFrame:

"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
Args:
datetimes: pd.DatetimeIndex. Must be sorted.
history_duration: Length of the historical slice used for each sample
forecast_duration: Length of the forecast slice used for each sample
sample_period_duration: The sample frequency of the timeseries
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
total_duration = history_duration + forecast_duration

contiguous_time_periods = find_contiguous_time_periods(
Expand All @@ -106,6 +123,8 @@ def find_contiguous_t0_periods(
forecast_duration=forecast_duration,
)

assert len(contiguous_t0_periods) > 0

return contiguous_t0_periods


Expand Down Expand Up @@ -192,13 +211,13 @@ def find_contiguous_t0_periods_nwp(
# considering dropout - then the contiguous period breaks, and new starts with considering
# dropout and history duration
if end_this_period < dt_init + max_dropout:
contiguous_periods += [[start_this_period, end_this_period]]
contiguous_periods.append([start_this_period, end_this_period])

# And start a new period
start_this_period = dt_init + hist_drop_buffer
end_this_period = dt_init + max_staleness

contiguous_periods += [[start_this_period, end_this_period]]
contiguous_periods.append([start_this_period, end_this_period])

return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"])

Expand Down
2 changes: 0 additions & 2 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def select_time_slice_nwp(
ds[channel_dim_name].values, accum_channels
)

# TODO: Once all times have been converted to pd.Timestamp, remove this
t0 = pd.Timestamp(t0)
start_dt = (t0 - history_duration).ceil(sample_period_duration)
end_dt = (t0 + forecast_duration).ceil(sample_period_duration)

Expand Down

0 comments on commit 0b6f06e

Please sign in to comment.