Skip to content

Commit

Permalink
tests passing again.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Dec 28, 2023
1 parent 2177960 commit ba63e19
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/mpol/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def neg_log_likelihood_avg(
Returns
-------
:class:`torch.Tensor` of :class:`torch.double`
the :math:`\chi^2_\mathrm{r}`, summed over all dimensions of input array.
the average of the negative log likelihood, summed over all dimensions of
input array.
"""
N = len(torch.ravel(data_vis)) # number of complex visibilities
ll = log_likelihood(model_vis, data_vis, weight)
Expand Down Expand Up @@ -297,8 +298,7 @@ def reduced_chi_squared_gridded(
Returns
-------
:class:`torch.Tensor` of :class:`torch.double`
the normalized negative log likelihood likelihood loss, summed over all input
values
the :math:`\chi^2_\mathrm{r}` value summed over all input dimensions
"""
model_vis = griddedDataset(modelVisibilityCube)

Expand Down
4 changes: 2 additions & 2 deletions src/mpol/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch

from mpol.losses import TSV, TV_image, entropy, nll_gridded, sparsity
from mpol.losses import TSV, TV_image, entropy, reduced_chi_squared_gridded, sparsity
from mpol.plot import train_diagnostics_fig
from mpol.utils import torch2npy

Expand Down Expand Up @@ -205,7 +205,7 @@ def loss_eval(self, vis, dataset, sky_cube=None):
Value of loss function
"""
# negative log-likelihood loss function
loss = nll_gridded(vis, dataset)
loss = reduced_chi_squared_gridded(vis, dataset)

# regularizers
if sky_cube is not None:
Expand Down
18 changes: 9 additions & 9 deletions test/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import torch

from mpol import fourier, images, losses, utils
from mpol import fourier, images, losses


# create a fixture that returns nchan and an image
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_nll_hermitian_pairs(loose_visibilities, mock_visibility_data):
data = torch.tensor(data_re + 1.0j * data_im)
weight = torch.tensor(weight)

log_like = losses.nll(loose_visibilities, data, weight)
log_like = losses.reduced_chi_squared(loose_visibilities, data, weight)
print("loose nll", log_like)

# calculate it with Hermitian pairs
Expand All @@ -116,7 +116,7 @@ def test_nll_hermitian_pairs(loose_visibilities, mock_visibility_data):
data = torch.cat([data, torch.conj(data)], axis=1)
weight = torch.cat([weight, weight], axis=1)

log_like = losses.nll(loose_visibilities, data, weight)
log_like = losses.reduced_chi_squared(loose_visibilities, data, weight)
print("loose nll w/ Hermitian", log_like)


Expand All @@ -133,11 +133,11 @@ def test_nll_evaluation(
data = torch.tensor(data_re + 1.0j * data_im)
weight = torch.tensor(weight)

log_like = losses.nll(loose_visibilities, data, weight)
log_like = losses.reduced_chi_squared(loose_visibilities, data, weight)
print("loose nll", log_like)

# calculate the gridded log likelihood
log_like_gridded = losses.nll_gridded(gridded_visibilities, dataset)
log_like_gridded = losses.reduced_chi_squared_gridded(gridded_visibilities, dataset)
print("gridded nll", log_like_gridded)


Expand All @@ -156,7 +156,7 @@ def test_nll_1D_zero():
data_im = model_im
data_vis = torch.complex(data_re, data_im)

loss = losses.nll(model_vis, data_vis, weights)
loss = losses.reduced_chi_squared(model_vis, data_vis, weights)
assert loss.item() == 0.0


Expand All @@ -175,7 +175,7 @@ def test_nll_1D_random():
data_im = torch.randn_like(weights)
data_vis = torch.complex(data_re, data_im)

losses.nll(model_vis, data_vis, weights)
losses.reduced_chi_squared(model_vis, data_vis, weights)


def test_nll_2D_zero():
Expand All @@ -195,7 +195,7 @@ def test_nll_2D_zero():
data_im = model_im
data_vis = torch.complex(data_re, data_im)

loss = losses.nll(model_vis, data_vis, weights)
loss = losses.reduced_chi_squared(model_vis, data_vis, weights)
assert loss.item() == 0.0


Expand All @@ -215,7 +215,7 @@ def test_nll_2D_random():
data_im = torch.randn_like(weights)
data_vis = torch.complex(data_re, data_im)

losses.nll(model_vis, data_vis, weights)
losses.reduced_chi_squared(model_vis, data_vis, weights)


def test_entropy_raise_error_negative():
Expand Down
6 changes: 3 additions & 3 deletions test/train_test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_standalone_init_train(coords, dataset):
rml.zero_grad()

# calculate a loss
loss = losses.nll_gridded(vis, dataset)
loss = losses.reduced_chi_squared_gridded(vis, dataset)

# calculate gradients of parameters
loss.backward()
Expand All @@ -147,7 +147,7 @@ def test_standalone_train_loop(coords, dataset_cont, tmp_path):
vis = rml()

# calculate a loss
loss = losses.nll_gridded(vis, dataset_cont)
loss = losses.reduced_chi_squared_gridded(vis, dataset_cont)

# calculate gradients of parameters
loss.backward()
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_tensorboard(coords, dataset_cont):
vis = rml()

# calculate a loss
loss = losses.nll_gridded(vis, dataset_cont)
loss = losses.reduced_chi_squared_gridded(vis, dataset_cont)

writer.add_scalar("loss", loss.item(), i)

Expand Down

0 comments on commit ba63e19

Please sign in to comment.