Skip to content

Commit

Permalink
Update what Site Dataset returns
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Nov 18, 2024
1 parent 1535415 commit 373fca7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 27 deletions.
85 changes: 85 additions & 0 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
Expand Down Expand Up @@ -115,6 +116,37 @@ def process_and_combine_datasets(

return combined_sample

def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
) -> xr.Dataset:
"""Normalize and combine data to xr Dataset"""

data_arrays = []

if "nwp" in dataset_dict:
for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
da_sat = dataset_dict["sat"]
data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
# site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp
data_arrays.append(("sites", da_sites))

combined_sample_dataset = merge_arrays(data_arrays)

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_dicts(list_of_dicts: list[dict]) -> dict:
"""Merge a list of dictionaries into a single dictionary"""
Expand All @@ -124,6 +156,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict

def merge_arrays(list_of_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.
Args:
list_of_arrays: List of tuples where each tuple contains:
- A string (key name).
- An xarray.DataArray.
Returns:
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
"""
datasets = []

for key, data_array in list_of_arrays:
# Ensure all attributes are strings for consistency
data_array = data_array.assign_attrs(
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
)

# Convert DataArray to Dataset with the variable name as the key
dataset = data_array.to_dataset(name=key)

# Prepend key name to all dimension and coordinate names for uniqueness
dataset = dataset.rename(
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
)
dataset = dataset.rename(
{coord: f"{key}__{coord}" for coord in dataset.coords}
)

# Handle concatenation dimension if applicable
concat_dim = (
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
else f"{key}__time_utc"
)

if f"{key}__init_time_utc" in dataset.coords:
init_coord = f"{key}__init_time_utc"
if dataset[init_coord].ndim == 0: # Check if scalar
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})

datasets.append(dataset)

# Ensure all datasets are valid xarray.Dataset objects
for ds in datasets:
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"

# Merge all prepared datasets
combined_dataset = xr.merge(datasets)

return combined_dataset

def fill_nans_in_arrays(batch: dict) -> dict:
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
Expand Down
4 changes: 2 additions & 2 deletions ocf_data_sampler/torch_datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
slice_datasets_by_time, slice_datasets_by_space
)
from ocf_data_sampler.utils import minutes
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict, compute
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods

xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -154,7 +154,7 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
sample_dict = compute(sample_dict)

sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')
sample = process_and_combine_site_sample_dict(sample_dict, self.config)

return sample

Expand Down
2 changes: 1 addition & 1 deletion tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
interval_start = pd.Timedelta(-6, "h")
interval_end = pd.Timedelta(3, "h")
freq = pd.Timedelta("1H")
freq = pd.Timedelta("1h")
dropout_timedelta = pd.Timedelta("-2h")

t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
Expand Down
51 changes: 27 additions & 24 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.site import SiteBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from xarray import Dataset


@pytest.fixture()
Expand Down Expand Up @@ -34,30 +35,32 @@ def test_site(site_config_filename):
# Generate a sample
sample = dataset[0]

assert isinstance(sample, dict)

for key in [
NWPBatchKey.nwp,
SatelliteBatchKey.satellite_actual,
SiteBatchKey.generation,
SiteBatchKey.site_solar_azimuth,
SiteBatchKey.site_solar_elevation,
]:
assert key in sample

for nwp_source in ["ukv"]:
assert nwp_source in sample[NWPBatchKey.nwp]

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample[SiteBatchKey.generation].shape == (4,)
# Solar angles have same shape as GSP data
assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)
assert isinstance(sample, Dataset)

# TODO change this bit of the test to check for sensible things

# for key in [
# NWPBatchKey.nwp,
# SatelliteBatchKey.satellite_actual,
# SiteBatchKey.generation,
# SiteBatchKey.site_solar_azimuth,
# SiteBatchKey.site_solar_elevation,
# ]:
# assert key in sample

# for nwp_source in ["ukv"]:
# assert nwp_source in sample[NWPBatchKey.nwp]

# # check the shape of the data is correct
# # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
# assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
# # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
# assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# # 3 hours of 30 minute data (inclusive)
# assert sample[SiteBatchKey.generation].shape == (4,)
# # Solar angles have same shape as GSP data
# assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
# assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)


def test_site_time_filter_start(site_config_filename):
Expand Down

0 comments on commit 373fca7

Please sign in to comment.