diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index 26bc1902f..94063b46a 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -172,4 +172,4 @@ def metnet_national_datapipe( gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples else: - metnet_datapipe.zip(gsp_datapipe) + return metnet_datapipe.zip(gsp_datapipe) diff --git a/tests/training/test_metnet_gsp_national.py b/tests/training/test_metnet_gsp_national.py index aaf73eb01..149bfdd21 100644 --- a/tests/training/test_metnet_gsp_national.py +++ b/tests/training/test_metnet_gsp_national.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from torchdata.dataloader2 import DataLoader2 @@ -16,3 +17,12 @@ def test_metnet_datapipe(): _ = batch if i + 1 % 50000 == 0: break + +def test_metnet_gsp_image_datapipe(): + filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") + gsp_datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128) + dataloader = iter(gsp_datapipe) + batch = next(dataloader) + x, y = batch + assert np.isfinite(x).all() + assert np.isfinite(y).all()