From 9ec77d3745823f9e05016700938e6b2ffbb770e0 Mon Sep 17 00:00:00 2001 From: evgeny pavlov <epavlov@mozilla.com> Date: Wed, 7 Feb 2024 16:24:22 -0800 Subject: [PATCH 1/2] Fix names of parameters for merge and noise --- src/opustrainer/modifiers/merge.py | 12 +++--- src/opustrainer/modifiers/noise.py | 6 +-- tests/test_config.yml | 61 ++++++++++++++++++++++++++++++ tests/test_trainer_cli.py | 15 ++++++++ 4 files changed, 85 insertions(+), 9 deletions(-) create mode 100644 tests/test_config.yml diff --git a/src/opustrainer/modifiers/merge.py b/src/opustrainer/modifiers/merge.py index eda5a55..535b09b 100644 --- a/src/opustrainer/modifiers/merge.py +++ b/src/opustrainer/modifiers/merge.py @@ -54,19 +54,19 @@ class MergeModifier(Modifier): max_lines: 4 ``` """ - min_lines_merge: int - max_lines_merge: int + min_lines: int + max_lines: int - def __init__(self, probability: float, min_lines_merge: int=2, max_lines_merge: int=4): + def __init__(self, probability: float, min_lines: int=2, max_lines: int=4): super().__init__(probability) - self.min_lines_merge = min_lines_merge - self.max_lines_merge = max_lines_merge + self.min_lines = min_lines + self.max_lines = max_lines def __call__(self, batch:List[str]) -> Iterable[str]: i = 0 while i < len(batch): if self.probability > random.random(): - merge_size = random.randint(self.min_lines_merge, self.max_lines_merge) + merge_size = random.randint(self.min_lines, self.max_lines) yield merge_sents(batch[i:i+merge_size]) i += merge_size else: diff --git a/src/opustrainer/modifiers/noise.py b/src/opustrainer/modifiers/noise.py index 9d0a287..fba6770 100644 --- a/src/opustrainer/modifiers/noise.py +++ b/src/opustrainer/modifiers/noise.py @@ -23,10 +23,10 @@ class NoiseModifier(Modifier): max_word_length: int max_words: int - def __init__(self, probability: float, min_word_legnth: int=2, - max_word_length: int=5, max_words: int=6): + def __init__(self, probability: float, min_word_length: int=2, + max_word_length: int=5, max_words: int=6): super().__init__(probability) - self.min_word_length = min_word_legnth + self.min_word_length = min_word_length self.max_word_length = max_word_length self.max_words = max_words diff --git a/tests/test_config.yml b/tests/test_config.yml new file mode 100644 index 0000000..380995c --- /dev/null +++ b/tests/test_config.yml @@ -0,0 +1,61 @@ + +datasets: + clean: test-data/clean + medium: test-data/medium + dirty: test-data/dirty + +stages: + - start + - mid + - end + +start: + - clean 0.8 + - medium 0.2 + - dirty 0 + - until clean 2 + +mid: + - clean 0.6 + - medium 0.3 + - dirty 0.1 + - until medium 1 + +end: + - clean 0.4 + - medium 0.3 + - dirty 0.3 + - until dirty 5 + +modifiers: + - UpperCase: 0.05 + - TitleCase: 0.05 + - Tags: 0.08 + custom_detok_src: null + custom_detok_trg: zh + template: "__source__ {src} __target__ {trg} __done__" + - Prefix: 0.05 + min_words: 2 + max_words: 5 + template: "__start__ {trg} __end__ " + - Merge: 0.01 + min_lines: 2 + max_lines: 4 + - Noise: 0.0005 + min_word_length: 2 # Minimum word length for each word in the noisy sentence + max_word_length: 5 # Maximum word length for each word in the noisy sentence + max_words: 6 # Maximum number of words in each noisy sentence + - Typos: 0.05 + char_swap: 0.1 # Swaps two random consecutive word characters in the string. + missing_char: 0.1 # Skips a random word character in the string. + extra_char: 0.1 # Adds an extra, keyboard-neighbor, letter next to a random word character. + nearby_char: 0.1 # Replaces a random word character with keyboard-neighbor letter. + similar_char: 0.1 # Replaces a random word character with another visually similar character. + skipped_space: 0.1 # Skips a random space from the string. + random_space: 0.1 # Adds a random space in the string. + repeated_char: 0.1 # Repeats a random word character. + unichar: 0.1 # Replaces a random consecutive repeated letter with a single letter. + + +seed: 1111 +trainer: /usr/bin/cat diff --git a/tests/test_trainer_cli.py b/tests/test_trainer_cli.py index ddd31a2..06c94f7 100644 --- a/tests/test_trainer_cli.py +++ b/tests/test_trainer_cli.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 +import os import unittest + +import yaml + from opustrainer.trainer import parse_args +from opustrainer.trainer import CurriculumLoader class TestArgumentParser(unittest.TestCase): @@ -28,3 +33,13 @@ def test_marian_log_args(self): 'trainer': ['marian', '--log', 'marian.log'] } self.assertEqual({**vars(parsed), **expected}, vars(parsed)) + + def test_config_loader(self): + config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_config.yml') + with open(config_path, 'r', encoding='utf-8') as fh: + config = yaml.safe_load(fh) + + curriculum = CurriculumLoader().load(config, basepath=os.path.dirname(config_path)) + + assert curriculum is not None + From a8e78f4038000abab4dff1c83d9eb56d4ca6f406 Mon Sep 17 00:00:00 2001 From: evgeny pavlov <epavlov@mozilla.com> Date: Wed, 7 Feb 2024 17:03:24 -0800 Subject: [PATCH 2/2] Move test config to contrib --- tests/test_config.yml => contrib/test_full_config.yml | 9 ++++----- tests/test_trainer_cli.py | 6 +++++- 2 files changed, 9 insertions(+), 6 deletions(-) rename tests/test_config.yml => contrib/test_full_config.yml (99%) diff --git a/tests/test_config.yml b/contrib/test_full_config.yml similarity index 99% rename from tests/test_config.yml rename to contrib/test_full_config.yml index 380995c..0f0582d 100644 --- a/tests/test_config.yml +++ b/contrib/test_full_config.yml @@ -30,10 +30,6 @@ end: modifiers: - UpperCase: 0.05 - TitleCase: 0.05 - - Tags: 0.08 - custom_detok_src: null - custom_detok_trg: zh - template: "__source__ {src} __target__ {trg} __done__" - Prefix: 0.05 min_words: 2 max_words: 5 @@ -55,7 +51,10 @@ modifiers: random_space: 0.1 # Adds a random space in the string. repeated_char: 0.1 # Repeats a random word character. unichar: 0.1 # Replaces a random consecutive repeated letter with a single letter. - + - Tags: 0.08 + custom_detok_src: null + custom_detok_trg: zh + template: "__source__ {src} __target__ {trg} __done__" seed: 1111 trainer: /usr/bin/cat diff --git a/tests/test_trainer_cli.py b/tests/test_trainer_cli.py index 06c94f7..790c7b5 100644 --- a/tests/test_trainer_cli.py +++ b/tests/test_trainer_cli.py @@ -35,7 +35,11 @@ def test_marian_log_args(self): self.assertEqual({**vars(parsed), **expected}, vars(parsed)) def test_config_loader(self): - config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_config.yml') + config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + os.pardir, + 'contrib', + 'test_full_config.yml') with open(config_path, 'r', encoding='utf-8') as fh: config = yaml.safe_load(fh)