From 11e65b3b1d3dc7370549ee9eb4ad6803daf850ef Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:11:26 +0000 Subject: [PATCH 01/10] Add GSP image creation A lot faster than the non-GSP one, and simpler for MetNet as it can be built into the inputs easier --- ocf_datapipes/__init__.py | 18 ++-- ocf_datapipes/training/metnet_gsp_national.py | 17 ++-- ocf_datapipes/transform/__init__.py | 4 +- ocf_datapipes/transform/xarray/__init__.py | 1 + .../transform/xarray/gsp/__init__.py | 0 .../transform/xarray/gsp/create_gsp_image.py | 90 +++++++++++++++++++ .../xarray/gsp/test_create_gsp_image.py | 10 +++ 7 files changed, 123 insertions(+), 17 deletions(-) create mode 100644 ocf_datapipes/transform/xarray/gsp/__init__.py create mode 100644 ocf_datapipes/transform/xarray/gsp/create_gsp_image.py create mode 100644 tests/transform/xarray/gsp/test_create_gsp_image.py diff --git a/ocf_datapipes/__init__.py b/ocf_datapipes/__init__.py index 1a6fc1661..5352f57b4 100644 --- a/ocf_datapipes/__init__.py +++ b/ocf_datapipes/__init__.py @@ -1,10 +1,10 @@ """Datapipes""" -import ocf_datapipes.batch -import ocf_datapipes.convert -import ocf_datapipes.experimental -import ocf_datapipes.fake -import ocf_datapipes.load -import ocf_datapipes.select -import ocf_datapipes.transform -import ocf_datapipes.utils -import ocf_datapipes.validation +from ocf_datapipes.batch import * +from ocf_datapipes.convert import * +from ocf_datapipes.experimental import * +from ocf_datapipes.fake import * +from ocf_datapipes.load import * +from ocf_datapipes.select import * +from ocf_datapipes.transform import * +from ocf_datapipes.utils import * +from ocf_datapipes.validation import * diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 7e60ce6f5..2e878507e 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -62,6 +62,7 @@ def metnet_national_datapipe( use_gsp: bool = True, use_topo: bool = True, output_size: int = 256, + gsp_in_image: bool = False, start_time: datetime.datetime = datetime.datetime(2014, 1, 1), end_time: datetime.datetime = datetime.datetime(2023, 1, 1), ) -> IterDataPipe: @@ -82,6 +83,7 @@ def metnet_national_datapipe( start_time: Start time to select on end_time: End time to select from output_size: Size, in pixels, of the output image + gsp_in_image: Add GSP history as channels in MetNet image Returns: datapipe """ @@ -142,7 +144,8 @@ def metnet_national_datapipe( gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5) location_datapipe = LocationPicker(gsp_loc_datapipe) - + if gsp_in_image: + modalities.append(gsp_history.map(_remove_nans)) metnet_datapipe = PreProcessMetNet( modalities, location_datapipe=location_datapipe, @@ -154,8 +157,10 @@ def metnet_national_datapipe( output_height_pixels=output_size, add_sun_features=use_sun, ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - gsp_history = gsp_history.map(_remove_nans) - gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) - return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples + if not gsp_in_image: + gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) + gsp_history = gsp_history.map(_remove_nans) + gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) + return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples + else: + metnet_datapipe.zip(gsp_datapipe) diff --git a/ocf_datapipes/transform/__init__.py b/ocf_datapipes/transform/__init__.py index 925d8c520..7919927a7 100644 --- a/ocf_datapipes/transform/__init__.py +++ b/ocf_datapipes/transform/__init__.py @@ -1,3 +1,3 @@ """Transforms for the data in both xarray and numpy formats""" -import ocf_datapipes.transform.numpy -import ocf_datapipes.transform.xarray +from ocf_datapipes.transform.numpy import * +from ocf_datapipes.transform.xarray import * diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index 49c6f80bb..0dc7a8aa7 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -29,3 +29,4 @@ from .pv.create_pv_image import CreatePVImageIterDataPipe as CreatePVImage from .remove_nans import RemoveNansIterDataPipe as RemoveNans from .reproject_topographic_data import ReprojectTopographyIterDataPipe as ReprojectTopography +from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage diff --git a/ocf_datapipes/transform/xarray/gsp/__init__.py b/ocf_datapipes/transform/xarray/gsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py new file mode 100644 index 000000000..8cde733b7 --- /dev/null +++ b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py @@ -0,0 +1,90 @@ +"""Convert point PV sites to image output""" +import logging +from typing import Union + +import numpy as np +import xarray as xr +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +from ocf_datapipes.utils import Zipper + +logger = logging.getLogger(__name__) + + +@functional_datapipe("create_gsp_image") +class CreateGSPImageIterDataPipe(IterDataPipe): + """Create GSP image from individual sites""" + + def __init__( + self, + source_datapipe: IterDataPipe, + image_datapipe: IterDataPipe, + normalize: bool = False, + image_dim: str = "geostationary", + always_return_first: bool = False, + seed=None, + ): + """ + Creates a 3D data cube of GSP output image x number of timesteps + + This is primarily for national GSP, so single GSP inputs are preferred + + Args: + source_datapipe: Source datapipe of PV data + image_datapipe: Datapipe emitting images to get the shape from, with coordinates + normalize: Whether to normalize based off the image max, or leave raw data + image_dim: Dimension name for the x and y dimensions + always_return_first: Always return the first image data cube, to save computation + Only use for if making the image at the beginning of the stack + seed: Random seed to use if using max_num_pv_systems + """ + self.source_datapipe = source_datapipe + self.image_datapipe = image_datapipe + self.normalize = normalize + self.x_dim = "x_" + image_dim + self.y_dim = "y_" + image_dim + self.rng = np.random.default_rng(seed=seed) + self.always_return_first = always_return_first + + def __iter__(self) -> xr.DataArray: + for gsp_systems_xr, image_xr in Zipper(self.source_datapipe, self.image_datapipe): + # Create empty image to use for the PV Systems, assumes image has x and y coordinates + pv_image = np.zeros( + ( + len(gsp_systems_xr["time_utc"]), + len(image_xr[self.y_dim]), + len(image_xr[self.x_dim]), + ), + dtype=np.float32, + ) + for i, gsp_system_id in enumerate(gsp_systems_xr["gsp_id"]): + gsp_system = gsp_systems_xr.sel(gsp_id=gsp_system_id) + for time_step in range(len(gsp_system.time_utc.values)): + # Now go by the timestep to create cube of GSP data + pv_image[time_step :, :] = gsp_system.isel(time_utc=time_step).values + + pv_image = np.nan_to_num(pv_image) + + # Should return Xarray as in Xarray transforms + # Same coordinates as the image xarray, so can take that + pv_image = _create_data_array_from_image(pv_image, gsp_systems_xr, image_xr) + yield pv_image + + +def _create_data_array_from_image( + pv_image: np.ndarray, + pv_systems_xr: Union[xr.Dataset, xr.DataArray], + image_xr: Union[xr.Dataset, xr.DataArray], +): + data_array = xr.DataArray( + data=pv_image, + coords=( + ("time_utc", pv_systems_xr.time_utc.values), + ("y_geostationary", image_xr.y_geostationary.values), + ("x_geostationary", image_xr.x_geostationary.values), + ), + name="gsp_image", + ).astype(np.float32) + data_array.attrs = image_xr.attrs + return data_array diff --git a/tests/transform/xarray/gsp/test_create_gsp_image.py b/tests/transform/xarray/gsp/test_create_gsp_image.py new file mode 100644 index 000000000..44dc509e7 --- /dev/null +++ b/tests/transform/xarray/gsp/test_create_gsp_image.py @@ -0,0 +1,10 @@ +from ocf_datapipes import DropGSP, CreateGSPImage +import numpy as np + + +def test_create_gsp_image(gsp_datapipe, sat_datapipe): + gsp_datapipe = DropGSP(gsp_datapipe, gsps_to_keep=[0]) + pv_image_datapipe = CreateGSPImage(gsp_datapipe, sat_datapipe) + data = next(iter(pv_image_datapipe)) + assert np.max(data) > 0 + assert np.min(data) >= 0 From 64de56ccbaf07d5bd7cbb9aad0315c4d7718a998 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Nov 2022 16:12:42 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/transform/xarray/__init__.py | 2 +- ocf_datapipes/transform/xarray/gsp/create_gsp_image.py | 2 +- tests/transform/xarray/gsp/test_create_gsp_image.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index 0dc7a8aa7..463103091 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -24,9 +24,9 @@ from .get_contiguous_time_periods import ( GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods, ) +from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage from .metnet_preprocessor import PreProcessMetNetIterDataPipe as PreProcessMetNet from .normalize import NormalizeIterDataPipe as Normalize from .pv.create_pv_image import CreatePVImageIterDataPipe as CreatePVImage from .remove_nans import RemoveNansIterDataPipe as RemoveNans from .reproject_topographic_data import ReprojectTopographyIterDataPipe as ReprojectTopography -from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage diff --git a/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py index 8cde733b7..295796c3c 100644 --- a/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py +++ b/ocf_datapipes/transform/xarray/gsp/create_gsp_image.py @@ -62,7 +62,7 @@ def __iter__(self) -> xr.DataArray: gsp_system = gsp_systems_xr.sel(gsp_id=gsp_system_id) for time_step in range(len(gsp_system.time_utc.values)): # Now go by the timestep to create cube of GSP data - pv_image[time_step :, :] = gsp_system.isel(time_utc=time_step).values + pv_image[time_step:, :] = gsp_system.isel(time_utc=time_step).values pv_image = np.nan_to_num(pv_image) diff --git a/tests/transform/xarray/gsp/test_create_gsp_image.py b/tests/transform/xarray/gsp/test_create_gsp_image.py index 44dc509e7..d6ba15ea9 100644 --- a/tests/transform/xarray/gsp/test_create_gsp_image.py +++ b/tests/transform/xarray/gsp/test_create_gsp_image.py @@ -1,6 +1,7 @@ -from ocf_datapipes import DropGSP, CreateGSPImage import numpy as np +from ocf_datapipes import CreateGSPImage, DropGSP + def test_create_gsp_image(gsp_datapipe, sat_datapipe): gsp_datapipe = DropGSP(gsp_datapipe, gsps_to_keep=[0]) From afe7da3007e5535213ac5193528c085d409c1c54 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:12:58 +0000 Subject: [PATCH 03/10] Linting --- ocf_datapipes/transform/xarray/gsp/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ocf_datapipes/transform/xarray/gsp/__init__.py b/ocf_datapipes/transform/xarray/gsp/__init__.py index e69de29bb..83a54091e 100644 --- a/ocf_datapipes/transform/xarray/gsp/__init__.py +++ b/ocf_datapipes/transform/xarray/gsp/__init__.py @@ -0,0 +1 @@ +"""GSP specific transforms""" \ No newline at end of file From 488c76a51f5702b75e201ed1a40a4c70e277eae3 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:13:12 +0000 Subject: [PATCH 04/10] Linting --- ocf_datapipes/transform/xarray/gsp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_datapipes/transform/xarray/gsp/__init__.py b/ocf_datapipes/transform/xarray/gsp/__init__.py index 83a54091e..384d610ef 100644 --- a/ocf_datapipes/transform/xarray/gsp/__init__.py +++ b/ocf_datapipes/transform/xarray/gsp/__init__.py @@ -1 +1 @@ -"""GSP specific transforms""" \ No newline at end of file +"""GSP specific transforms""" From 3547c822e2429c6b87f6e324bb5c59f8c073d087 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:16:49 +0000 Subject: [PATCH 05/10] Change importing method --- ocf_datapipes/__init__.py | 18 +++++++++--------- ocf_datapipes/transform/__init__.py | 4 ++-- .../xarray/gsp/test_create_gsp_image.py | 3 ++- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ocf_datapipes/__init__.py b/ocf_datapipes/__init__.py index 5352f57b4..1a6fc1661 100644 --- a/ocf_datapipes/__init__.py +++ b/ocf_datapipes/__init__.py @@ -1,10 +1,10 @@ """Datapipes""" -from ocf_datapipes.batch import * -from ocf_datapipes.convert import * -from ocf_datapipes.experimental import * -from ocf_datapipes.fake import * -from ocf_datapipes.load import * -from ocf_datapipes.select import * -from ocf_datapipes.transform import * -from ocf_datapipes.utils import * -from ocf_datapipes.validation import * +import ocf_datapipes.batch +import ocf_datapipes.convert +import ocf_datapipes.experimental +import ocf_datapipes.fake +import ocf_datapipes.load +import ocf_datapipes.select +import ocf_datapipes.transform +import ocf_datapipes.utils +import ocf_datapipes.validation diff --git a/ocf_datapipes/transform/__init__.py b/ocf_datapipes/transform/__init__.py index 7919927a7..925d8c520 100644 --- a/ocf_datapipes/transform/__init__.py +++ b/ocf_datapipes/transform/__init__.py @@ -1,3 +1,3 @@ """Transforms for the data in both xarray and numpy formats""" -from ocf_datapipes.transform.numpy import * -from ocf_datapipes.transform.xarray import * +import ocf_datapipes.transform.numpy +import ocf_datapipes.transform.xarray diff --git a/tests/transform/xarray/gsp/test_create_gsp_image.py b/tests/transform/xarray/gsp/test_create_gsp_image.py index d6ba15ea9..c4158a0ed 100644 --- a/tests/transform/xarray/gsp/test_create_gsp_image.py +++ b/tests/transform/xarray/gsp/test_create_gsp_image.py @@ -1,6 +1,7 @@ import numpy as np -from ocf_datapipes import CreateGSPImage, DropGSP +from ocf_datapipes.transform.xarray import CreateGSPImage +from ocf_datapipes.select import DropGSP def test_create_gsp_image(gsp_datapipe, sat_datapipe): From dbc3b6533fd4f91dfe2aa9ba0fd66f52b4ada2e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Nov 2022 16:17:07 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/transform/xarray/gsp/test_create_gsp_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transform/xarray/gsp/test_create_gsp_image.py b/tests/transform/xarray/gsp/test_create_gsp_image.py index c4158a0ed..e90b668c8 100644 --- a/tests/transform/xarray/gsp/test_create_gsp_image.py +++ b/tests/transform/xarray/gsp/test_create_gsp_image.py @@ -1,7 +1,7 @@ import numpy as np -from ocf_datapipes.transform.xarray import CreateGSPImage from ocf_datapipes.select import DropGSP +from ocf_datapipes.transform.xarray import CreateGSPImage def test_create_gsp_image(gsp_datapipe, sat_datapipe): From 1230339ccf0a5955be4566730e0bdffbb19dd777 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:24:01 +0000 Subject: [PATCH 07/10] Add option for GSP national image in MetNet --- ocf_datapipes/training/metnet_gsp_national.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 2e878507e..bcc198475 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -132,6 +132,12 @@ def metnet_national_datapipe( # Now combine in the MetNet format modalities = [] + if gsp_in_image and "hrv" in used_datapipes.keys(): + sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(image_datapipe=sat_gsp_datapipe) + elif gsp_in_image and "sat" in used_datapipes.keys(): + sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(image_datapipe=sat_gsp_datapipe) if "nwp" in used_datapipes.keys(): modalities.append(nwp_datapipe) if "hrv" in used_datapipes.keys(): @@ -140,12 +146,13 @@ def metnet_national_datapipe( modalities.append(sat_datapipe) if "topo" in used_datapipes.keys(): modalities.append(topo_datapipe) + if gsp_in_image: + modalities.append(gsp_history) gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5) location_datapipe = LocationPicker(gsp_loc_datapipe) - if gsp_in_image: - modalities.append(gsp_history.map(_remove_nans)) + metnet_datapipe = PreProcessMetNet( modalities, location_datapipe=location_datapipe, From 5b74827e615c75fd2028b9cf5234ed6a35b9d6d2 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:25:02 +0000 Subject: [PATCH 08/10] Linting --- ocf_datapipes/training/metnet_gsp_national.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index bcc198475..26bc1902f 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -134,10 +134,12 @@ def metnet_national_datapipe( modalities = [] if gsp_in_image and "hrv" in used_datapipes.keys(): sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) - gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(image_datapipe=sat_gsp_datapipe) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0])\ + .create_gsp_image(image_datapipe=sat_gsp_datapipe) elif gsp_in_image and "sat" in used_datapipes.keys(): sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) - gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image(image_datapipe=sat_gsp_datapipe) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0])\ + .create_gsp_image(image_datapipe=sat_gsp_datapipe) if "nwp" in used_datapipes.keys(): modalities.append(nwp_datapipe) if "hrv" in used_datapipes.keys(): From cb9ac4a64442077ecae90a0255981e4bebcd89cf Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 28 Nov 2022 16:31:48 +0000 Subject: [PATCH 09/10] Add test --- ocf_datapipes/training/metnet_gsp_national.py | 2 +- tests/training/test_metnet_gsp_national.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 26bc1902f..94063b46a 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -172,4 +172,4 @@ def metnet_national_datapipe( gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples else: - metnet_datapipe.zip(gsp_datapipe) + return metnet_datapipe.zip(gsp_datapipe) diff --git a/tests/training/test_metnet_gsp_national.py b/tests/training/test_metnet_gsp_national.py index aaf73eb01..149bfdd21 100644 --- a/tests/training/test_metnet_gsp_national.py +++ b/tests/training/test_metnet_gsp_national.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from torchdata.dataloader2 import DataLoader2 @@ -16,3 +17,12 @@ def test_metnet_datapipe(): _ = batch if i + 1 % 50000 == 0: break + +def test_metnet_gsp_image_datapipe(): + filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") + gsp_datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128) + dataloader = iter(gsp_datapipe) + batch = next(dataloader) + x, y = batch + assert np.isfinite(x).all() + assert np.isfinite(y).all() From 93672945d13081e4f24f7bf02a7117897022cd87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Nov 2022 16:33:50 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/metnet_gsp_national.py | 10 ++++++---- tests/training/test_metnet_gsp_national.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 94063b46a..df7e8dc58 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -134,12 +134,14 @@ def metnet_national_datapipe( modalities = [] if gsp_in_image and "hrv" in used_datapipes.keys(): sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) - gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0])\ - .create_gsp_image(image_datapipe=sat_gsp_datapipe) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image( + image_datapipe=sat_gsp_datapipe + ) elif gsp_in_image and "sat" in used_datapipes.keys(): sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) - gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0])\ - .create_gsp_image(image_datapipe=sat_gsp_datapipe) + gsp_history = gsp_history.drop_gsp(gsps_to_keep=[0]).create_gsp_image( + image_datapipe=sat_gsp_datapipe + ) if "nwp" in used_datapipes.keys(): modalities.append(nwp_datapipe) if "hrv" in used_datapipes.keys(): diff --git a/tests/training/test_metnet_gsp_national.py b/tests/training/test_metnet_gsp_national.py index 149bfdd21..3b0874f02 100644 --- a/tests/training/test_metnet_gsp_national.py +++ b/tests/training/test_metnet_gsp_national.py @@ -18,9 +18,12 @@ def test_metnet_datapipe(): if i + 1 % 50000 == 0: break + def test_metnet_gsp_image_datapipe(): filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - gsp_datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128) + gsp_datapipe = metnet_national_datapipe( + filename, use_pv=False, gsp_in_image=True, output_size=128 + ) dataloader = iter(gsp_datapipe) batch = next(dataloader) x, y = batch