Transformer models have achieved promising results on nat-ural language processing (NLP) tasks including extractivequestion answering (QA). Common Transformer encodersused in NLP tasks process the hidden states of all input to-kens in the context paragraph throughout all layers. However,different from other tasks such as sequence classification, an-swering the raised question does not necessarily need all thetokens in the context paragraph. Following this motivation,we propose Block-Skim, which learns to skim unnecessarycontext in higher hidden layers to improve and acceleratethe Transformer performance. The key idea of Block-Skimis to identify the context that must be further processed andthose that could be safely discarded early on during infer-ence. Critically, we find that such information could be suf-ficiently derived from the self-attention weights inside theTransformer model. We further prune the hidden states cor-responding to the unnecessary positions early in lower lay-ers, achieving significant inference-time speedup. To our sur-prise, we observe that models pruned in this way outperformtheir full-size counterparts. Block-Skim improves QA mod-els’ accuracy on different datasets and achieves3×speedupon BERTbase model.
SQuAD-1.1 | HotpotQA |
---|---|
We will make the code base public available and release the checkpoints of models for Block-Skim once accepted.
-
Install Anaconda.
-
Install dependencies with Anaconda.
conda create --name blockskim --file requirements.txt
- Activate installed anaconda virtual environment.
conda activate blockskim
- SQuAD
Download SQuAD-1.1 datasets from the following links.
- MRQA datasets
We use SearchQA, NewsQA, NaturalQuestions and TriviaQA datasets in MRQA format. Download the MRQA datasets from its official repository.
- HotpotQA
We use HotpotQA dataset from datasets and parsing it with in-house preprocess script to include supporting facts in SQuAD format.
python src/utils/process_hotpotqa.py
For Block-Skim training, just run
bash scripts/finetune_squad.sh
This bash script will run the training loop in src/run_squad.py
. The other training settings, e.g. distillation or pruning, are implemented in separated training loop files in src
.
For Block-Skim inference, just run
bash scripts/eval_squad.sh