-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add merge and standalone noise modifiers
- Loading branch information
1 parent
452c4ea
commit 85ebe30
Showing
10 changed files
with
288 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# This file contains merge modifier and noise modifier | ||
import random | ||
from opustrainer.modifiers import Modifier | ||
|
||
import random | ||
from typing import List, Sequence, Union | ||
from opustrainer.modifiers import Modifier | ||
from opustrainer.alignments import format_alignments, parse_alignments, Pair | ||
|
||
def merge_sents(inputs: List[str]) -> str: | ||
"""Merges n sentences together, fixing up their alignments""" | ||
srcs: List[List[str]] = [x.split('\t')[0].split() for x in inputs] | ||
trgs: List[List[str]] = [x.split('\t')[1].split() for x in inputs] | ||
align_txt: Union[str, None] = None | ||
if len(inputs[0].split('\t')) > 2: | ||
aligns: List[List[Pair]] = [parse_alignments(x.split('\t')[2].strip()) for x in inputs] | ||
|
||
add_src = len(srcs[0]) | ||
add_trg = len(trgs[0]) | ||
for i in range(1, len(srcs)): | ||
for j in range(len(aligns[i])): | ||
aligns[i][j] = Pair(aligns[i][j][0] + add_src, aligns[i][j][1] + add_trg) | ||
add_src = add_src + len(srcs[i]) | ||
add_trg = add_trg + len(trgs[i]) | ||
|
||
align_txt = format_alignments([item for sublist in aligns for item in sublist]) | ||
|
||
srcs_txt: str = " ".join([x.split('\t')[0] for x in inputs]) | ||
trgs_txt: str = " ".join([x.split('\t')[1] for x in inputs]) | ||
|
||
if align_txt is not None: | ||
return srcs_txt + '\t' + trgs_txt + '\t' + align_txt | ||
else: | ||
return srcs_txt + '\t' + trgs_txt | ||
|
||
class MergeModifier(Modifier): | ||
"""Randomly merges up to n lines into one | ||
Usage: | ||
```yaml | ||
modifiers: | ||
- Merge: 0.01 | ||
min_lines: 2 | ||
max_lines: 4 | ||
``` | ||
""" | ||
min_lines_merge: int | ||
max_lines_merge: int | ||
def __init__(self, probability: float=0.0, min_lines_merge: int=2, max_lines_merge: int=4): | ||
super().__init__(probability) | ||
self.min_lines_merge = min_lines_merge | ||
self.max_lines_merge = max_lines_merge | ||
|
||
def __call__(self, batch:List[str]) -> Sequence[str]: | ||
newbatch: List[str] = [] | ||
# Identify merging candidates and their lengths | ||
prev_end = -1 | ||
for i in range(len(batch)): | ||
if i < prev_end: | ||
continue | ||
elif self.probability > random.random(): | ||
merge_end = i + random.randint(self.min_lines_merge, self.max_lines_merge) | ||
prev_end = merge_end | ||
merge_batch: str = merge_sents(batch[i:merge_end]) | ||
newbatch.append(merge_batch) | ||
else: | ||
newbatch.append(batch[i]) | ||
|
||
return newbatch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# This file contains merge modifier and noise modifier | ||
import random | ||
from opustrainer.modifiers import Modifier | ||
from opustrainer.modifiers.placeholders import get_random_unicode_words | ||
|
||
import random | ||
from typing import List, Sequence | ||
from opustrainer.modifiers import Modifier | ||
|
||
class NoiseModifier(Modifier): | ||
"""Adds noise during training. Nonsensitcal string on the source and on the target | ||
Usage: | ||
```yaml | ||
modifiers: | ||
- Noise: 0.01 | ||
min_word_length: 2 | ||
max_word_length: 5 | ||
max_words: 6 | ||
``` | ||
""" | ||
min_word_length: int | ||
max_word_length: int | ||
max_words: int | ||
|
||
def __init__(self, probability: float=0.0, min_word_legnth: int=2, | ||
max_word_length: int=5, max_words: int=6): | ||
super().__init__(probability) | ||
self.min_word_length = min_word_legnth | ||
self.max_word_length = max_word_length | ||
self.max_words = max_words | ||
|
||
def __call__(self, batch:List[str]) -> Sequence[str]: | ||
"""Generates a random noise line""" | ||
# The only problem is that we don't know if the dataset is supposed to have an alignment field | ||
# or not... A tradeoff is to look at the previous line and see if it has alignment info and then follow that | ||
# it's not ideal as we might hit a defective line, but oh well... | ||
ret_batch: List[str] = [] | ||
for line in batch: | ||
if self.probability > random.random(): | ||
newline: str = " ".join(get_random_unicode_words(self.min_word_length, self.max_word_length, self.max_words)) | ||
# Check if we have a 3rd field, which we assume is alignment | ||
if line.count('\t') == 2: | ||
# Generate alignments, just in case | ||
alignments: str = "" | ||
myrange = range(newline.count(' ') + 1) | ||
for j in myrange: | ||
alignments = alignments + str(j) + '-' + str(j) + " " | ||
alignments = alignments[:-1] # remove final space | ||
ret_batch.append(newline +'\t' + newline + '\t' + alignments) | ||
else: | ||
ret_batch.append(newline +'\t' + newline) | ||
ret_batch.append(line) | ||
return ret_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from doctest import Example | ||
import random | ||
import unittest | ||
|
||
from opustrainer.modifiers.merge import MergeModifier, merge_sents | ||
|
||
class TestMerge(unittest.TestCase): | ||
def setUp(self): | ||
random.seed(1) | ||
|
||
# Set up examples | ||
self.example = [ | ||
'429 运输 中队 ( 429 野牛) , 使用 CC - 177 429 Transport Squadron (429 Bison Squadron) - Flying the CC-177 0-0 1-1 2-2 3-3 4-3 5-4 5-5 7-5 8-5 9-6 8-7 9-8 10-9', | ||
"微生物 检验 与 食品 安全 控制 . Food Poisoning and Food Hygiene. 3-0 0-1 1-1 2-1 2-2 3-3 4-3 5-4 6-4" | ||
]*10 | ||
|
||
self.example_noalign = ["\t".join(a.split('\t')[:-1]) for a in self.example] | ||
|
||
# counts | ||
self.psn_cnt = " ".join(self.example).count('Poisoning') # 10 | ||
self.num_cnt = " ".join(self.example).count('429') # 40 because it appears once in src and trg | ||
|
||
def test_merge(self): | ||
merged = merge_sents(self.example_noalign[0:3]) | ||
expected = '429 运输 中队 ( 429 野牛) , 使用 CC - 177 微生物 检验 与 食品 安全 控制 . 429 运输 中队 ( 429 野牛) , 使用 CC - 177\t429 Transport Squadron (429 Bison Squadron) - Flying the CC-177 Food Poisoning and Food Hygiene. 429 Transport Squadron (429 Bison Squadron) - Flying the CC-177' | ||
self.assertEqual(merged, expected) | ||
|
||
# Expected based on counts | ||
lensrc = sum([len(a.split('\t')[0].split()) for a in self.example_noalign[0:3]]) | ||
lentrg = sum([len(a.split('\t')[1].split()) for a in self.example_noalign[0:3]]) | ||
|
||
lenmrgsrc = len(merged.split('\t')[0].split()) | ||
lenmrgtrg = len(merged.split('\t')[1].split()) | ||
self.assertEqual(lensrc, lenmrgsrc) | ||
self.assertEqual(lentrg, lenmrgtrg) | ||
|
||
def test_merge_align(self): | ||
merged = merge_sents(self.example[0:3]) | ||
expected = '429 运输 中队 ( 429 野牛) , 使用 CC - 177 微生物 检验 与 食品 安全 控制 . 429 运输 中队 ( 429 野牛) , 使用 CC - 177\t429 Transport Squadron (429 Bison Squadron) - Flying the CC-177 Food Poisoning and Food Hygiene. 429 Transport Squadron (429 Bison Squadron) - Flying the CC-177\t0-0 1-1 2-2 3-3 4-3 5-4 5-5 7-5 8-5 9-6 8-7 9-8 10-9 14-10 11-11 12-11 13-11 13-12 14-13 15-13 16-14 17-14 18-15 19-16 20-17 21-18 22-18 23-19 23-20 25-20 26-20 27-21 26-22 27-23 28-24' | ||
self.assertEqual(merged, expected) | ||
|
||
# Expected based on counts | ||
lensrc = sum([len(a.split('\t')[0].split()) for a in self.example[0:3]]) | ||
lentrg = sum([len(a.split('\t')[1].split()) for a in self.example[0:3]]) | ||
|
||
lenmrgsrc = len(merged.split('\t')[0].split()) | ||
lenmrgtrg = len(merged.split('\t')[1].split()) | ||
self.assertEqual(lensrc, lenmrgsrc) | ||
self.assertEqual(lentrg, lenmrgtrg) | ||
|
||
# Test alignment based on final letter | ||
len_srcalign_final = len(merged.split('\t')[0].split()) | ||
len_trgalign_final = len(merged.split('\t')[1].split()) | ||
self.assertEqual(len_srcalign_final, 29) | ||
self.assertEqual(len_trgalign_final, 25) | ||
|
||
def test_merge_full(self): | ||
merger = MergeModifier(0.8) | ||
merged = merger(self.example_noalign) | ||
|
||
psn_cnt = " ".join(merged).count('Poisoning') | ||
num_cnt = " ".join(merged).count('429') | ||
|
||
self.assertNotEqual(len(merged), len(self.example_noalign)) # Assert it being activated | ||
self.assertEqual(self.psn_cnt, psn_cnt) | ||
self.assertEqual(self.num_cnt, num_cnt) | ||
|
||
def test_merge_full_align(self): | ||
merger = MergeModifier(0.8) | ||
merged = merger(self.example) | ||
|
||
psn_cnt = " ".join(merged).count('Poisoning') | ||
num_cnt = " ".join(merged).count('429') | ||
|
||
self.assertNotEqual(len(merged), len(self.example)) # Assert it being activated | ||
self.assertEqual(self.psn_cnt, psn_cnt) | ||
self.assertEqual(self.num_cnt, num_cnt) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from doctest import Example | ||
import enum | ||
import random | ||
import unittest | ||
|
||
from opustrainer.modifiers.noise import NoiseModifier | ||
|
||
class TestMerge(unittest.TestCase): | ||
def setUp(self): | ||
random.seed(1) | ||
|
||
# Set up examples | ||
self.example = [ | ||
'429 运输 中队 ( 429 野牛) , 使用 CC - 177 429 Transport Squadron (429 Bison Squadron) - Flying the CC-177 0-0 1-1 2-2 3-3 4-3 5-4 5-5 7-5 8-5 9-6 8-7 9-8 10-9', | ||
"微生物 检验 与 食品 安全 控制 . Food Poisoning and Food Hygiene. 3-0 0-1 1-1 2-1 2-2 3-3 4-3 5-4 6-4" | ||
]*10 | ||
|
||
self.example_noalign = ["\t".join(a.split('\t')[:-1]) for a in self.example] | ||
|
||
# With 20% prob this is triggered 3 times we check one of the matches. We expect new length to be 23 | ||
self.num_nine_noise = "쑥맜\t쑥맜" | ||
self.num_nine_noise_align = "쑥맜\t쑥맜\t0-0" | ||
|
||
def test_noise(self): | ||
noiser = NoiseModifier(0.2) | ||
noised = noiser(self.example_noalign) | ||
self.assertEqual(noised[9], self.num_nine_noise) | ||
self.assertEqual(len(noised), 23) | ||
|
||
def test_noise_align(self): | ||
noiser = NoiseModifier(0.2) | ||
noised = noiser(self.example) | ||
self.assertEqual(noised[9], self.num_nine_noise_align) | ||
self.assertEqual(len(noised), 23) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters