Skip to content

Commit

Permalink
minor tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Aug 2, 2024
1 parent 626ebbb commit e681508
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 47 deletions.
107 changes: 61 additions & 46 deletions ocf_dataset_alpha/datasets/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
from ocf_datapipes.config.load import load_yaml_configuration

from ocf_datapipes.utils.location import Location

from ocf_datapipes.batch import BatchKey
from ocf_datapipes.batch import BatchKey, NumpyBatch

from ocf_datapipes.utils.consts import (
NWP_MEANS,
Expand All @@ -53,25 +52,23 @@
xr.set_options(keep_attrs=True)


xarray_object = xr.DataArray | xr.Dataset

# TODO: This is messy. Remove the need for this in NumpyBatch conversion
def add_t0_idx_and_sample_period_duration(ds: xarray_object, source_config):
def add_t0_idx_and_sample_period_duration(da: xr.DataArray, source_config) -> xr.DataArray:
"""Add attributes to the xarray dataset needed for numpy batch
Args:
ds: Xarray object
da: xarray DataArray
source_config: Input data source config
"""

ds.attrs["t0_idx"] = int(
da.attrs["t0_idx"] = int(
source_config.history_minutes / source_config.time_resolution_minutes
)
ds.attrs["sample_period_duration"] = minutes(source_config.time_resolution_minutes)
return ds
da.attrs["sample_period_duration"] = minutes(source_config.time_resolution_minutes)
return da


def minutes(m: float):
def minutes(m: float) -> timedelta:
"""Timedelta minutes
Args:
Expand All @@ -80,14 +77,16 @@ def minutes(m: float):
return timedelta(minutes=m)


def get_dataset_dict(config: Configuration):
def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]:
"""Construct dictionary of all of the input data sources
Args:
config: Configuration file
"""
# Check which modalities to use
conf_in = config.input_data

# TODO: Clean these up
use_nwp = (

(conf_in.nwp is not None)
Expand All @@ -103,36 +102,36 @@ def get_dataset_dict(config: Configuration):
# We always assume GSP will be included
gsp_config = config.input_data.gsps

ds_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path)
ds_gsp = add_t0_idx_and_sample_period_duration(ds_gsp, gsp_config)
da_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path)
da_gsp = add_t0_idx_and_sample_period_duration(da_gsp, gsp_config)

datasets["gsp"] = ds_gsp
datasets["gsp"] = da_gsp

# Load NWP data if in config
if use_nwp:

datasets["nwp"] = {}
for nwp_source, nwp_config in conf_in.nwp.items():

ds_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)
da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)

ds_nwp = ds_nwp.sel(channel=list(nwp_config.nwp_channels))
da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels))

ds_nwp = add_t0_idx_and_sample_period_duration(ds_nwp, nwp_config)
da_nwp = add_t0_idx_and_sample_period_duration(da_nwp, nwp_config)

datasets["nwp"][nwp_source] = ds_nwp
datasets["nwp"][nwp_source] = da_nwp

# Load satellite data if in config
if use_sat:
sat_config = config.input_data.satellite

ds_sat = open_sat_data(sat_config.satellite_zarr_path)
da_sat = open_sat_data(sat_config.satellite_zarr_path)

ds_sat.sel(channel=list(sat_config.satellite_channels))
da_sat.sel(channel=list(sat_config.satellite_channels))

ds_sat = add_t0_idx_and_sample_period_duration(ds_sat, sat_config)
da_sat = add_t0_idx_and_sample_period_duration(da_sat, sat_config)

datasets["sat"] = ds_sat
datasets["sat"] = da_sat

return datasets

Expand All @@ -157,7 +156,7 @@ def find_valid_t0_times(
if key == "nwp":
for nwp_key, nwp_conf in config.input_data.nwp.items():

ds = datasets_dict["nwp"][nwp_key]
da = datasets_dict["nwp"][nwp_key]

if nwp_conf.dropout_timedeltas_minutes is None:
max_dropout = minutes(0)
Expand All @@ -178,7 +177,7 @@ def find_valid_t0_times(

# This is the max staleness we can use considering the max step of the input data
max_possible_staleness = (
pd.Timedelta(ds["step"].max().item())
pd.Timedelta(da["step"].max().item())
- minutes(nwp_conf.forecast_minutes)
- end_buffer
)
Expand All @@ -192,7 +191,7 @@ def find_valid_t0_times(
max_staleness = max_staleness

time_periods = find_contiguous_t0_periods_nwp(
datetimes=pd.DatetimeIndex(ds["init_time_utc"]),
datetimes=pd.DatetimeIndex(da["init_time_utc"]),
history_duration=minutes(nwp_conf.history_minutes),
max_staleness=max_staleness,
max_dropout=max_dropout,
Expand Down Expand Up @@ -299,14 +298,14 @@ def slice_datasets_by_time(

sliced_datasets_dict["nwp"] = {}

for nwp_key, ds_nwp in datasets_dict["nwp"].items():
for nwp_key, da_nwp in datasets_dict["nwp"].items():

dropout_timedeltas = minutes_list_to_timedeltas(
conf_in.nwp[nwp_key].dropout_timedeltas_minutes
)

sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp(
ds_nwp,
da_nwp,
t0,
sample_period_duration=minutes(conf_in.nwp[nwp_key].time_resolution_minutes),
history_duration=minutes(conf_in.nwp[nwp_key].history_minutes),
Expand Down Expand Up @@ -377,7 +376,7 @@ def slice_datasets_by_time(
return sliced_datasets_dict


def merge_dicts(list_of_dicts: list[dict]):
def merge_dicts(list_of_dicts: list[dict]) -> dict:
"""Merge a list of dictionaries into a single dictionary"""
# TODO: This doesn't account for duplicate keys
all_d = {}
Expand All @@ -387,7 +386,7 @@ def merge_dicts(list_of_dicts: list[dict]):



def process_and_combine_datasets(dataset_dict: dict, config: Configuration):
def process_and_combine_datasets(dataset_dict: dict, config: Configuration) -> NumpyBatch:
"""Normalize and convert data to numpy arrays"""

numpy_modalities = []
Expand All @@ -397,27 +396,27 @@ def process_and_combine_datasets(dataset_dict: dict, config: Configuration):
conf_nwp = config.input_data.nwp
nwp_numpy_modalities = dict()

for nwp_key, ds_nwp in dataset_dict["nwp"].items():
for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Normalise
provider = conf_nwp[nwp_key].nwp_provider
ds_nwp = (ds_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(ds_nwp)
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})

if "sat" in dataset_dict:
# Normalise
ds_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD
da_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD
# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(ds_sat))
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))


# GSP always assumed to be in data
ds_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
ds_gsp = normalize_gsp(ds_gsp)
numpy_modalities.append(convert_gsp_to_numpy_batch(ds_gsp))
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)
numpy_modalities.append(convert_gsp_to_numpy_batch(da_gsp))

# Combine all the modalities
combined_sample = merge_dicts(numpy_modalities)
Expand All @@ -428,8 +427,8 @@ def process_and_combine_datasets(dataset_dict: dict, config: Configuration):
return combined_sample


def compute(d):
"""Eagerly load a nested dictionary of xarray data"""
def compute(xarray_dict: dict) -> dict:
"""Eagerly load a nested dictionary of xarray DataArrays"""
for k, v in d.items():
if isinstance(v, dict):
d[k] = compute(v)
Expand All @@ -438,16 +437,16 @@ def compute(d):
return d


def get_locations(gs_gsp: xr.Dataset) -> list[Location]:
def get_locations(gs_gsp: xr.DataArray) -> list[Location]:
"""Get list of locations of GSP"""
locations = []
for gsp_id in gs_gsp.gsp_id.values:
ds_ = gs_gsp.sel(gsp_id=gsp_id)
da_ = gs_gsp.sel(gsp_id=gsp_id)
locations.append(
Location(
coordinate_system = "osgb",
x=ds_.x_osgb.item(),
y=ds_.y_osgb.item(),
x=da_.x_osgb.item(),
y=da_.y_osgb.item(),
id=gsp_id,
)
)
Expand Down Expand Up @@ -483,18 +482,23 @@ def __init__(

# Construct list of locations to sample from
locations = get_locations(datasets_dict["gsp"])

# Construct a lookup for locations - useful for users to construct sample by GSP ID
location_lookup = {loc.id: loc for loc in locations}

# Construct indices for sampling
t_index, loc_index = np.meshgrid(
np.arange(len(valid_t0_times)),
np.arange(len(locations)),
)

# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T

# Assign coords and indices to self
self.valid_t0_times = valid_t0_times
self.locations = locations
self.location_lookup = location_lookup
self.index_pairs = index_pairs

# Assign config and input data to self
Expand All @@ -505,8 +509,8 @@ def __init__(
def __len__(self):
return len(self.index_pairs)

# TODO: Would this be better if we could pass in GSP ID int instead of Location object?
def get_sample(self, t0: pd.Timestamp, location: Location):

def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch:
"""Generate the PVNet sample for given coordinates
Args:
Expand All @@ -530,7 +534,18 @@ def __getitem__(self, idx):
t0 = self.valid_t0_times[t_index]

# Generate the sample
return self.get_sample(t0, location)
return self._get_sample(t0, location)


def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
"""Generate the PVNet sample for given coordinates
Args:
t0: init-time for sample
gsp_id: GSP ID
"""
location = self.location_lookup[gsp_id]
return self._get_sample(t0, location)


if __name__=="__main__":
Expand Down
1 change: 0 additions & 1 deletion ocf_dataset_alpha/numpy_batch/add_sun_position.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

import pvlib
import numpy as np
import warnings
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.utils.consts import (
AZIMUTH_MEAN,
Expand Down

0 comments on commit e681508

Please sign in to comment.