Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Site torch dataset update #82

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 95 additions & 23 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
Expand All @@ -9,7 +10,6 @@
convert_satellite_to_numpy_batch,
convert_gsp_to_numpy_batch,
make_sun_position_numpy_batch,
convert_site_to_numpy_batch,
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
Expand Down Expand Up @@ -73,18 +73,6 @@ def process_and_combine_datasets(
}
)


if "site" in dataset_dict:
site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp

numpy_modalities.append(
convert_site_to_numpy_batch(
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
)
)

if target_key == 'gsp':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
Expand All @@ -95,16 +83,6 @@ def process_and_combine_datasets(

lon, lat = osgb_to_lon_lat(location.x, location.y)

elif target_key == 'site':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
t0+minutes(site_config.interval_start_minutes),
t0+minutes(site_config.interval_end_minutes),
freq=minutes(site_config.time_resolution_minutes),
)

lon, lat = location.x, location.y

numpy_modalities.append(
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
)
Expand All @@ -115,6 +93,47 @@ def process_and_combine_datasets(

return combined_sample

def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
) -> xr.Dataset:
"""
Normalize and combine data into a single xr Dataset

Args:
dataset_dict: dict containing sliced xr DataArrays
config: Configuration for the model

Returns:
xr.Dataset: A merged Dataset with nans filled in.

"""

data_arrays = []

if "nwp" in dataset_dict:
for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
da_sat = dataset_dict["sat"]
data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
# site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp
data_arrays.append(("sites", da_sites))

combined_sample_dataset = merge_arrays(data_arrays)

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_dicts(list_of_dicts: list[dict]) -> dict:
"""Merge a list of dictionaries into a single dictionary"""
Expand All @@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict

def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.

Args:
list_of_arrays: List of tuples where each tuple contains:
- A string (key name).
- An xarray.DataArray.

Returns:
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
"""
datasets = []

for key, data_array in normalised_data_arrays:
# Ensure all attributes are strings for consistency
data_array = data_array.assign_attrs(
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
)

# Convert DataArray to Dataset with the variable name as the key
dataset = data_array.to_dataset(name=key)

# Prepend key name to all dimension and coordinate names for uniqueness
dataset = dataset.rename(
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
)
dataset = dataset.rename(
{coord: f"{key}__{coord}" for coord in dataset.coords}
)

# Handle concatenation dimension if applicable
concat_dim = (
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
else f"{key}__time_utc"
)

if f"{key}__init_time_utc" in dataset.coords:
init_coord = f"{key}__init_time_utc"
if dataset[init_coord].ndim == 0: # Check if scalar
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})

datasets.append(dataset)

# Ensure all datasets are valid xarray.Dataset objects
for ds in datasets:
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"

# Merge all prepared datasets
combined_dataset = xr.merge(datasets)

return combined_dataset

def fill_nans_in_arrays(batch: dict) -> dict:
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
Expand Down
7 changes: 3 additions & 4 deletions ocf_data_sampler/torch_datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
slice_datasets_by_time, slice_datasets_by_space
)
from ocf_data_sampler.utils import minutes
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods

xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -152,10 +152,9 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
"""
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
sample_dict = compute(sample_dict)

sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')

sample = process_and_combine_site_sample_dict(sample_dict, self.config)
sample = sample.compute()
return sample

def get_location_from_site_id(self, site_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
interval_start = pd.Timedelta(-6, "h")
interval_end = pd.Timedelta(3, "h")
freq = pd.Timedelta("1H")
freq = pd.Timedelta("1h")
dropout_timedelta = pd.Timedelta("-2h")

t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
Expand Down
34 changes: 15 additions & 19 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.site import SiteBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from xarray import Dataset


@pytest.fixture()
Expand Down Expand Up @@ -34,31 +35,26 @@ def test_site(site_config_filename):
# Generate a sample
sample = dataset[0]

assert isinstance(sample, dict)
assert isinstance(sample, Dataset)

for key in [
NWPBatchKey.nwp,
SatelliteBatchKey.satellite_actual,
SiteBatchKey.generation,
SiteBatchKey.site_solar_azimuth,
SiteBatchKey.site_solar_elevation,
]:
assert key in sample
# Expected dimensions and data variables
expected_dims = {'satellite__x_geostationary', 'sites__time_utc', 'nwp-ukv__target_time_utc',
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'}
expected_data_vars = {"nwp-ukv", "satellite", "sites"}

for nwp_source in ["ukv"]:
assert nwp_source in sample[NWPBatchKey.nwp]
# Check dimensions
assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
# Check data variables
assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
assert sample["satellite"].values.shape == (7, 1, 2, 2)
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample[SiteBatchKey.generation].shape == (4,)
# Solar angles have same shape as GSP data
assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)

assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
# 1.5 hours of 30 minute data (inclusive)
assert sample["sites"].values.shape == (4,)

def test_site_time_filter_start(site_config_filename):

Expand Down
Loading