diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 7e60ce6f5..df7e8dc58 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -62,6 +62,7 @@ def metnet_national_datapipe( use_gsp: bool = True, use_topo: bool = True, output_size: int = 256, + gsp_in_image: bool = False, start_time: datetime.datetime = datetime.datetime(2014, 1, 1), end_time: datetime.datetime = datetime.datetime(2023, 1, 1), ) -> IterDataPipe: @@ -82,6 +83,7 @@ def metnet_national_datapipe( start_time: Start time to select on end_time: End time to select from output_size: Size, in pixels, of the output image + gsp_in_image: Add GSP history as channels in MetNet image Returns: datapipe """ @@ -130,6 +132,16 @@ def metnet_national_datapipe( # Now combine in the MetNet format modalities = [] + if gsp_in_image and "hrv" in used_datapipes.keys(): + sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image( + image_datapipe=sat_gsp_datapipe + ) + elif gsp_in_image and "sat" in used_datapipes.keys(): + sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image( + image_datapipe=sat_gsp_datapipe + ) if "nwp" in used_datapipes.keys(): modalities.append(nwp_datapipe) if "hrv" in used_datapipes.keys(): @@ -138,6 +150,8 @@ def metnet_national_datapipe( modalities.append(sat_datapipe) if "topo" in used_datapipes.keys(): modalities.append(topo_datapipe) + if gsp_in_image: + modalities.append(gsp_history) gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5) @@ -154,8 +168,10 @@ def metnet_national_datapipe( output_height_pixels=output_size, add_sun_features=use_sun, ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - gsp_history = gsp_history.map(_remove_nans) - gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) - return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples + if not gsp_in_image: + gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) + gsp_history = gsp_history.map(_remove_nans) + gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) + return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples + else: + return metnet_datapipe.zip(gsp_datapipe) diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index 49c6f80bb..463103091 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -24,6 +24,7 @@ from .get_contiguous_time_periods import ( GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods, ) +from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage from .metnet_preprocessor import PreProcessMetNetIterDataPipe as PreProcessMetNet from .normalize import NormalizeIterDataPipe as Normalize from .pv.create_pv_image import CreatePVImageIterDataPipe as CreatePVImage diff --git a/ocf_datapipes/transform/xarray/gsp/__init__.py b/ocf_datapipes/transform/xarray/gsp/__init__.py new file mode 100644 index 000000000..384d610ef --- /dev/null +++ b/ocf_datapipes/transform/xarray/gsp/__init__.py @@ -0,0 +1 @@ +"""GSP specific transforms""" diff --git a/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py new file mode 100644 index 000000000..295796c3c --- /dev/null +++ b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py @@ -0,0 +1,90 @@ +"""Convert point PV sites to image output""" +import logging +from typing import Union + +import numpy as np +import xarray as xr +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +from ocf_datapipes.utils import Zipper + +logger = logging.getLogger(__name__) + + +@functional_datapipe("create_gsp_image") +class CreateGSPImageIterDataPipe(IterDataPipe): + """Create GSP image from individual sites""" + + def __init__( + self, + source_datapipe: IterDataPipe, + image_datapipe: IterDataPipe, + normalize: bool = False, + image_dim: str = "geostationary", + always_return_first: bool = False, + seed=None, + ): + """ + Creates a 3D data cube of GSP output image x number of timesteps + + This is primarily for national GSP, so single GSP inputs are preferred + + Args: + source_datapipe: Source datapipe of PV data + image_datapipe: Datapipe emitting images to get the shape from, with coordinates + normalize: Whether to normalize based off the image max, or leave raw data + image_dim: Dimension name for the x and y dimensions + always_return_first: Always return the first image data cube, to save computation + Only use for if making the image at the beginning of the stack + seed: Random seed to use if using max_num_pv_systems + """ + self.source_datapipe = source_datapipe + self.image_datapipe = image_datapipe + self.normalize = normalize + self.x_dim = "x_" + image_dim + self.y_dim = "y_" + image_dim + self.rng = np.random.default_rng(seed=seed) + self.always_return_first = always_return_first + + def __iter__(self) -> xr.DataArray: + for gsp_systems_xr, image_xr in Zipper(self.source_datapipe, self.image_datapipe): + # Create empty image to use for the PV Systems, assumes image has x and y coordinates + pv_image = np.zeros( + ( + len(gsp_systems_xr["time_utc"]), + len(image_xr[self.y_dim]), + len(image_xr[self.x_dim]), + ), + dtype=np.float32, + ) + for i, gsp_system_id in enumerate(gsp_systems_xr["gsp_id"]): + gsp_system = gsp_systems_xr.sel(gsp_id=gsp_system_id) + for time_step in range(len(gsp_system.time_utc.values)): + # Now go by the timestep to create cube of GSP data + pv_image[time_step:, :] = gsp_system.isel(time_utc=time_step).values + + pv_image = np.nan_to_num(pv_image) + + # Should return Xarray as in Xarray transforms + # Same coordinates as the image xarray, so can take that + pv_image = _create_data_array_from_image(pv_image, gsp_systems_xr, image_xr) + yield pv_image + + +def _create_data_array_from_image( + pv_image: np.ndarray, + pv_systems_xr: Union[xr.Dataset, xr.DataArray], + image_xr: Union[xr.Dataset, xr.DataArray], +): + data_array = xr.DataArray( + data=pv_image, + coords=( + ("time_utc", pv_systems_xr.time_utc.values), + ("y_geostationary", image_xr.y_geostationary.values), + ("x_geostationary", image_xr.x_geostationary.values), + ), + name="gsp_image", + ).astype(np.float32) + data_array.attrs = image_xr.attrs + return data_array diff --git a/tests/training/test_metnet_gsp_national.py b/tests/training/test_metnet_gsp_national.py index aaf73eb01..3b0874f02 100644 --- a/tests/training/test_metnet_gsp_national.py +++ b/tests/training/test_metnet_gsp_national.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from torchdata.dataloader2 import DataLoader2 @@ -16,3 +17,15 @@ def test_metnet_datapipe(): _ = batch if i + 1 % 50000 == 0: break + + +def test_metnet_gsp_image_datapipe(): + filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") + gsp_datapipe = metnet_national_datapipe( + filename, use_pv=False, gsp_in_image=True, output_size=128 + ) + dataloader = iter(gsp_datapipe) + batch = next(dataloader) + x, y = batch + assert np.isfinite(x).all() + assert np.isfinite(y).all() diff --git a/tests/transform/xarray/gsp/test_create_gsp_image.py b/tests/transform/xarray/gsp/test_create_gsp_image.py new file mode 100644 index 000000000..e90b668c8 --- /dev/null +++ b/tests/transform/xarray/gsp/test_create_gsp_image.py @@ -0,0 +1,12 @@ +import numpy as np + +from ocf_datapipes.select import DropGSP +from ocf_datapipes.transform.xarray import CreateGSPImage + + +def test_create_gsp_image(gsp_datapipe, sat_datapipe): + gsp_datapipe = DropGSP(gsp_datapipe, gsps_to_keep=[0]) + pv_image_datapipe = CreateGSPImage(gsp_datapipe, sat_datapipe) + data = next(iter(pv_image_datapipe)) + assert np.max(data) > 0 + assert np.min(data) >= 0