Skip to content

Commit

Permalink
Merge pull request #117 from openclimatefix/jacob/gsp_image
Browse files Browse the repository at this point in the history
Add GSP image creation
  • Loading branch information
jacobbieker authored Nov 28, 2022
2 parents dca9fa7 + 9367294 commit b84dd06
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 5 deletions.
26 changes: 21 additions & 5 deletions ocf_datapipes/training/metnet_gsp_national.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def metnet_national_datapipe(
use_gsp: bool = True,
use_topo: bool = True,
output_size: int = 256,
gsp_in_image: bool = False,
start_time: datetime.datetime = datetime.datetime(2014, 1, 1),
end_time: datetime.datetime = datetime.datetime(2023, 1, 1),
) -> IterDataPipe:
Expand All @@ -82,6 +83,7 @@ def metnet_national_datapipe(
start_time: Start time to select on
end_time: End time to select from
output_size: Size, in pixels, of the output image
gsp_in_image: Add GSP history as channels in MetNet image
Returns: datapipe
"""
Expand Down Expand Up @@ -130,6 +132,16 @@ def metnet_national_datapipe(

# Now combine in the MetNet format
modalities = []
if gsp_in_image and "hrv" in used_datapipes.keys():
sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2)
gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(
image_datapipe=sat_gsp_datapipe
)
elif gsp_in_image and "sat" in used_datapipes.keys():
sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2)
gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(
image_datapipe=sat_gsp_datapipe
)
if "nwp" in used_datapipes.keys():
modalities.append(nwp_datapipe)
if "hrv" in used_datapipes.keys():
Expand All @@ -138,6 +150,8 @@ def metnet_national_datapipe(
modalities.append(sat_datapipe)
if "topo" in used_datapipes.keys():
modalities.append(topo_datapipe)
if gsp_in_image:
modalities.append(gsp_history)

gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5)

Expand All @@ -154,8 +168,10 @@ def metnet_national_datapipe(
output_height_pixels=output_size,
add_sun_features=use_sun,
)

gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe)
gsp_history = gsp_history.map(_remove_nans)
gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True)
return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples
if not gsp_in_image:
gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe)
gsp_history = gsp_history.map(_remove_nans)
gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True)
return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples
else:
return metnet_datapipe.zip(gsp_datapipe)
1 change: 1 addition & 0 deletions ocf_datapipes/transform/xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .get_contiguous_time_periods import (
GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods,
)
from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage
from .metnet_preprocessor import PreProcessMetNetIterDataPipe as PreProcessMetNet
from .normalize import NormalizeIterDataPipe as Normalize
from .pv.create_pv_image import CreatePVImageIterDataPipe as CreatePVImage
Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/transform/xarray/gsp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""GSP specific transforms"""
90 changes: 90 additions & 0 deletions ocf_datapipes/transform/xarray/gsp/create_gsp_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Convert point PV sites to image output"""
import logging
from typing import Union

import numpy as np
import xarray as xr
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

from ocf_datapipes.utils import Zipper

logger = logging.getLogger(__name__)


@functional_datapipe("create_gsp_image")
class CreateGSPImageIterDataPipe(IterDataPipe):
"""Create GSP image from individual sites"""

def __init__(
self,
source_datapipe: IterDataPipe,
image_datapipe: IterDataPipe,
normalize: bool = False,
image_dim: str = "geostationary",
always_return_first: bool = False,
seed=None,
):
"""
Creates a 3D data cube of GSP output image x number of timesteps
This is primarily for national GSP, so single GSP inputs are preferred
Args:
source_datapipe: Source datapipe of PV data
image_datapipe: Datapipe emitting images to get the shape from, with coordinates
normalize: Whether to normalize based off the image max, or leave raw data
image_dim: Dimension name for the x and y dimensions
always_return_first: Always return the first image data cube, to save computation
Only use for if making the image at the beginning of the stack
seed: Random seed to use if using max_num_pv_systems
"""
self.source_datapipe = source_datapipe
self.image_datapipe = image_datapipe
self.normalize = normalize
self.x_dim = "x_" + image_dim
self.y_dim = "y_" + image_dim
self.rng = np.random.default_rng(seed=seed)
self.always_return_first = always_return_first

def __iter__(self) -> xr.DataArray:
for gsp_systems_xr, image_xr in Zipper(self.source_datapipe, self.image_datapipe):
# Create empty image to use for the PV Systems, assumes image has x and y coordinates
pv_image = np.zeros(
(
len(gsp_systems_xr["time_utc"]),
len(image_xr[self.y_dim]),
len(image_xr[self.x_dim]),
),
dtype=np.float32,
)
for i, gsp_system_id in enumerate(gsp_systems_xr["gsp_id"]):
gsp_system = gsp_systems_xr.sel(gsp_id=gsp_system_id)
for time_step in range(len(gsp_system.time_utc.values)):
# Now go by the timestep to create cube of GSP data
pv_image[time_step:, :] = gsp_system.isel(time_utc=time_step).values

pv_image = np.nan_to_num(pv_image)

# Should return Xarray as in Xarray transforms
# Same coordinates as the image xarray, so can take that
pv_image = _create_data_array_from_image(pv_image, gsp_systems_xr, image_xr)
yield pv_image


def _create_data_array_from_image(
pv_image: np.ndarray,
pv_systems_xr: Union[xr.Dataset, xr.DataArray],
image_xr: Union[xr.Dataset, xr.DataArray],
):
data_array = xr.DataArray(
data=pv_image,
coords=(
("time_utc", pv_systems_xr.time_utc.values),
("y_geostationary", image_xr.y_geostationary.values),
("x_geostationary", image_xr.x_geostationary.values),
),
name="gsp_image",
).astype(np.float32)
data_array.attrs = image_xr.attrs
return data_array
13 changes: 13 additions & 0 deletions tests/training/test_metnet_gsp_national.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import numpy as np
import pytest
from torchdata.dataloader2 import DataLoader2

Expand All @@ -16,3 +17,15 @@ def test_metnet_datapipe():
_ = batch
if i + 1 % 50000 == 0:
break


def test_metnet_gsp_image_datapipe():
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
gsp_datapipe = metnet_national_datapipe(
filename, use_pv=False, gsp_in_image=True, output_size=128
)
dataloader = iter(gsp_datapipe)
batch = next(dataloader)
x, y = batch
assert np.isfinite(x).all()
assert np.isfinite(y).all()
12 changes: 12 additions & 0 deletions tests/transform/xarray/gsp/test_create_gsp_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np

from ocf_datapipes.select import DropGSP
from ocf_datapipes.transform.xarray import CreateGSPImage


def test_create_gsp_image(gsp_datapipe, sat_datapipe):
gsp_datapipe = DropGSP(gsp_datapipe, gsps_to_keep=[0])
pv_image_datapipe = CreateGSPImage(gsp_datapipe, sat_datapipe)
data = next(iter(pv_image_datapipe))
assert np.max(data) > 0
assert np.min(data) >= 0

0 comments on commit b84dd06

Please sign in to comment.