diff --git a/contrib/test_full_config.yml b/contrib/test_full_config.yml new file mode 100644 index 0000000..0f0582d --- /dev/null +++ b/contrib/test_full_config.yml @@ -0,0 +1,60 @@ + +datasets: + clean: test-data/clean + medium: test-data/medium + dirty: test-data/dirty + +stages: + - start + - mid + - end + +start: + - clean 0.8 + - medium 0.2 + - dirty 0 + - until clean 2 + +mid: + - clean 0.6 + - medium 0.3 + - dirty 0.1 + - until medium 1 + +end: + - clean 0.4 + - medium 0.3 + - dirty 0.3 + - until dirty 5 + +modifiers: + - UpperCase: 0.05 + - TitleCase: 0.05 + - Prefix: 0.05 + min_words: 2 + max_words: 5 + template: "__start__ {trg} __end__ " + - Merge: 0.01 + min_lines: 2 + max_lines: 4 + - Noise: 0.0005 + min_word_length: 2 # Minimum word length for each word in the noisy sentence + max_word_length: 5 # Maximum word length for each word in the noisy sentence + max_words: 6 # Maximum number of words in each noisy sentence + - Typos: 0.05 + char_swap: 0.1 # Swaps two random consecutive word characters in the string. + missing_char: 0.1 # Skips a random word character in the string. + extra_char: 0.1 # Adds an extra, keyboard-neighbor, letter next to a random word character. + nearby_char: 0.1 # Replaces a random word character with keyboard-neighbor letter. + similar_char: 0.1 # Replaces a random word character with another visually similar character. + skipped_space: 0.1 # Skips a random space from the string. + random_space: 0.1 # Adds a random space in the string. + repeated_char: 0.1 # Repeats a random word character. + unichar: 0.1 # Replaces a random consecutive repeated letter with a single letter. + - Tags: 0.08 + custom_detok_src: null + custom_detok_trg: zh + template: "__source__ {src} __target__ {trg} __done__" + +seed: 1111 +trainer: /usr/bin/cat diff --git a/src/opustrainer/modifiers/merge.py b/src/opustrainer/modifiers/merge.py index eda5a55..535b09b 100644 --- a/src/opustrainer/modifiers/merge.py +++ b/src/opustrainer/modifiers/merge.py @@ -54,19 +54,19 @@ class MergeModifier(Modifier): max_lines: 4 ``` """ - min_lines_merge: int - max_lines_merge: int + min_lines: int + max_lines: int - def __init__(self, probability: float, min_lines_merge: int=2, max_lines_merge: int=4): + def __init__(self, probability: float, min_lines: int=2, max_lines: int=4): super().__init__(probability) - self.min_lines_merge = min_lines_merge - self.max_lines_merge = max_lines_merge + self.min_lines = min_lines + self.max_lines = max_lines def __call__(self, batch:List[str]) -> Iterable[str]: i = 0 while i < len(batch): if self.probability > random.random(): - merge_size = random.randint(self.min_lines_merge, self.max_lines_merge) + merge_size = random.randint(self.min_lines, self.max_lines) yield merge_sents(batch[i:i+merge_size]) i += merge_size else: diff --git a/src/opustrainer/modifiers/noise.py b/src/opustrainer/modifiers/noise.py index 9d0a287..fba6770 100644 --- a/src/opustrainer/modifiers/noise.py +++ b/src/opustrainer/modifiers/noise.py @@ -23,10 +23,10 @@ class NoiseModifier(Modifier): max_word_length: int max_words: int - def __init__(self, probability: float, min_word_legnth: int=2, - max_word_length: int=5, max_words: int=6): + def __init__(self, probability: float, min_word_length: int=2, + max_word_length: int=5, max_words: int=6): super().__init__(probability) - self.min_word_length = min_word_legnth + self.min_word_length = min_word_length self.max_word_length = max_word_length self.max_words = max_words diff --git a/tests/test_trainer_cli.py b/tests/test_trainer_cli.py index ddd31a2..790c7b5 100644 --- a/tests/test_trainer_cli.py +++ b/tests/test_trainer_cli.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 +import os import unittest + +import yaml + from opustrainer.trainer import parse_args +from opustrainer.trainer import CurriculumLoader class TestArgumentParser(unittest.TestCase): @@ -28,3 +33,17 @@ def test_marian_log_args(self): 'trainer': ['marian', '--log', 'marian.log'] } self.assertEqual({**vars(parsed), **expected}, vars(parsed)) + + def test_config_loader(self): + config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + os.pardir, + 'contrib', + 'test_full_config.yml') + with open(config_path, 'r', encoding='utf-8') as fh: + config = yaml.safe_load(fh) + + curriculum = CurriculumLoader().load(config, basepath=os.path.dirname(config_path)) + + assert curriculum is not None +