diff --git a/tests/torch_datasets/test_pvnet_uk_regional.py b/tests/torch_datasets/test_pvnet_uk_regional.py index 6991ecc..ff5c2fe 100644 --- a/tests/torch_datasets/test_pvnet_uk_regional.py +++ b/tests/torch_datasets/test_pvnet_uk_regional.py @@ -35,7 +35,10 @@ def test_pvnet(pvnet_config_filename): assert isinstance(sample, dict) - for key in [BatchKey.nwp, BatchKey.satellite_actual, BatchKey.gsp]: + for key in [ + BatchKey.nwp, BatchKey.satellite_actual, BatchKey.gsp, + BatchKey.gsp_solar_azimuth, BatchKey.gsp_solar_elevation, + ]: assert key in sample for nwp_source in ["ukv"]: @@ -48,6 +51,9 @@ def test_pvnet(pvnet_config_filename): assert sample[BatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) # 3 hours of 30 minute data (inclusive) assert sample[BatchKey.gsp].shape == (7,) + # Solar angles have same shape as GSP data + assert sample[BatchKey.gsp_solar_azimuth].shape == (7,) + assert sample[BatchKey.gsp_solar_elevation].shape == (7,)