Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 28, 2022
1 parent 5b74827 commit cb9ac4a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ocf_datapipes/training/metnet_gsp_national.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def metnet_national_datapipe(
gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True)
return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples
else:
metnet_datapipe.zip(gsp_datapipe)
return metnet_datapipe.zip(gsp_datapipe)
10 changes: 10 additions & 0 deletions tests/training/test_metnet_gsp_national.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import numpy as np
import pytest
from torchdata.dataloader2 import DataLoader2

Expand All @@ -16,3 +17,12 @@ def test_metnet_datapipe():
_ = batch
if i + 1 % 50000 == 0:
break

def test_metnet_gsp_image_datapipe():
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")
gsp_datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128)
dataloader = iter(gsp_datapipe)
batch = next(dataloader)
x, y = batch
assert np.isfinite(x).all()
assert np.isfinite(y).all()

0 comments on commit cb9ac4a

Please sign in to comment.