From 2427ddbf585cc3618067782bf9ed973b7f113723 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Wed, 29 Nov 2023 02:25:38 -0500 Subject: [PATCH] tests: new args for train diag fig --- test/train_test_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/train_test_test.py b/test/train_test_test.py index fd457e4c..9826d9fc 100644 --- a/test/train_test_test.py +++ b/test/train_test_test.py @@ -7,7 +7,7 @@ from mpol import losses, precomposed from mpol.plot import train_diagnostics_fig from mpol.training import TrainTest, train_to_dirty_image -from mpol.constants import * +from mpol.utils import torch2npy def test_traintestclass_training(coords, imager, dataset, generic_parameters): @@ -83,8 +83,15 @@ def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_p trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) loss, loss_history = trainer.train(model, dataset) + learn_rates = np.repeat(learn_rate, len(loss_history)) + + old_mod_im = torch2npy(model.icube.sky_cube[0]) + train_fig, train_axes = train_diagnostics_fig(model, losses=loss_history, + learn_rates=learn_rates, + fluxes=np.zeros(len(loss_history)), + old_model_image=old_mod_im ) train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300) plt.close("all")