From 035dfff71e4fd59806f854f652395ba3aa3ed7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sk=C3=B3rzewski?= Date: Fri, 15 Dec 2023 15:29:38 +0100 Subject: [PATCH 1/4] Implement back transcription augmentation method --- tests/test_augment_api.py | 10 ++ .../sentence_transformations/__init__.py | 1 + .../back_transcription.py | 95 +++++++++++++++++++ 3 files changed, 106 insertions(+) create mode 100644 textattack/transformations/sentence_transformations/back_transcription.py diff --git a/tests/test_augment_api.py b/tests/test_augment_api.py index 79e6fb46d..16c495f30 100644 --- a/tests/test_augment_api.py +++ b/tests/test_augment_api.py @@ -134,3 +134,13 @@ def test_back_translation(): augmented_text_list = augmenter.augment(s) augmented_s = "What the hell are you doing?" assert augmented_s in augmented_text_list + + +def test_back_transcription(): + from textattack.augmentation import Augmenter + from textattack.transformations.sentence_transformations import BackTranscription + + augmenter = Augmenter(transformation=BackTranscription()) + s = "What on earth are you doing?" + augmented_text_list = augmenter.augment(s) + assert augmented_text_list diff --git a/textattack/transformations/sentence_transformations/__init__.py b/textattack/transformations/sentence_transformations/__init__.py index a8a3928a2..24590edee 100644 --- a/textattack/transformations/sentence_transformations/__init__.py +++ b/textattack/transformations/sentence_transformations/__init__.py @@ -6,3 +6,4 @@ from .sentence_transformation import SentenceTransformation from .back_translation import BackTranslation +from .back_transcription import BackTranscription diff --git a/textattack/transformations/sentence_transformations/back_transcription.py b/textattack/transformations/sentence_transformations/back_transcription.py new file mode 100644 index 000000000..ae62930c9 --- /dev/null +++ b/textattack/transformations/sentence_transformations/back_transcription.py @@ -0,0 +1,95 @@ +""" +BackTranscription class +----------------------------------- + +""" + + +from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub +from fairseq.models.text_to_speech.hub_interface import TTSHubInterface +import librosa +from transformers import WhisperForConditionalGeneration, WhisperProcessor + +from textattack.shared import AttackedText + +from .sentence_transformation import SentenceTransformation + + +class BackTranscription(SentenceTransformation): + """A type of sentence level transformation that takes in a text input, converts it into + synthesized speech using ASR, and transcribes it back to text using TTS. + + tts_model: text-to-speech model from huggingface + asr_model: automatic speech recognition model from huggingface + + (!) Python libraries `fairseq`, `g2p_en` and `librosa` should be installed. + + Example:: + + >>> from textattack.transformations.sentence_transformations import BackTranscription + >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification + >>> from textattack.augmentation import Augmenter + + >>> transformation = BackTranscription() + >>> constraints = [RepeatModification(), StopwordModification()] + >>> augmenter = Augmenter(transformation = transformation, constraints = constraints) + >>> s = 'What on earth are you doing here.' + + >>> augmenter.augment(s) + """ + + def __init__( + self, + tts_model="facebook/fastspeech2-en-ljspeech", + asr_model="openai/whisper-base", + ): + # TTS model + self.tts_model_name = tts_model + models, cfg, self.tts_task = load_model_ensemble_and_task_from_hf_hub( + self.tts_model_name, + arg_overrides={"vocoder": "hifigan", "fp16": False}, + ) + self.tts_model = models[0] + TTSHubInterface.update_cfg_with_data_cfg(cfg, self.tts_task.data_cfg) + self.tts_generator = self.tts_task.build_generator(models, cfg) + + # ASR model + self.asr_model_name = asr_model + self.asr_sampling_rate = 16000 + self.asr_processor = WhisperProcessor.from_pretrained(self.asr_model_name) + self.asr_model = WhisperForConditionalGeneration.from_pretrained( + self.asr_model_name + ) + self.asr_model.config.forced_decoder_ids = None + + def back_transcribe(self, text): + # speech synthesis + sample = TTSHubInterface.get_model_input(self.tts_task, text) + wav, rate = TTSHubInterface.get_prediction( + self.tts_task, self.tts_model, self.tts_generator, sample + ) + + # speech recognition + resampled_wav = librosa.resample( + wav.numpy(), orig_sr=rate, target_sr=self.asr_sampling_rate + ) + input_features = self.asr_processor( + resampled_wav, sampling_rate=self.asr_sampling_rate, return_tensors="pt" + ).input_features + + predicted_ids = self.asr_model.generate(input_features) + + transcription = self.asr_processor.batch_decode( + predicted_ids, skip_special_tokens=True + ) + return transcription[0].strip() + + def _get_transformations(self, current_text, indices_to_modify): + transformed_texts = [] + current_text = current_text.text + + # do the back transcription + back_transcribed_text = self.back_transcribe([current_text]) + + transformed_texts.append(AttackedText(back_transcribed_text)) + return transformed_texts From 49d9c96c8c762a64bf227ed8f980ed195f0a05c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sk=C3=B3rzewski?= Date: Mon, 18 Dec 2023 14:41:10 +0100 Subject: [PATCH 2/4] Make TextAttack not dependnet from BackTranscription's dependencies --- tests/test_augment_api.py | 14 ++++++++++---- .../sentence_transformations/back_transcription.py | 10 +++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/test_augment_api.py b/tests/test_augment_api.py index 16c495f30..b38725f2a 100644 --- a/tests/test_augment_api.py +++ b/tests/test_augment_api.py @@ -140,7 +140,13 @@ def test_back_transcription(): from textattack.augmentation import Augmenter from textattack.transformations.sentence_transformations import BackTranscription - augmenter = Augmenter(transformation=BackTranscription()) - s = "What on earth are you doing?" - augmented_text_list = augmenter.augment(s) - assert augmented_text_list + try: + augmenter = Augmenter(transformation=BackTranscription()) + except ModuleNotFoundError: + print( + "To use BackTranscription augmenter, install `fairseq`, `g2p_en` and `librosa` libraries" + ) + else: + s = "What on earth are you doing?" + augmented_text_list = augmenter.augment(s) + assert augmented_text_list diff --git a/textattack/transformations/sentence_transformations/back_transcription.py b/textattack/transformations/sentence_transformations/back_transcription.py index ae62930c9..8d23dc2ea 100644 --- a/textattack/transformations/sentence_transformations/back_transcription.py +++ b/textattack/transformations/sentence_transformations/back_transcription.py @@ -5,9 +5,6 @@ """ -from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub -from fairseq.models.text_to_speech.hub_interface import TTSHubInterface -import librosa from transformers import WhisperForConditionalGeneration, WhisperProcessor from textattack.shared import AttackedText @@ -44,6 +41,9 @@ def __init__( asr_model="openai/whisper-base", ): # TTS model + from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub + from fairseq.models.text_to_speech.hub_interface import TTSHubInterface + self.tts_model_name = tts_model models, cfg, self.tts_task = load_model_ensemble_and_task_from_hf_hub( self.tts_model_name, @@ -64,12 +64,16 @@ def __init__( def back_transcribe(self, text): # speech synthesis + from fairseq.models.text_to_speech.hub_interface import TTSHubInterface + sample = TTSHubInterface.get_model_input(self.tts_task, text) wav, rate = TTSHubInterface.get_prediction( self.tts_task, self.tts_model, self.tts_generator, sample ) # speech recognition + import librosa + resampled_wav = librosa.resample( wav.numpy(), orig_sr=rate, target_sr=self.asr_sampling_rate ) From 1d353e9e74c628d395412b6c2e2caf2abcd8a3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sk=C3=B3rzewski?= Date: Mon, 18 Dec 2023 15:01:05 +0100 Subject: [PATCH 3/4] Command line option for back transcription and documentation --- README.md | 65 +++++++++---------- docs/3recipes/augmenter_recipes_cmd.md | 12 ++-- textattack/augment_args.py | 1 + textattack/augmentation/recipes.py | 12 ++++ .../back_transcription.py | 21 ++++++ 5 files changed, 72 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index c05e20263..a235f27d4 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,12 @@

Generating adversarial examples for NLP models

- [TextAttack Documentation on ReadTheDocs] + [TextAttack Documentation on ReadTheDocs]

AboutSetupUsage • - Design + Design

Github Runner Covergae Status @@ -19,7 +19,7 @@

TextAttack Demo GIF - + ## About TextAttack is a Python framework for adversarial attacks, data augmentation, and model training in NLP. @@ -52,8 +52,8 @@ pip install textattack Once TextAttack is installed, you can run it via command-line (`textattack ...`) or via python module (`python -m textattack ...`). -> **Tip**: TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models, -> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the +> **Tip**: TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models, +> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the > environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`). ## Usage @@ -62,16 +62,16 @@ or via python module (`python -m textattack ...`). TextAttack's main features can all be accessed via the `textattack` command. Two very common commands are `textattack attack `, and `textattack augment `. You can see more -information about all commands using +information about all commands using ```bash -textattack --help +textattack --help ``` or a specific command using, for example, ```bash textattack attack --help ``` -The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. +The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The [documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint.. @@ -80,18 +80,18 @@ The [documentation website](https://textattack.readthedocs.io/en/latest) contain ### Running Attacks: `textattack attack --help` -The easiest way to try out an attack is via the command-line interface, `textattack attack`. +The easiest way to try out an attack is via the command-line interface, `textattack attack`. > **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance. (If you want to attack Keras models in parallel, please check out `examples/attack/attack_keras_parallel.py` instead) Here are some concrete examples: -*TextFooler on BERT trained on the MR sentiment classification dataset*: +*TextFooler on BERT trained on the MR sentiment classification dataset*: ```bash textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 100 ``` -*DeepWordBug on DistilBERT trained on the Quora Question Pairs paraphrase identification dataset*: +*DeepWordBug on DistilBERT trained on the Quora Question Pairs paraphrase identification dataset*: ```bash textattack attack --model distilbert-base-uncased-cola --recipe deepwordbug --num-examples 100 ``` @@ -129,7 +129,7 @@ To run an attack recipe: `textattack attack --recipe [recipe_name]`
Attacks on classification tasks, like sentiment classification and entailment:
-a2t +a2t Untargeted {Classification, Entailment} Percentage of words perturbed, Word embedding distance, DistilBERT sentence encoding cosine similarity, part-of-speech consistency @@ -319,7 +319,8 @@ for data augmentation: - `eda` augments text with a combination of word insertions, substitutions and deletions. - `checklist` augments text by contraction/extension and by substituting names, locations, numbers. - `clare` augments text by replacing, inserting, and merging with a pre-trained masked language model. -- `back_trans` augments text by backtranslation approach. +- `back_trans` augments text by backtranslation approach. +- `back_transcription` augments text by back transcription approach. #### Augmentation Command-Line Interface @@ -339,7 +340,7 @@ For example, given the following as `examples.csv`: "it's a mystery how the movie could be released in this condition .", 0 ``` -The command +The command ```bash textattack augment --input-csv examples.csv --output-csv output.csv --input-column text --recipe embedding --pct-words-to-swap .1 --transformations-per-example 2 --exclude-original ``` @@ -412,7 +413,7 @@ textattack train --model-name-or-path bert-base-uncased --dataset glue^cola --pe ### To check datasets: `textattack peek-dataset` -To take a closer look at a dataset, use `textattack peek-dataset`. TextAttack will print some cursory statistics about the inputs and outputs from the dataset. For example, +To take a closer look at a dataset, use `textattack peek-dataset`. TextAttack will print some cursory statistics about the inputs and outputs from the dataset. For example, ```bash textattack peek-dataset --dataset-from-huggingface snli ``` @@ -427,7 +428,7 @@ There are lots of pieces in TextAttack, and it can be difficult to keep track of ## Design -### Models +### Models TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that outputs IDs, tensors, or strings. To help users, TextAttack includes pre-trained models for different common NLP tasks. This makes it easier for users to get started with TextAttack. It also enables a more fair comparison of attacks from @@ -437,12 +438,12 @@ the literature. #### Built-in Models and Datasets -TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct -dataset to the correct model. We include 82 different (Oct 2020) pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/) -tasks, as well as some common datasets for classification, translation, and summarization. +TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct +dataset to the correct model. We include 82 different (Oct 2020) pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/) +tasks, as well as some common datasets for classification, translation, and summarization. A list of available pretrained models and their validation accuracies is available at -[textattack/models/README.md](textattack/models/README.md). You can also view a full list of provided models +[textattack/models/README.md](textattack/models/README.md). You can also view a full list of provided models & datasets via `textattack attack --help`. Here's an example of using one of the built-in models (the SST-2 dataset is automatically loaded): @@ -453,7 +454,7 @@ textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 1 #### HuggingFace support: `transformers` models and `datasets` datasets -We also provide built-in support for [`transformers` pretrained models](https://huggingface.co/models) +We also provide built-in support for [`transformers` pretrained models](https://huggingface.co/models) and datasets from the [`datasets` package](https://github.com/huggingface/datasets)! Here's an example of loading and attacking a pre-trained model and dataset: @@ -461,7 +462,7 @@ and attacking a pre-trained model and dataset: textattack attack --model-from-huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset-from-huggingface glue^sst2 --recipe deepwordbug --num-examples 10 ``` -You can explore other pre-trained models using the `--model-from-huggingface` argument, or other datasets by changing +You can explore other pre-trained models using the `--model-from-huggingface` argument, or other datasets by changing `--dataset-from-huggingface`. @@ -517,7 +518,7 @@ To allow for word replacement after a sequence has been tokenized, we include an which maintains both a list of tokens and the original text, with punctuation. We use this object in favor of a list of words or just raw text. -### Attacks and how to design a new attack +### Attacks and how to design a new attack We formulate an attack as consisting of four components: a **goal function** which determines if the attack has succeeded, **constraints** defining which perturbations are valid, a **transformation** that generates potential modifications given an input, and a **search method** which traverses through the search space of possible perturbations. The attack attempts to perturb an input text such that the model output fulfills the goal function (i.e., indicating whether the attack is successful) and the perturbation adheres to the set of constraints (e.g., grammar constraint, semantic similarity constraint). A search method is used to find a sequence of transformations that produce a successful adversarial example. @@ -549,11 +550,11 @@ A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a fi ## On Benchmarking Attacks -- See our analysis paper: Searching for a Search Method: Benchmarking Search Algorithms for Generating NLP Adversarial Examples at [EMNLP BlackBoxNLP](https://arxiv.org/abs/2009.06368). +- See our analysis paper: Searching for a Search Method: Benchmarking Search Algorithms for Generating NLP Adversarial Examples at [EMNLP BlackBoxNLP](https://arxiv.org/abs/2009.06368). -- As we emphasized in the above paper, we don't recommend to directly compare Attack Recipes out of the box. +- As we emphasized in the above paper, we don't recommend to directly compare Attack Recipes out of the box. -- This comment is due to that attack recipes in the recent literature used different ways or thresholds in setting up their constraints. Without the constraint space held constant, an increase in attack success rate could come from an improved search or transformation method or a less restrictive search space. +- This comment is due to that attack recipes in the recent literature used different ways or thresholds in setting up their constraints. Without the constraint space held constant, an increase in attack success rate could come from an improved search or transformation method or a less restrictive search space. - Our Github on benchmarking scripts and results: [TextAttack-Search-Benchmark Github](https://github.com/QData/TextAttack-Search-Benchmark) @@ -563,19 +564,19 @@ A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a fi - Our analysis Paper in [EMNLP Findings](https://arxiv.org/abs/2004.14174) - We analyze the generated adversarial examples of two state-of-the-art synonym substitution attacks. We find that their perturbations often do not preserve semantics, and 38% introduce grammatical errors. Human surveys reveal that to successfully preserve semantics, we need to significantly increase the minimum cosine similarities between the embeddings of swapped words and between the sentence encodings of original and perturbed sentences.With constraints adjusted to better preserve semantics and grammaticality, the attack success rate drops by over 70 percentage points. - Our Github on Reevaluation results: [Reevaluating-NLP-Adversarial-Examples Github](https://github.com/QData/Reevaluating-NLP-Adversarial-Examples) -- As we have emphasized in this analysis paper, we recommend researchers and users to be EXTREMELY mindful on the quality of generated adversarial examples in natural language -- We recommend the field to use human-evaluation derived thresholds for setting up constraints +- As we have emphasized in this analysis paper, we recommend researchers and users to be EXTREMELY mindful on the quality of generated adversarial examples in natural language +- We recommend the field to use human-evaluation derived thresholds for setting up constraints ## Multi-lingual Support -- see example code: [https://github.com/QData/TextAttack/blob/master/examples/attack/attack_camembert.py](https://github.com/QData/TextAttack/blob/master/examples/attack/attack_camembert.py) for using our framework to attack French-BERT. +- see example code: [https://github.com/QData/TextAttack/blob/master/examples/attack/attack_camembert.py](https://github.com/QData/TextAttack/blob/master/examples/attack/attack_camembert.py) for using our framework to attack French-BERT. -- see tutorial notebook: [https://textattack.readthedocs.io/en/latest/2notebook/Example_4_CamemBERT.html](https://textattack.readthedocs.io/en/latest/2notebook/Example_4_CamemBERT.html) for using our framework to attack French-BERT. +- see tutorial notebook: [https://textattack.readthedocs.io/en/latest/2notebook/Example_4_CamemBERT.html](https://textattack.readthedocs.io/en/latest/2notebook/Example_4_CamemBERT.html) for using our framework to attack French-BERT. -- See [README_ZH.md](https://github.com/QData/TextAttack/blob/master/README_ZH.md) for our README in Chinese +- See [README_ZH.md](https://github.com/QData/TextAttack/blob/master/README_ZH.md) for our README in Chinese @@ -598,5 +599,3 @@ If you use TextAttack for your research, please cite [TextAttack: A Framework fo year={2020} } ``` - - diff --git a/docs/3recipes/augmenter_recipes_cmd.md b/docs/3recipes/augmenter_recipes_cmd.md index bde5b3116..e44006cd1 100644 --- a/docs/3recipes/augmenter_recipes_cmd.md +++ b/docs/3recipes/augmenter_recipes_cmd.md @@ -1,8 +1,8 @@ -# Augmenter Recipes CommandLine Use +# Augmenter Recipes CommandLine Use -Transformations and constraints can be used for simple NLP data augmentations. +Transformations and constraints can be used for simple NLP data augmentations. -The [`examples/`](https://github.com/QData/TextAttack/tree/master/examples) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. +The [`examples/`](https://github.com/QData/TextAttack/tree/master/examples) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The [documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint.. @@ -18,11 +18,12 @@ for data augmentation: - `eda` augments text with a combination of word insertions, substitutions and deletions. - `checklist` augments text by contraction/extension and by substituting names, locations, numbers. - `clare` augments text by replacing, inserting, and merging with a pre-trained masked language model. -- `back_trans` augments text by backtranslation method. +- `back_trans` augments text by backtranslation method. +- `back_transcription` augments text by back transcription approach. ### Augmentation Command-Line Interface -The easiest way to use our data augmentation tools is with `textattack augment `. +The easiest way to use our data augmentation tools is with `textattack augment `. `textattack augment` takes an input CSV file, the "text" column to augment, along with the number of words to change per augmentation @@ -65,4 +66,3 @@ it's a enigma how the filmmaking wo be publicized in this condition .,0 ``` The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data. - diff --git a/textattack/augment_args.py b/textattack/augment_args.py index 666ed2e3c..c7ce2e78c 100644 --- a/textattack/augment_args.py +++ b/textattack/augment_args.py @@ -14,6 +14,7 @@ "checklist": "textattack.augmentation.CheckListAugmenter", "clare": "textattack.augmentation.CLAREAugmenter", "back_trans": "textattack.augmentation.BackTranslationAugmenter", + "back_transcription": "textattack.augmentation.BackTranscriptionAugmenter", } diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index fe647d9d9..4c0f87f74 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -263,3 +263,15 @@ def __init__(self, **kwargs): transformation = BackTranslation(chained_back_translation=5) super().__init__(transformation, **kwargs) + + +class BackTranscriptionAugmenter(Augmenter): + """Sentence level augmentation that uses back transcription (TTS+ASR).""" + + def __init__(self, **kwargs): + from textattack.transformations.sentence_transformations import ( + BackTranscription, + ) + + transformation = BackTranscription() + super().__init__(transformation, **kwargs) diff --git a/textattack/transformations/sentence_transformations/back_transcription.py b/textattack/transformations/sentence_transformations/back_transcription.py index 8d23dc2ea..ec83c4db4 100644 --- a/textattack/transformations/sentence_transformations/back_transcription.py +++ b/textattack/transformations/sentence_transformations/back_transcription.py @@ -33,6 +33,27 @@ class BackTranscription(SentenceTransformation): >>> s = 'What on earth are you doing here.' >>> augmenter.augment(s) + + You can find more about the back transcription method in the following paper: + + @inproceedings{kubis-etal-2023-back, + title = "Back Transcription as a Method for Evaluating Robustness of Natural Language Understanding Models to Speech Recognition Errors", + author = "Kubis, Marek and + Sk{\\'o}rzewski, Pawe{\\l} and + Sowa{\\'n}nski, Marcin and + Zietkiewicz, Tomasz", + editor = "Bouamor, Houda and + Pino, Juan and + Bali, Kalika", + booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing", + month = dec, + year = "2023", + address = "Singapore", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2023.emnlp-main.724", + doi = "10.18653/v1/2023.emnlp-main.724", + pages = "11824--11835", + } """ def __init__( From a9ad2d570ebda458a8a8cb0725b79c81f7e38592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sk=C3=B3rzewski?= Date: Mon, 18 Dec 2023 15:17:11 +0100 Subject: [PATCH 4/4] Add BackTranscription transformation documentation by modifying file --- .../textattack.transformations.sentence_transformations.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/apidoc/textattack.transformations.sentence_transformations.rst b/docs/apidoc/textattack.transformations.sentence_transformations.rst index 54f534ec9..1c56ed068 100644 --- a/docs/apidoc/textattack.transformations.sentence_transformations.rst +++ b/docs/apidoc/textattack.transformations.sentence_transformations.rst @@ -15,6 +15,12 @@ textattack.transformations.sentence\_transformations package :show-inheritance: +.. automodule:: textattack.transformations.sentence_transformations.back_transcription + :members: + :undoc-members: + :show-inheritance: + + .. automodule:: textattack.transformations.sentence_transformations.sentence_transformation :members: :undoc-members: