Skip to content

Commit

Permalink
tests: new args for train diag fig
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 29, 2023
1 parent 8505ac5 commit 2427ddb
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion test/train_test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mpol import losses, precomposed
from mpol.plot import train_diagnostics_fig
from mpol.training import TrainTest, train_to_dirty_image
from mpol.constants import *
from mpol.utils import torch2npy


def test_traintestclass_training(coords, imager, dataset, generic_parameters):
Expand Down Expand Up @@ -83,8 +83,15 @@ def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_p
trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars)
loss, loss_history = trainer.train(model, dataset)

learn_rates = np.repeat(learn_rate, len(loss_history))

old_mod_im = torch2npy(model.icube.sky_cube[0])

train_fig, train_axes = train_diagnostics_fig(model,
losses=loss_history,
learn_rates=learn_rates,
fluxes=np.zeros(len(loss_history)),
old_model_image=old_mod_im
)
train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300)
plt.close("all")
Expand Down

0 comments on commit 2427ddb

Please sign in to comment.