Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a new flax example for Bert model inference #34794

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
26 changes: 25 additions & 1 deletion examples/flax/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
-->

# Language model training examples
# Language model training and inference examples

The following example showcases how to train a language model from scratch
using the JAX/Flax backend.
Expand Down Expand Up @@ -542,3 +542,27 @@ python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
--report_to="tensorboard" \
--save_strategy="no"
```

## Language model inference with bfloat16

The following example demonstrates performing inference with a language model using the JAX/Flax backend.

The example script run_bert_flax.py uses bert-base-uncased, and the model is loaded into `FlaxBertModel`.
The input data are randomly generated tokens, and the model is also jitted with JAX.
By default, it uses float32 precision for inference. To enable bfloat16, add the flag shown in the command below.

```bash
python3 run_bert_flax.py --precision bfloat16
> NOTE: For JAX Versions after v0.4.33 or later, users will need to set the below environment variables as a \
> temporary workaround to use Bfloat16 datatype. \
> This restriction is expected to be removed in future version
```bash
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
```
bfloat16 gives better performance on GPUs and also Intel CPUs (Sapphire Rapids or later) with Advanced Matrix Extension (Intel AMX).
By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the performance benefits of the underlying hardware.
```python
import jax
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16)
```
Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x.
56 changes: 56 additions & 0 deletions examples/flax/language-modeling/run_bert_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
import time
from argparse import ArgumentParser

import jax
import numpy as np

from transformers import BertConfig, FlaxBertModel


parser = ArgumentParser()
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32")
args = parser.parse_args()

dtype = jax.numpy.float32
if args.precision == "bfloat16":
dtype = jax.numpy.bfloat16

VOCAB_SIZE = 30522
BS = 32
SEQ_LEN = 128


def get_input_data(batch_size=1, seq_length=384):
shape = (batch_size, seq_length)
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
token_type_ids = np.ones(shape).astype(np.int32)
attention_mask = np.ones(shape).astype(np.int32)
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}


inputs = get_input_data(BS, SEQ_LEN)
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new")
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype)


@jax.jit
def func():
outputs = model(**inputs)
return outputs


(nwarmup, nbenchmark) = (5, 100)

# warmpup
for _ in range(nwarmup):
func()

# benchmark

start = time.time()
for _ in range(nbenchmark):
func()
end = time.time()
print(end - start)
print(f"Throughput: {((nbenchmark * BS)/(end-start)):.3f} examples/sec")