Skip to content

Commit

Permalink
linting and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Aug 1, 2024
1 parent 281bc7e commit 64d8ae0
Showing 1 changed file with 97 additions and 52 deletions.
149 changes: 97 additions & 52 deletions ocf_dataset_alpha/datasets/pvnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Torch dataset for PVNet"""

from datetime import datetime, timedelta

import numpy as np
Expand Down Expand Up @@ -51,7 +53,16 @@
xr.set_options(keep_attrs=True)


def add_t0_idx_and_sample_period_duration(ds, source_config):
xarray_object = xr.DataArray | xr.Dataset

# TODO: This is messy. Remove the need for this in NumpyBatch conversion
def add_t0_idx_and_sample_period_duration(ds: xarray_object, source_config):
"""Add attributes to the xarray dataset needed for numpy batch
Args:
ds: Xarray object
source_config: Input data source config
"""

ds.attrs["t0_idx"] = int(
source_config.history_minutes / source_config.time_resolution_minutes
Expand All @@ -60,12 +71,21 @@ def add_t0_idx_and_sample_period_duration(ds, source_config):
return ds


def minutes(m):
def minutes(m: float):
"""Timedelta minutes
Args:
m: minutes
"""
return timedelta(minutes=m)


def get_dataset_dict(config: Configuration):
"""Construct dictionary of all of the input data sources
Args:
config: Configuration file
"""
# Check which modalities to use
conf_in = config.input_data
use_nwp = (
Expand All @@ -80,44 +100,36 @@ def get_dataset_dict(config: Configuration):

datasets = {}

# Load GSP national data
if use_gsp:
gsp_config = config.input_data.gsp
# We always assume GSP will be included
gsp_config = config.input_data.gsps

ds_gsp = open_gsp(
zarr_path=gsp_config.gsp_zarr_path
)
ds_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path)
ds_gsp = add_t0_idx_and_sample_period_duration(ds_gsp, gsp_config)

# These attrs are still currently needed but will be removed in the future
ds_gsp = add_t0_idx_and_sample_period_duration(ds_gsp, gsp_config)

datasets["gsp"] = ds_gsp
datasets["gsp"] = ds_gsp

# Load NWP data
# Load NWP data if in config
if use_nwp:

datasets["nwp"] = {}
for nwp_source, nwp_config in conf_in.nwp.items():
ds_nwp = open_nwp(
nwp_config.nwp_zarr_path,
provider=nwp_config.nwp_provider,
)


ds_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider)

ds_nwp = ds_nwp.sel(channel=list(nwp_config.nwp_channels))

# These attrs are still currently needed but will be removed in the future

ds_nwp = add_t0_idx_and_sample_period_duration(ds_nwp, nwp_config)

datasets["nwp"][nwp_source] = ds_nwp

# Load satellite data if in config
if use_sat:
sat_config = config.input_data.satellite

ds_sat = open_sat_data(sat_config.satellite_zarr_path)

ds_sat.sel(channel=list(sat_config.satellite_channels))

# These attrs are still currently needed but will be removed in the future
ds_sat = add_t0_idx_and_sample_period_duration(ds_sat, sat_config)

datasets["sat"] = ds_sat
Expand All @@ -130,9 +142,16 @@ def find_valid_t0_times(
datasets_dict: dict,
config: Configuration,
):
"""Find the t0 times where all of the requested input data is available
Args:
datasets_dict: A dictionary of input datasets
config: Configuration file
"""

contiguous_time_periods = [] # Used to store contiguous time periods from each data source

# TODO: Is this cleaner as series of `if key in datasets_dict` statements rather than loop?
for key in datasets_dict.keys():

if key == "nwp":
Expand Down Expand Up @@ -207,7 +226,7 @@ def find_valid_t0_times(
else:
valid_time_periods = contiguous_time_periods[0]

# Select time periods and set length
# Fill out the contiguous time periods to get the t0 times
valid_t0_times = fill_time_periods(
valid_time_periods,
freq=minutes(config.input_data.gsp.time_resolution_minutes)
Expand All @@ -220,12 +239,12 @@ def slice_datasets_by_space(
datasets_dict: dict,
location: Location,
config: Configuration,
) -> None:
"""Modifies a dictionary of datapipes in-place to yield slices around a given location
) -> dict:
"""Slice a dictionaries of input data sources around a given location
Args:
datapipes_dict: Dictionary of used datapipes and t0 ones
location_pipe: Datapipe which yields location for sample
datasets_dict: Dictionary of the input data sources
location: The location to sample around
config: Configuration object.
"""

Expand Down Expand Up @@ -261,11 +280,17 @@ def slice_datasets_by_space(


def slice_datasets_by_time(
datasets_dict,
t0,
datasets_dict: dict,
t0: pd.Timedelta,
config: Configuration,
) -> None:
) -> dict:
"""Slice a dictionaries of input data sources around a given t0 time
Args:
datasets_dict: Dictionary of the input data sources
t0: The init-time
config: Configuration object.
"""
conf_in = config.input_data

sliced_datasets_dict = {}
Expand All @@ -274,7 +299,6 @@ def slice_datasets_by_time(

sliced_datasets_dict["nwp"] = {}

# NWP is nested in the dict
for nwp_key, ds_nwp in datasets_dict["nwp"].items():

dropout_timedeltas = minutes_list_to_timedeltas(
Expand All @@ -293,7 +317,7 @@ def slice_datasets_by_time(
)

if "sat" in datasets_dict:
# Take time slices of sat data

sliced_datasets_dict["sat"] = select_time_slice(
datasets_dict["sat"],
t0,
Expand All @@ -303,7 +327,7 @@ def slice_datasets_by_time(
max_steps_gap=2,
)

# Generate randomly sampled dropout times
# Randomly sample dropout
sat_dropout_time = draw_dropout_time(
t0,
dropout_timedeltas=minutes_list_to_timedeltas(
Expand All @@ -318,6 +342,7 @@ def slice_datasets_by_time(
sat_dropout_time,
)

# GSP always assumed to be included
sliced_datasets_dict["gsp_future"] = select_time_slice(
datasets_dict["gsp"],
t0,
Expand Down Expand Up @@ -352,37 +377,40 @@ def slice_datasets_by_time(
return sliced_datasets_dict


def merge_dicts(list_of_dicts):
def merge_dicts(list_of_dicts: list[dict]):
"""Merge a list of dictionaries into a single dictionary"""
# TODO: This doesn't account for duplicate keys
all_d = {}
for d in list_of_dicts:
all_d.update(d)
return all_d



def process_and_combine_datasets(dataset_dict, config):
def process_and_combine_datasets(dataset_dict: dict, config: Configuration):
"""Normalize and convert data to numpy arrays"""

numpy_modalities = []

# Normalise the inputs and convert to numpy format
if "nwp" in dataset_dict:

conf_nwp = config.input_data.nwp

nwp_numpy_modalities = dict()

for nwp_key, ds_nwp in dataset_dict["nwp"].items():

# Normalise
provider = conf_nwp[nwp_key].nwp_provider
ds_nwp = (ds_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(ds_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})

if "sat" in dataset_dict:
# Normalise
ds_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD
# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(ds_sat))


Expand All @@ -393,12 +421,15 @@ def process_and_combine_datasets(dataset_dict, config):

# Combine all the modalities
combined_sample = merge_dicts(numpy_modalities)

# Add sun coords
combined_sample = add_sun_position(combined_sample, modality_name="gsp")

return combined_sample


def compute(d):
"""Eagerly load a nested dictionary of xarray data"""
for k, v in d.items():
if isinstance(v, dict):
d[k] = compute(v)
Expand All @@ -408,7 +439,7 @@ def compute(d):


def get_locations(gs_gsp: xr.Dataset) -> list[Location]:
"""Get list of locations of the GSPs"""
"""Get list of locations of GSP"""
locations = []
for gsp_id in gs_gsp.gsp_id.values:
ds_ = gs_gsp.sel(gsp_id=gsp_id)
Expand All @@ -426,10 +457,9 @@ def get_locations(gs_gsp: xr.Dataset) -> list[Location]:
class PVNetDataset(Dataset):
def __init__(
self,
config_filename,
start_time = None,
end_time=None,
preshuffle: bool = False
config_filename: str,
start_time: str | None = None,
end_time: str| None = None,
):
"""A torch Dataset for PVNet
Expand All @@ -439,7 +469,7 @@ def __init__(

datasets_dict = get_dataset_dict(config)

# Remove national GSP data
# Remove national GSP ID
datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=slice(1, None))

if (start_time is not None) or (end_time is not None):
Expand All @@ -461,9 +491,6 @@ def __init__(
)

index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T

if preshuffle:
index_pairs = np.random.permutation(index_pairs)

# Assign coords and indices to self
self.valid_t0_times = valid_t0_times
Expand All @@ -478,8 +505,14 @@ def __init__(
def __len__(self):
return len(self.index_pairs)


def get_sample(self, t0, location):
# TODO: Would this be better if we could pass in GSP ID int instead of Location object?
def get_sample(self, t0: pd.Timestamp, location: Location):
"""Generate the PVNet sample for given coordinates
Args:
t0: init-time for sample
location: location for sample
"""
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
sample_dict = compute(sample_dict)
Expand All @@ -496,25 +529,37 @@ def __getitem__(self, idx):
location = self.locations[loc_index]
t0 = self.valid_t0_times[t_index]

# Generate the sample
return self.get_sample(t0, location)


if __name__=="__main__":

# ------------------ basic usage ---------------------

# TODO: remove this, its messy, but useful until we have tests and docs
config_filename = (
"/home/jamesfulton/repos/PVNet/configs/datamodule/configuration/gcp_configuration.yaml"
)

# Create dataset object
dataset = PVNetDataset(config_filename)


print(len(dataset.valid_t0_times))
# Print number of samples
print(f"Found {len(dataset.valid_t0_times)} possible samples")

idx = 100
# Find the 0th sample coordinates
# TODO: Should we be able to use the dataset to map from index to t0, location coords more
# easily?
idx = 0
t_index, loc_index = dataset.index_pairs[idx]

location = dataset.locations[loc_index]
t0 = dataset.valid_t0_times[t_index]

# Print coords
print(t0)
print(location)

# Generate sample - no printing since its BIG
sample = dataset[idx]

0 comments on commit 64d8ae0

Please sign in to comment.