diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 9f9e8bb5..e8317846 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -37,7 +37,7 @@ class EasyDataAugmenter(Augmenter): https://arxiv.org/abs/1901.11196 """ - def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4): + def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4, **kwargs): assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]" assert ( transformations_per_example > 0 @@ -49,17 +49,22 @@ def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4): self.synonym_replacement = WordNetAugmenter( pct_words_to_swap=pct_words_to_swap, transformations_per_example=n_aug_each, + **kwargs, ) self.random_deletion = DeletionAugmenter( pct_words_to_swap=pct_words_to_swap, transformations_per_example=n_aug_each, + **kwargs, ) self.random_swap = SwapAugmenter( pct_words_to_swap=pct_words_to_swap, transformations_per_example=n_aug_each, + **kwargs, ) self.random_insertion = SynonymInsertionAugmenter( - pct_words_to_swap=pct_words_to_swap, transformations_per_example=n_aug_each + pct_words_to_swap=pct_words_to_swap, + transformations_per_example=n_aug_each, + **kwargs, ) def augment(self, text):