Skip to content

SJTU-ReArch-Group/BlockSkim_AAAI22

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Block-Skim: Efficient Question Answering for Transformer

Abstract

Transformer models have achieved promising results on nat-ural language processing (NLP) tasks including extractivequestion answering (QA). Common Transformer encodersused in NLP tasks process the hidden states of all input to-kens in the context paragraph throughout all layers. However,different from other tasks such as sequence classification, an-swering the raised question does not necessarily need all thetokens in the context paragraph. Following this motivation,we propose Block-Skim, which learns to skim unnecessarycontext in higher hidden layers to improve and acceleratethe Transformer performance. The key idea of Block-Skimis to identify the context that must be further processed andthose that could be safely discarded early on during infer-ence. Critically, we find that such information could be suf-ficiently derived from the self-attention weights inside theTransformer model. We further prune the hidden states cor-responding to the unnecessary positions early in lower lay-ers, achieving significant inference-time speedup. To our sur-prise, we observe that models pruned in this way outperformtheir full-size counterparts. Block-Skim improves QA mod-els’ accuracy on different datasets and achieves3×speedupon BERTbase model.

Overview

Overall scheme of Block-Skim

3 times speedup with no accuracy loss on SQuAD

Compatible acceleration with model compression methods

SQuAD-1.1 HotpotQA

How to use

We will make the code base public available and release the checkpoints of models for Block-Skim once accepted.

Requirements

  1. Install Anaconda.

  2. Install dependencies with Anaconda.

conda create --name blockskim --file requirements.txt 
  1. Activate installed anaconda virtual environment.
conda activate blockskim

Datasets

  1. SQuAD

Download SQuAD-1.1 datasets from the following links.

  1. MRQA datasets

We use SearchQA, NewsQA, NaturalQuestions and TriviaQA datasets in MRQA format. Download the MRQA datasets from its official repository.

  1. HotpotQA

We use HotpotQA dataset from datasets and parsing it with in-house preprocess script to include supporting facts in SQuAD format.

python src/utils/process_hotpotqa.py

Training

For Block-Skim training, just run

bash scripts/finetune_squad.sh

This bash script will run the training loop in src/run_squad.py. The other training settings, e.g. distillation or pruning, are implemented in separated training loop files in src.

Evaluation

For Block-Skim inference, just run

bash scripts/eval_squad.sh

About

Code for BlockSkim

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Shell 0.2%