Skip to content

Commit

Permalink
Update examples/flax/language-modeling/README.md
Browse files Browse the repository at this point in the history
Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
louie-tsai and stevhliu committed Dec 4, 2024
1 parent b95143f commit d7fc2d1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/flax/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -565,4 +565,4 @@ By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the p
import jax
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16)
```
On a AWS c7i.4xlarge with Intel Sapphire Rapids, we get > 2X speedup by changing precision from float32 to bfloat16.
Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x.

0 comments on commit d7fc2d1

Please sign in to comment.