Skip to content

Commit

Permalink
Fix multidim context merge bug (closes #93)
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Nov 10, 2023
1 parent 4382957 commit 2673492
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deepsensor/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ def mask_nans_numpy(self):

def f(arr):
if isinstance(arr, deepsensor.backend.nps.Masked):
# Ignore nps.Masked objects
nps_mask = arr.mask == 0
nan_mask = np.isnan(arr.y)
mask = np.logical_or(nps_mask, nan_mask)
mask = np.any(mask, axis=1, keepdims=True)
data = arr.y
data[nan_mask] = 0.0
arr = deepsensor.backend.nps.Masked(data, mask)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,46 @@ def test_training(self):
# Check for NaNs in the loss
loss = np.mean(epoch_losses)
self.assertFalse(np.isnan(loss))

def test_training_multidim(self):
"""A basic test of the training loop with multidimensional context sets"""
# Load raw data
ds_raw = xr.tutorial.open_dataset("air_temperature")

# Add extra dim
ds_raw["air2"] = ds_raw["air"].copy()

# Normalise data
dp = DataProcessor(x1_name="lat", x2_name="lon")
ds = dp(ds_raw)

# Set up task loader
tl = TaskLoader(context=ds, target=ds)

# Set up model
model = ConvNP(dp, tl)

# Generate training tasks
n_train_tasks = 10
train_tasks = []
for i in range(n_train_tasks):
date = np.random.choice(self.da.time.values)
task = tl(date, 10, 10)
task["Y_c"][0][:, 0] = np.nan # Add NaN to context
task["Y_t"][0][:, 0] = np.nan # Add NaN to target
print(task)
train_tasks.append(task)

# Train
trainer = Trainer(model, lr=5e-5)
# batch_size = None
batch_size = 5
n_epochs = 10
epoch_losses = []
for epoch in tqdm(range(n_epochs)):
batch_losses = trainer(train_tasks, batch_size=batch_size)
epoch_losses.append(np.mean(batch_losses))

# Check for NaNs in the loss
loss = np.mean(epoch_losses)
self.assertFalse(np.isnan(loss))

0 comments on commit 2673492

Please sign in to comment.