Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmervdl committed Dec 23, 2023
1 parent b1a7f04 commit f142670
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion contrib/test-data/test_enzh_config_plain_expected.log
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 7
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 8
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 9
[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit. Press ctrl-c to be more aggressive
[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit
5 changes: 3 additions & 2 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import closing
from textwrap import dedent
from io import StringIO
from itertools import chain
from itertools import chain, islice

import yaml

Expand Down Expand Up @@ -173,10 +173,11 @@ def test_resume(self):

# Train on trainer1
with closing(Trainer(curriculum)) as trainer1:
batches = [batch for _, batch in zip(range(10), state_tracker.run(trainer1))]
batches = list(islice(state_tracker.run(trainer1), 10))

# Resume on trainer2
with closing(Trainer(curriculum)) as trainer2:
state_tracker.restore(trainer2)
batches.extend(state_tracker.run(trainer2))

self.assertEqual(batches, batches_ref)
Expand Down
16 changes: 10 additions & 6 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3
import unittest
import yaml
import sys
from subprocess import Popen
from pathlib import Path
from tempfile import TemporaryDirectory, TemporaryFile

import yaml

from opustrainer.trainer import parse_args


Expand Down Expand Up @@ -60,7 +62,7 @@ def test_early_stopping(self):
'seed': 1111
}

with TemporaryDirectory() as tmp, TemporaryFile() as fout:
with TemporaryDirectory() as tmp, TemporaryFile() as fout, TemporaryFile() as ferr:
with open(Path(tmp) / 'config.yml', 'w+t') as fcfg:
yaml.safe_dump(config, fcfg)

Expand All @@ -71,13 +73,15 @@ def test_early_stopping(self):
'--no-shuffle',
'--config', str(Path(tmp) / 'config.yml'),
'head', '-n', str(head_lines)
], stdout=fout)
], stdout=fout, stderr=ferr)

# Assert we exited neatly
retval = child.wait(30)
self.assertEqual(retval, 0)
fout.seek(0)
ferr.seek(0)

# Assert we exited neatly
self.assertEqual(retval, 0, msg=ferr.read().decode())

# Assert we got the number of lines we'd expect
fout.seek(0)
line_count = sum(1 for _ in fout)
self.assertEqual(line_count, len(config['stages']) * head_lines)

0 comments on commit f142670

Please sign in to comment.