Skip to content

Commit

Permalink
Merge pull request #52 from mozilla/fix_config_parsing
Browse files Browse the repository at this point in the history
Fix names of parameters for merge and noise
  • Loading branch information
jelmervdl authored Feb 8, 2024
2 parents 4fceac9 + a8e78f4 commit c966d7b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 9 deletions.
60 changes: 60 additions & 0 deletions contrib/test_full_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

datasets:
clean: test-data/clean
medium: test-data/medium
dirty: test-data/dirty

stages:
- start
- mid
- end

start:
- clean 0.8
- medium 0.2
- dirty 0
- until clean 2

mid:
- clean 0.6
- medium 0.3
- dirty 0.1
- until medium 1

end:
- clean 0.4
- medium 0.3
- dirty 0.3
- until dirty 5

modifiers:
- UpperCase: 0.05
- TitleCase: 0.05
- Prefix: 0.05
min_words: 2
max_words: 5
template: "__start__ {trg} __end__ "
- Merge: 0.01
min_lines: 2
max_lines: 4
- Noise: 0.0005
min_word_length: 2 # Minimum word length for each word in the noisy sentence
max_word_length: 5 # Maximum word length for each word in the noisy sentence
max_words: 6 # Maximum number of words in each noisy sentence
- Typos: 0.05
char_swap: 0.1 # Swaps two random consecutive word characters in the string.
missing_char: 0.1 # Skips a random word character in the string.
extra_char: 0.1 # Adds an extra, keyboard-neighbor, letter next to a random word character.
nearby_char: 0.1 # Replaces a random word character with keyboard-neighbor letter.
similar_char: 0.1 # Replaces a random word character with another visually similar character.
skipped_space: 0.1 # Skips a random space from the string.
random_space: 0.1 # Adds a random space in the string.
repeated_char: 0.1 # Repeats a random word character.
unichar: 0.1 # Replaces a random consecutive repeated letter with a single letter.
- Tags: 0.08
custom_detok_src: null
custom_detok_trg: zh
template: "__source__ {src} __target__ {trg} __done__"

seed: 1111
trainer: /usr/bin/cat
12 changes: 6 additions & 6 deletions src/opustrainer/modifiers/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ class MergeModifier(Modifier):
max_lines: 4
```
"""
min_lines_merge: int
max_lines_merge: int
min_lines: int
max_lines: int

def __init__(self, probability: float, min_lines_merge: int=2, max_lines_merge: int=4):
def __init__(self, probability: float, min_lines: int=2, max_lines: int=4):
super().__init__(probability)
self.min_lines_merge = min_lines_merge
self.max_lines_merge = max_lines_merge
self.min_lines = min_lines
self.max_lines = max_lines

def __call__(self, batch:List[str]) -> Iterable[str]:
i = 0
while i < len(batch):
if self.probability > random.random():
merge_size = random.randint(self.min_lines_merge, self.max_lines_merge)
merge_size = random.randint(self.min_lines, self.max_lines)
yield merge_sents(batch[i:i+merge_size])
i += merge_size
else:
Expand Down
6 changes: 3 additions & 3 deletions src/opustrainer/modifiers/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class NoiseModifier(Modifier):
max_word_length: int
max_words: int

def __init__(self, probability: float, min_word_legnth: int=2,
max_word_length: int=5, max_words: int=6):
def __init__(self, probability: float, min_word_length: int=2,
max_word_length: int=5, max_words: int=6):
super().__init__(probability)
self.min_word_length = min_word_legnth
self.min_word_length = min_word_length
self.max_word_length = max_word_length
self.max_words = max_words

Expand Down
19 changes: 19 additions & 0 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#!/usr/bin/env python3
import os
import unittest

import yaml

from opustrainer.trainer import parse_args
from opustrainer.trainer import CurriculumLoader


class TestArgumentParser(unittest.TestCase):
Expand Down Expand Up @@ -28,3 +33,17 @@ def test_marian_log_args(self):
'trainer': ['marian', '--log', 'marian.log']
}
self.assertEqual({**vars(parsed), **expected}, vars(parsed))

def test_config_loader(self):
config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.pardir,
'contrib',
'test_full_config.yml')
with open(config_path, 'r', encoding='utf-8') as fh:
config = yaml.safe_load(fh)

curriculum = CurriculumLoader().load(config, basepath=os.path.dirname(config_path))

assert curriculum is not None

0 comments on commit c966d7b

Please sign in to comment.