diff --git a/contrib/test-data/test_enzh_config_plain_expected.log b/contrib/test-data/test_enzh_config_plain_expected.log index ea33b7b..46b0cd5 100644 --- a/contrib/test-data/test_enzh_config_plain_expected.log +++ b/contrib/test-data/test_enzh_config_plain_expected.log @@ -9,4 +9,4 @@ [2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 7 [2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 8 [2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 9 -[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit. Press ctrl-c to be more aggressive +[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7c40ff7..503f7cd 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -9,7 +9,7 @@ from contextlib import closing from textwrap import dedent from io import StringIO -from itertools import chain +from itertools import chain, islice import yaml @@ -173,10 +173,11 @@ def test_resume(self): # Train on trainer1 with closing(Trainer(curriculum)) as trainer1: - batches = [batch for _, batch in zip(range(10), state_tracker.run(trainer1))] + batches = list(islice(state_tracker.run(trainer1), 10)) # Resume on trainer2 with closing(Trainer(curriculum)) as trainer2: + state_tracker.restore(trainer2) batches.extend(state_tracker.run(trainer2)) self.assertEqual(batches, batches_ref) diff --git a/tests/test_trainer_cli.py b/tests/test_trainer_cli.py index 7534621..79f013b 100644 --- a/tests/test_trainer_cli.py +++ b/tests/test_trainer_cli.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 import unittest -import yaml import sys from subprocess import Popen from pathlib import Path from tempfile import TemporaryDirectory, TemporaryFile + +import yaml + from opustrainer.trainer import parse_args @@ -60,7 +62,7 @@ def test_early_stopping(self): 'seed': 1111 } - with TemporaryDirectory() as tmp, TemporaryFile() as fout: + with TemporaryDirectory() as tmp, TemporaryFile() as fout, TemporaryFile() as ferr: with open(Path(tmp) / 'config.yml', 'w+t') as fcfg: yaml.safe_dump(config, fcfg) @@ -71,13 +73,15 @@ def test_early_stopping(self): '--no-shuffle', '--config', str(Path(tmp) / 'config.yml'), 'head', '-n', str(head_lines) - ], stdout=fout) + ], stdout=fout, stderr=ferr) - # Assert we exited neatly retval = child.wait(30) - self.assertEqual(retval, 0) + fout.seek(0) + ferr.seek(0) + + # Assert we exited neatly + self.assertEqual(retval, 0, msg=ferr.read().decode()) # Assert we got the number of lines we'd expect - fout.seek(0) line_count = sum(1 for _ in fout) self.assertEqual(line_count, len(config['stages']) * head_lines)