Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Aug 16, 2024
1 parent 1893473 commit 58eb4da
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
3 changes: 3 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/tests"
},
"justMyCode": true
},
{
Expand Down
1 change: 1 addition & 0 deletions machine/jobs/thot/thot_word_alignment_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class ThotWordAlignmentModelFactory(WordAlignmentModelFactory):

def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer:
(self._model_dir / "tm").mkdir(parents=True, exist_ok=True)
direct_trainer = ThotWordAlignmentModelTrainer(
self._config.thot.word_alignment_model_type,
corpus.lowercase(),
Expand Down
4 changes: 2 additions & 2 deletions tests/translation/thot/test_thot_smt_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from tempfile import TemporaryDirectory

from machine.translation.thot import ThotSmtModel, ThotSmtModelTrainer, ThotSmtParameters, ThotWordAlignmentModelType
from translation.thot.thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus

from .thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus
from machine.translation.thot import ThotSmtModel, ThotSmtModelTrainer, ThotSmtParameters, ThotWordAlignmentModelType


def test_train_non_empty_corpus() -> None:
Expand Down
47 changes: 29 additions & 18 deletions tests/translation/thot/test_thot_word_alignment_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from translation.thot.thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus

from machine.corpora.parallel_text_corpus import ParallelTextCorpus
from machine.tokenization.whitespace_tokenizer import WhitespaceTokenizer
from machine.tokenization import StringTokenizer, WhitespaceTokenizer
from machine.translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer
from machine.translation.thot import ThotWordAlignmentModelTrainer
from machine.translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel
from machine.translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model
from machine.translation.word_alignment_matrix import WordAlignmentMatrix

from .thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus


def train_model(
corpus: ParallelTextCorpus, direct_model_path: Path, inverse_model_path: Path, thot_word_alignment_model_type: str
corpus: ParallelTextCorpus,
direct_model_path: Path,
inverse_model_path: Path,
thot_word_alignment_model_type: str,
tokenizer: StringTokenizer,
):
tokenizer = WhitespaceTokenizer()
direct_trainer = ThotWordAlignmentModelTrainer(
thot_word_alignment_model_type,
corpus.lowercase(),
Expand All @@ -32,41 +35,49 @@ def train_model(
)

with SymmetrizedWordAlignmentModelTrainer(direct_trainer, inverse_trainer) as trainer:
trainer.train()
trainer.train(lambda status: print(f"{status.message}: {status.percent_completed:.2%}"))
trainer.save()


def test_train_non_empty_corpus() -> None:
thot_word_alignment_model_type = "hmm"
tokenizer = WhitespaceTokenizer()
corpus = get_parallel_corpus()

with TemporaryDirectory() as temp_dir:
corpus = get_parallel_corpus()
thot_word_alignment_model_type = "hmm"
tmp_path = Path(temp_dir)
(tmp_path / "tm").mkdir()
direct_model_path = tmp_path / "tm" / "src_trg_invswm"
inverse_model_path = tmp_path / "tm" / "src_trg_swm"
train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type)
train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type, tokenizer)
with ThotSymmetrizedWordAlignmentModel(
create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path),
create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path),
) as model:
matrix = model.align("una habitación individual por semana", "a single room cost per week")
assert matrix == WordAlignmentMatrix.from_word_pairs(
6, 7, {(0, 0), (1, 1), (2, 2), (4, 3), (3, 4), (3, 5), (5, 6)}
matrix = model.align(
list(tokenizer.tokenize("una habitación individual por semana")),
list(tokenizer.tokenize("a single room cost per week")),
)
assert matrix == WordAlignmentMatrix.from_word_pairs(5, 6, {(0, 2), (1, 2), (2, 3), (2, 4), (2, 5)})


def test_train_empty_corpus() -> None:
thot_word_alignment_model_type = "hmm"
tokenizer = WhitespaceTokenizer()
corpus = get_emtpy_parallel_corpus()
with TemporaryDirectory() as temp_dir:
corpus = get_emtpy_parallel_corpus()
thot_word_alignment_model_type = "hmm"
tmp_path = Path(temp_dir)
direct_model_path = tmp_path / "tm" / "src_trg_invswm"
inverse_model_path = tmp_path / "tm" / "src_trg_swm"
train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type)
train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type, tokenizer)
with ThotSymmetrizedWordAlignmentModel(
create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path),
create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path),
) as model:
matrix = model.align("una habitación individual por semana", "a single room cost per week")
assert matrix == WordAlignmentMatrix.from_word_pairs(
6, 7, {(0, 0), (1, 1), (2, 2), (4, 3), (3, 4), (3, 5), (5, 6)}
)
assert matrix == WordAlignmentMatrix.from_word_pairs(5, 6, {(0, 0)})


if __name__ == "__main__":
test_train_non_empty_corpus()
test_train_empty_corpus()

0 comments on commit 58eb4da

Please sign in to comment.